File size: 3,178 Bytes
a5dcb7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import base64
import os

import gradio as gr
import spaces
import torch
from transformers import AutoModel, AutoTokenizer

model_name = "ucaslcl/GOT-OCR2_0"
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, device_map=device)
model = model.eval().to(device)


@spaces.GPU()
def ocr_process(image, got_mode, ocr_color="", ocr_box="", progress=gr.Progress()):
    if image is None:
        return "错误:未提供图片"

    try:
        image_path = image
        result_path = f"{os.path.splitext(image_path)[0]}_result.html"

        progress(0, desc="开始处理...")

        if "plain" in got_mode:
            progress(0.3, desc="执行OCR识别...")
            if "multi-crop" in got_mode:
                res = model.chat_crop(tokenizer, image_path, ocr_type="ocr")
            else:
                res = model.chat(tokenizer, image_path, ocr_type="ocr", ocr_box=ocr_box, ocr_color=ocr_color)
            progress(1, desc="处理完成")
            return res
        elif "format" in got_mode:
            progress(0.3, desc="执行OCR识别...")
            if "multi-crop" in got_mode:
                res = model.chat_crop(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path)
            else:
                res = model.chat(tokenizer, image_path, ocr_type="format", ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)

            progress(0.7, desc="生成结果...")
            if os.path.exists(result_path):
                with open(result_path, "r", encoding="utf-8") as f:
                    html_content = f.read()
                encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
                data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}"
                preview = f'<iframe src="{data_uri}" width="100%" height="600px"></iframe>'
                download_link = f'<a href="{data_uri}" download="result.html">下载完整结果</a>'
                progress(1, desc="处理完成")
                return f"{download_link}\n\n{preview}"

        return "错误: 未知的OCR模式"
    except Exception as e:
        return f"错误: {str(e)}"


with gr.Blocks() as demo:
    gr.Markdown("# OCR 图像识别")

    with gr.Row():
        image_input = gr.Image(type="filepath", label="上传图片")

    got_mode = gr.Dropdown(
        choices=["plain texts OCR", "format texts OCR", "plain multi-crop OCR", "format multi-crop OCR", "plain fine-grained OCR", "format fine-grained OCR"],
        label="OCR模式",
        value="plain texts OCR",
    )

    with gr.Row():
        ocr_color = gr.Textbox(label="OCR颜色 (仅用于fine-grained模式)")
        ocr_box = gr.Textbox(label="OCR边界框 (仅用于fine-grained模式)")

    submit_button = gr.Button("开始OCR识别")

    output = gr.HTML(label="识别结果")

    submit_button.click(ocr_process, inputs=[image_input, got_mode, ocr_color, ocr_box], outputs=output)

if __name__ == "__main__":
    demo.launch()