Spaces:
Sleeping
Sleeping
File size: 3,933 Bytes
a5dcb7e dcec5ef a5dcb7e dcec5ef a5dcb7e dcec5ef a5dcb7e 118959f dcec5ef 118959f dcec5ef a5dcb7e 7b83985 a5dcb7e 118959f a5dcb7e dcec5ef a5dcb7e dcec5ef a5dcb7e dcec5ef f99597a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import base64
import logging
import os
from datetime import datetime
import torch
from fastapi import FastAPI, File, Form, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoModel, AutoTokenizer
# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
app = FastAPI()
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 初始化模型
model_name = "Mageia/GOT-OCR2_0"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, device_map=device)
model = model.eval().to(device)
# OCR处理函数
async def ocr_process(image_path, got_mode, ocr_color="", ocr_box=""):
try:
if "plain" in got_mode:
if "multi-crop" in got_mode:
res = model.chat_crop(tokenizer, image_path, ocr_type="ocr")
else:
res = model.chat(tokenizer, image_path, ocr_type="ocr", ocr_box=ocr_box, ocr_color=ocr_color)
return res
elif "format" in got_mode:
result_path = f"{os.path.splitext(image_path)[0]}_result.html"
if "multi-crop" in got_mode:
res = model.chat_crop(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path)
else:
res = model.chat(tokenizer, image_path, ocr_type="format", ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)
if os.path.exists(result_path):
with open(result_path, "r", encoding="utf-8") as f:
html_content = f.read()
encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
return {"html_content": encoded_html}
return {"error": "未知的OCR模式"}
except Exception as e:
return {"error": str(e)}
@app.post("/ocr")
async def ocr_api(request: Request, image: UploadFile = File(...), got_mode: str = Form(...), ocr_color: str = Form(""), ocr_box: str = Form("")):
# 记录请求信息
client_host = request.client.host
user_agent = request.headers.get("user-agent", "Unknown")
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
log_message = f"""
时间: {current_time}
IP地址: {client_host}
User-Agent: {user_agent}
图片名称: {image.filename}
OCR模式: {got_mode}
OCR颜色: {ocr_color}
OCR边界框: {ocr_box}
"""
logger.info(log_message)
# 保存上传的图片
image_path = f"temp_{image.filename}"
with open(image_path, "wb") as buffer:
buffer.write(await image.read())
# 处理OCR
result = await ocr_process(image_path, got_mode, ocr_color, ocr_box)
# 删除临时文件
os.remove(image_path)
# 记录处理结果
logger.info(f"OCR处理结果: {result}")
return result
@app.get("/")
async def read_root(request: Request):
client_host = request.client.host
user_agent = request.headers.get("user-agent", "Unknown")
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
log_message = f"""
时间: {current_time}
IP地址: {client_host}
User-Agent: {user_agent}
访问: 根路径
"""
logger.info(log_message)
return {
"message": "欢迎使用OCR API",
"user_agent": user_agent,
"model": model_name,
"device": device,
"ocr_mode": [
"plain texts OCR",
"format texts OCR",
"plain multi-crop OCR",
"format multi-crop OCR",
"plain fine-grained OCR",
"format fine-grained OCR",
],
}
|