Mageia commited on
Commit
4d31938
·
unverified ·
1 Parent(s): 0e016b0

fix: process pdf once

Browse files
Files changed (1) hide show
  1. app.py +33 -188
app.py CHANGED
@@ -1,209 +1,54 @@
1
  import base64
2
- import multiprocessing
3
  import os
4
- import shutil
5
  import uuid
6
- from functools import partial
7
 
8
- import fitz # PyMuPDF
9
- import gradio as gr
10
- import spaces
11
- from PIL import Image, ImageEnhance
12
- from transformers import AutoModel, AutoTokenizer
13
 
14
- # 全局变量
15
- model = None
16
- tokenizer = None
17
 
 
18
 
19
- def initialize_model():
20
- model_name = "ucaslcl/GOT-OCR2_0"
21
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
22
- model = AutoModel.from_pretrained(model_name, trust_remote_code=True, low_cpu_mem_usage=True, device_map="auto")
23
- model = model.eval()
24
- return model, tokenizer
25
 
 
 
 
 
 
26
 
27
  UPLOAD_FOLDER = "./uploads"
28
- RESULTS_FOLDER = "./results"
29
 
30
- # 确保必要的文件夹存在
31
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
32
- os.makedirs(RESULTS_FOLDER, exist_ok=True)
33
 
34
 
35
- def pdf_to_images(pdf_path):
36
- images = []
37
- pdf_document = fitz.open(pdf_path)
38
- for page_num in range(len(pdf_document)):
39
- page = pdf_document.load_page(page_num)
40
- # 进一步增加分辨率和缩放比例
41
- zoom = 10 # 增加缩放比例到4
42
- mat = fitz.Matrix(zoom, zoom)
43
- pix = page.get_pixmap(matrix=mat, alpha=False)
44
- img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
45
 
46
- # 增对比度
47
- enhancer = ImageEnhance.Contrast(img)
48
- img = enhancer.enhance(1.5) # 增加50%的对比度
49
-
50
- images.append(img)
51
- pdf_document.close()
52
- return images
53
-
54
-
55
- def process_pdf(pdf_file):
56
- if pdf_file is None:
57
- return None
58
-
59
- temp_pdf_path = os.path.join(UPLOAD_FOLDER, f"{uuid.uuid4()}.pdf")
60
-
61
- # 使用 shutil 复制上传的件到临时位置
62
- shutil.copy(pdf_file.name, temp_pdf_path)
63
-
64
- images = pdf_to_images(temp_pdf_path)
65
- os.remove(temp_pdf_path)
66
 
67
- # 将图像保存为临时文件并返回文件路径列表
68
- image_paths = []
69
- for i, img in enumerate(images):
70
- img_path = os.path.join(RESULTS_FOLDER, f"page_{i+1}.png")
71
- img.save(img_path, "PNG")
72
- image_paths.append(img_path)
73
 
74
- return image_paths
 
 
 
75
 
 
76
 
77
- @spaces.GPU()
78
- def got_ocr(model, tokenizer, image_path, got_mode="format texts OCR", fine_grained_mode="", ocr_color="", ocr_box=""):
79
- # 在这里将模型移动到 GPU
80
- model = model.cuda()
81
- # 执行OCR
82
- try:
83
- if got_mode == "plain texts OCR":
84
- res = model.chat(tokenizer, image_path, ocr_type="ocr")
85
- return res, None
86
- elif got_mode == "format texts OCR":
87
- result_path = f"{os.path.splitext(image_path)[0]}_result.html"
88
- res = model.chat(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path)
89
- elif got_mode == "plain multi-crop OCR":
90
- res = model.chat_crop(tokenizer, image_path, ocr_type="ocr")
91
- return res, None
92
- elif got_mode == "format multi-crop OCR":
93
- result_path = f"{os.path.splitext(image_path)[0]}_result.html"
94
- res = model.chat_crop(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path)
95
- elif got_mode == "plain fine-grained OCR":
96
- res = model.chat(tokenizer, image_path, ocr_type="ocr", ocr_box=ocr_box, ocr_color=ocr_color)
97
- return res, None
98
- elif got_mode == "format fine-grained OCR":
99
- result_path = f"{os.path.splitext(image_path)[0]}_result.html"
100
- res = model.chat(tokenizer, image_path, ocr_type="format", ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)
101
-
102
- # 处理格式化结果
103
- if "format" in got_mode and os.path.exists(result_path):
104
- with open(result_path, "r") as f:
105
- html_content = f.read()
106
- encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
107
- return res, encoded_html
108
- else:
109
- return res, None
110
-
111
- except Exception as e:
112
- return f"错误: {str(e)}", None
113
- finally:
114
- # 在使用完后将模型移回 CPU
115
- model = model.cpu()
116
-
117
-
118
- def worker_process(task_queue, result_queue):
119
- model, tokenizer = initialize_model()
120
- while True:
121
- task = task_queue.get()
122
- if task is None:
123
- break
124
- image_path, got_mode, fine_grained_mode, ocr_color, ocr_box = task
125
- result, _ = got_ocr(model, tokenizer, image_path, got_mode, fine_grained_mode, ocr_color, ocr_box)
126
- result_queue.put(result)
127
-
128
-
129
- def perform_ocr(image_gallery, got_mode, fine_grained_type, color, box):
130
- task_queue = multiprocessing.Queue()
131
- result_queue = multiprocessing.Queue()
132
-
133
- process = multiprocessing.Process(target=worker_process, args=(task_queue, result_queue))
134
- process.start()
135
-
136
- results = []
137
- progress = gr.Progress()
138
-
139
- for i, image_info in enumerate(progress.tqdm(image_gallery)):
140
- selected_image = image_info[0]
141
- ocr_color = color if fine_grained_type == "color" else ""
142
- ocr_box = box if fine_grained_type == "box" else ""
143
-
144
- task_queue.put((selected_image, got_mode, fine_grained_type, ocr_color, ocr_box))
145
- result = result_queue.get()
146
- results.append(f"第 {i+1} 页结果:\n{result}\n\n")
147
-
148
- task_queue.put(None) # 发送终止信号
149
- process.join()
150
-
151
- combined_result = "".join(results)
152
- encoded_result = base64.b64encode(combined_result.encode("utf-8")).decode("utf-8")
153
- download_link = f'<a href="data:text/plain;base64,{encoded_result}" download="ocr_result.txt">下载完整OCR结果</a>'
154
-
155
- return gr.Markdown(f"{download_link}\n\n{combined_result[:1000]}..."), combined_result
156
-
157
-
158
- def task_update(task):
159
- if "fine-grained" in task:
160
- return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
161
- else:
162
- return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)]
163
-
164
-
165
- def fine_grained_update(fine_grained_type):
166
- if fine_grained_type == "color":
167
- return [gr.update(visible=True), gr.update(visible=False)]
168
- elif fine_grained_type == "box":
169
- return [gr.update(visible=False), gr.update(visible=True)]
170
- else:
171
- return [gr.update(visible=False), gr.update(visible=False)]
172
-
173
-
174
- with gr.Blocks() as demo:
175
- pdf_input = gr.File(label="上传PDF文件")
176
- image_gallery = gr.Gallery(
177
- label="PDF页面预览",
178
- columns=3,
179
- height=600,
180
- object_fit="contain",
181
- preview=True,
182
- )
183
- pdf_input.upload(fn=process_pdf, inputs=pdf_input, outputs=image_gallery)
184
-
185
- task_dropdown = gr.Dropdown(
186
- choices=["plain texts OCR", "format texts OCR", "plain multi-crop OCR", "format multi-crop OCR", "plain fine-grained OCR", "format fine-grained OCR"],
187
- label="选择GOT模式",
188
- value="format texts OCR",
189
- )
190
- fine_grained_dropdown = gr.Dropdown(choices=["box", "color"], label="fine-grained类型", visible=False)
191
- color_dropdown = gr.Dropdown(choices=["red", "green", "blue"], label="颜色列表", visible=False)
192
- box_input = gr.Textbox(label="输入框: [x1,y1,x2,y2]", placeholder="例如: [0,0,100,100]", visible=False)
193
-
194
- ocr_button = gr.Button("开始OCR识别")
195
- ocr_result = gr.Markdown(label="OCR结果预览")
196
- full_result = gr.State()
197
-
198
- task_dropdown.change(task_update, inputs=[task_dropdown], outputs=[fine_grained_dropdown, color_dropdown, box_input])
199
- fine_grained_dropdown.change(fine_grained_update, inputs=[fine_grained_dropdown], outputs=[color_dropdown, box_input])
200
-
201
- ocr_button.click(
202
- fn=perform_ocr,
203
- inputs=[image_gallery, task_dropdown, fine_grained_dropdown, color_dropdown, box_input],
204
- outputs=[ocr_result, full_result],
205
- )
206
 
207
  if __name__ == "__main__":
208
- multiprocessing.set_start_method("spawn")
209
- demo.launch()
 
 
1
  import base64
 
2
  import os
 
3
  import uuid
 
4
 
5
+ import torch
6
+ from fastapi import FastAPI, File, UploadFile
7
+ from fastapi.responses import JSONResponse
8
+ from transformers import AutoConfig, AutoModel, AutoTokenizer
 
9
 
10
+ from got_ocr import got_ocr
 
 
11
 
12
+ app = FastAPI()
13
 
14
+ # 初始化模型和分词器
15
+ model_name = "ucaslcl/GOT-OCR2_0"
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
17
 
18
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
19
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
20
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True, low_cpu_mem_usage=True, device_map="cuda", use_safetensors=True)
21
+ model = model.eval().to(device)
22
+ model.config.pad_token_id = tokenizer.eos_token_id
23
 
24
  UPLOAD_FOLDER = "./uploads"
 
25
 
26
+ # 确保上传文件夹存在
27
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
 
28
 
29
 
30
+ @app.post("/ocr")
31
+ async def perform_ocr(image: UploadFile = File(...)):
32
+ # 保存上传的图片
33
+ image_path = os.path.join(UPLOAD_FOLDER, f"{uuid.uuid4()}.png")
34
+ with open(image_path, "wb") as buffer:
35
+ buffer.write(await image.read())
 
 
 
 
36
 
37
+ # 执行OCR
38
+ result, html_content = got_ocr(model, tokenizer, image_path, got_mode="format texts OCR")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ # 删除临时文件
41
+ os.remove(image_path)
 
 
 
 
42
 
43
+ # 准备响应
44
+ response = {"result": result}
45
+ if html_content:
46
+ response["html_content"] = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
47
 
48
+ return JSONResponse(content=response)
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  if __name__ == "__main__":
52
+ import uvicorn
53
+
54
+ uvicorn.run(app, host="0.0.0.0", port=8000)