File size: 3,933 Bytes
a5dcb7e
dcec5ef
a5dcb7e
dcec5ef
a5dcb7e
 
dcec5ef
 
a5dcb7e
118959f
dcec5ef
 
 
 
118959f
 
dcec5ef
 
 
 
 
 
 
 
 
a5dcb7e
7b83985
a5dcb7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118959f
a5dcb7e
 
dcec5ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5dcb7e
 
 
 
 
 
 
 
 
 
 
dcec5ef
 
 
a5dcb7e
 
 
 
dcec5ef
 
 
 
 
 
 
 
 
 
 
 
 
f99597a
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import base64
import logging
import os
from datetime import datetime

import torch
from fastapi import FastAPI, File, Form, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoModel, AutoTokenizer

# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

app = FastAPI()

# 添加CORS中间件
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 初始化模型
model_name = "Mageia/GOT-OCR2_0"
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, device_map=device)
model = model.eval().to(device)


# OCR处理函数
async def ocr_process(image_path, got_mode, ocr_color="", ocr_box=""):
    try:
        if "plain" in got_mode:
            if "multi-crop" in got_mode:
                res = model.chat_crop(tokenizer, image_path, ocr_type="ocr")
            else:
                res = model.chat(tokenizer, image_path, ocr_type="ocr", ocr_box=ocr_box, ocr_color=ocr_color)
            return res
        elif "format" in got_mode:
            result_path = f"{os.path.splitext(image_path)[0]}_result.html"
            if "multi-crop" in got_mode:
                res = model.chat_crop(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path)
            else:
                res = model.chat(tokenizer, image_path, ocr_type="format", ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)

            if os.path.exists(result_path):
                with open(result_path, "r", encoding="utf-8") as f:
                    html_content = f.read()
                encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
                return {"html_content": encoded_html}

        return {"error": "未知的OCR模式"}
    except Exception as e:
        return {"error": str(e)}


@app.post("/ocr")
async def ocr_api(request: Request, image: UploadFile = File(...), got_mode: str = Form(...), ocr_color: str = Form(""), ocr_box: str = Form("")):
    # 记录请求信息
    client_host = request.client.host
    user_agent = request.headers.get("user-agent", "Unknown")
    current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

    log_message = f"""
    时间: {current_time}
    IP地址: {client_host}
    User-Agent: {user_agent}
    图片名称: {image.filename}
    OCR模式: {got_mode}
    OCR颜色: {ocr_color}
    OCR边界框: {ocr_box}
    """
    logger.info(log_message)

    # 保存上传的图片
    image_path = f"temp_{image.filename}"
    with open(image_path, "wb") as buffer:
        buffer.write(await image.read())

    # 处理OCR
    result = await ocr_process(image_path, got_mode, ocr_color, ocr_box)

    # 删除临时文件
    os.remove(image_path)

    # 记录处理结果
    logger.info(f"OCR处理结果: {result}")

    return result


@app.get("/")
async def read_root(request: Request):
    client_host = request.client.host
    user_agent = request.headers.get("user-agent", "Unknown")
    current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

    log_message = f"""
    时间: {current_time}
    IP地址: {client_host}
    User-Agent: {user_agent}
    访问: 根路径
    """
    logger.info(log_message)

    return {
        "message": "欢迎使用OCR API",
        "user_agent": user_agent,
        "model": model_name,
        "device": device,
        "ocr_mode": [
            "plain texts OCR",
            "format texts OCR",
            "plain multi-crop OCR",
            "format multi-crop OCR",
            "plain fine-grained OCR",
            "format fine-grained OCR",
        ],
    }