bwingenroth commited on
Commit
d851af1
·
verified ·
1 Parent(s): afc4ada

Create app.py

Browse files

![lrpw0232_Page_14.png](/static-proxy?url=https%3A%2F%2Fcdn-uploads.huggingface.co%2Fproduction%2Fuploads%2F652d4da24a1bd9cfff0f568a%2F1MXWoVVJcSpeBOdbeMW6K.png)%3Cbr%2F%3E!%5Bfnmf0234_Page_2.png%5D(https%3A%2F%2Fcdn-uploads.huggingface.co%2Fproduction%2Fuploads%2F652d4da24a1bd9cfff0f568a%2FSpbkEZFB8dDs3pmcyRNDt.png)%3Cbr%2F%3E!%5Bpublaynet_example.jpeg%5D(https%3A%2F%2Fcdn-uploads.huggingface.co%2Fproduction%2Fuploads%2F652d4da24a1bd9cfff0f568a%2FXwg13XhIGlqcBl4BrO2_M.jpeg)%3Cbr%2F%3E!%5Bfpmj0236_Page_018.png%5D(https%3A%2F%2Fcdn-uploads.huggingface.co%2Fproduction%2Fuploads%2F652d4da24a1bd9cfff0f568a%2F8mxGF-WLJNyYbkh4Qz0g1.png)%3Cbr%2F%3E!%5Bfpmj0236_Page_012.png%5D(https%3A%2F%2Fcdn-uploads.huggingface.co%2Fproduction%2Fuploads%2F652d4da24a1bd9cfff0f568a%2FqGnOjplJHlvyW5pkyDZlX.png)%3C!-- HTML_TAG_END -->

Files changed (1) hide show
  1. app.py +298 -0
app.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system('git clone https://github.com/facebookresearch/detectron2.git')
3
+ os.system('pip install -e detectron2')
4
+ os.system("git clone https://github.com/microsoft/unilm.git")
5
+ os.system("sed -i 's/from collections import Iterable/from collections.abc import Iterable/' unilm/dit/object_detection/ditod/table_evaluation/data_structure.py")
6
+ os.system("curl -LJ -o publaynet_dit-b_cascade.pth 'https://layoutlm.blob.core.windows.net/dit/dit-fts/publaynet_dit-b_cascade.pth?sv=2022-11-02&ss=b&srt=o&sp=r&se=2033-06-08T16:48:15Z&st=2023-06-08T08:48:15Z&spr=https&sig=a9VXrihTzbWyVfaIDlIT1Z0FoR1073VB0RLQUMuudD4%3D'")
7
+
8
+ import sys
9
+ sys.path.append("unilm")
10
+ sys.path.append("detectron2")
11
+
12
+ import cv2
13
+ import filetype
14
+ from PIL import Image
15
+ import numpy as np
16
+ from io import BytesIO
17
+ from pdf2image import convert_from_bytes, convert_from_path
18
+
19
+ import re
20
+ import requests
21
+ from urllib.parse import urlparse, parse_qs
22
+
23
+ from unilm.dit.object_detection.ditod import add_vit_config
24
+
25
+ import torch
26
+
27
+ from detectron2.config import CfgNode as CN
28
+ from detectron2.config import get_cfg
29
+ from detectron2.utils.visualizer import ColorMode, Visualizer
30
+ from detectron2.data import MetadataCatalog
31
+ from detectron2.engine import DefaultPredictor
32
+
33
+ from huggingface_hub import hf_hub_download
34
+
35
+ import gradio as gr
36
+
37
+
38
+ # Step 1: instantiate config
39
+ cfg = get_cfg()
40
+ add_vit_config(cfg)
41
+ #cfg.merge_from_file("cascade_dit_base.yml")
42
+ cfg.merge_from_file("unilm/dit/object_detection/publaynet_configs/cascade/cascade_dit_base.yaml")
43
+
44
+ # Step 2: add model weights URL to config
45
+ filepath = hf_hub_download(repo_id="Sebas6k/DiT_weights", filename="publaynet_dit-b_cascade.pth", repo_type="model")
46
+ cfg.MODEL.WEIGHTS = filepath
47
+
48
+ # Step 3: set device
49
+ cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
50
+
51
+ # Step 4: define model
52
+ predictor = DefaultPredictor(cfg)
53
+
54
+
55
+ def analyze_image(img):
56
+ md = MetadataCatalog.get(cfg.DATASETS.TEST[0])
57
+ if cfg.DATASETS.TEST[0]=='icdar2019_test':
58
+ md.set(thing_classes=["table"])
59
+ else:
60
+ md.set(thing_classes=["text","title","list","table","figure"]) ## these are categories from PubLayNet (PubMed PDF/XML data): https://ieeexplore.ieee.org/document/8977963
61
+
62
+ outputs = predictor(img)
63
+ instances = outputs["instances"]
64
+
65
+ # Ensure we're operating on CPU for numpy compatibility
66
+ instances = instances.to("cpu")
67
+
68
+ # Filter out figures based on class labels
69
+ high_confidence = []
70
+ medium_confidence = []
71
+ low_confidence = []
72
+ for i in range(len(instances)):
73
+ if md.thing_classes[instances.pred_classes[i]] == "figure":
74
+ box = instances.pred_boxes.tensor[i].numpy().astype(int)
75
+ cropped_img = img[box[1]:box[3], box[0]:box[2]]
76
+ confidence_score = instances.scores[i].numpy() * 100 # convert to percentage
77
+ confidence_text = f"Score: {confidence_score:.2f}%"
78
+
79
+ # Overlay confidence score on the image
80
+ # Enhanced label visualization with orange color
81
+ font_scale = 0.9
82
+ font_thickness = 2
83
+ text_color = (255, 255, 255) # white background
84
+ background_color = (255, 165, 0) # RGB for orange
85
+
86
+ (text_width, text_height), _ = cv2.getTextSize(confidence_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_thickness)
87
+ padding = 12
88
+ text_offset_x = padding - 3
89
+ text_offset_y = cropped_img.shape[0] - padding + 2
90
+ box_coords = ((text_offset_x, text_offset_y + padding // 2), (text_offset_x + text_width + padding, text_offset_y - text_height - padding // 2))
91
+ cv2.rectangle(cropped_img, box_coords[0], box_coords[1], background_color, cv2.FILLED)
92
+ cv2.putText(cropped_img, confidence_text, (text_offset_x, text_offset_y), cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_color, font_thickness)
93
+
94
+ # Categorize images based on confidence levels
95
+ if confidence_score > 85:
96
+ high_confidence.append(cropped_img)
97
+ elif confidence_score > 50:
98
+ medium_confidence.append(cropped_img)
99
+ else:
100
+ low_confidence.append(cropped_img)
101
+
102
+ v = Visualizer(img[:, :, ::-1], md, scale=1.0, instance_mode=ColorMode.SEGMENTATION)
103
+ result_image = v.draw_instance_predictions(instances).get_image()[:, :, ::-1]
104
+
105
+ return result_image, high_confidence, medium_confidence, low_confidence
106
+ # output = predictor(img)["instances"]
107
+ # v = Visualizer(img[:, :, ::-1],
108
+ # md,
109
+ # scale=1.0,
110
+ # instance_mode=ColorMode.SEGMENTATION)
111
+ # result = v.draw_instance_predictions(output.to("cpu"))
112
+ # result_image = result.get_image()[:, :, ::-1]
113
+ #
114
+ ## figs = [img[box[1]:box[3], box[0]:box[2]] for box, cls in zip(output.pred_boxes, output.pred_classes) if md.thing_classes[cls] == "figure"]
115
+ #
116
+ # return result_image, figs
117
+
118
+ def handle_input(input_data):
119
+ images = []
120
+
121
+ #input_data is a dict with keys 'text' and 'files'
122
+ if 'text' in input_data and input_data['text']:
123
+ input_text = input_data['text'].strip()
124
+
125
+ # this is either a URL or a PDF ID
126
+ if input_text.startswith('http://') or input_text.startswith('https://'):
127
+ # Extract the ID from the URL
128
+ url_parts = urlparse(input_text)
129
+ query_params = parse_qs(url_parts.fragment) # Assumes ID is a fragment parameter
130
+ pdf_id = query_params.get('id', [None])[0]
131
+ if not pdf_id:
132
+ raise ValueError("PDF ID not found in URL")
133
+ else:
134
+ # Assume input is a direct PDF ID
135
+ pdf_id = input_text
136
+
137
+ if not re.match(r'^[a-zA-Z]{4}\d{4}$', pdf_id):
138
+ raise ValueError("Invalid PDF ID format. Expected four letters followed by four numbers.")
139
+
140
+ # Assume input is a PDF ID, convert to URL
141
+ # Now construct the download URL
142
+ pdf_url = construct_download_url(pdf_id)
143
+
144
+ #https://download.industrydocuments.ucsf.edu/k/t/k/l/ktkl0236/ktkl0236.pdf
145
+ # Assume input is a PDF URL
146
+ pdf_data = download_pdf(pdf_url)
147
+ images = pdf_to_images(pdf_data)
148
+
149
+ if 'files' in input_data and input_data['files']:
150
+ for file_path in input_data['files']:
151
+ print("Type of file as uploaded:", type(file_path))
152
+ print(f" File: {file_path}")
153
+
154
+ # Check if the input is a file and determine its type
155
+ kind = filetype.guess(file_path)
156
+ if kind.mime.startswith('image'):
157
+ # Process a single image
158
+ images.append(load_image(file_path)) # Process image directly
159
+ elif kind.mime == 'application/pdf':
160
+ # Convert PDF pages to images
161
+ images.extend(pdf_to_images(file_path))
162
+ else:
163
+ raise ValueError("Unsupported file type.")
164
+ if not images:
165
+ raise ValueError("No valid input provided. Please upload a file or enter a PDF ID.")
166
+
167
+ # Assuming processing images returns galleries of images by confidence
168
+ return process_images(images)
169
+
170
+ def load_image(img_path):
171
+ print(f"Loading image: {img_path}")
172
+ # Load an image from a file path
173
+ image = Image.open(img_path)
174
+ if isinstance(image, Image.Image):
175
+ image = np.array(image) # Convert PIL Image to numpy array
176
+ # Ensure the image is in the correct format
177
+ if image.ndim == 2: # Image is grayscale
178
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
179
+ elif image.ndim == 3 and image.shape[2] == 3:
180
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
181
+ # image = image[:, :, ::-1] # Convert RGB to BGR if necessary
182
+
183
+ return image
184
+
185
+ def construct_download_url(pdf_id):
186
+ # Construct the download URL from the PDF ID
187
+ # https://download.examples.edu/k/t/k/l/ktkl0236/ktkl0236.pdf
188
+ path_parts = '/'.join(pdf_id[i] for i in range(4)) # 'k/t/k/l'
189
+ download_url = f"https://download.industrydocuments.ucsf.edu/{path_parts}/{pdf_id}/{pdf_id}.pdf"
190
+ return download_url
191
+
192
+
193
+ def download_pdf(pdf_url):
194
+ # Download the PDF file from the given URL
195
+ response = requests.get(pdf_url)
196
+ response.raise_for_status() # Ensure we notice bad responses
197
+ return BytesIO(response.content)
198
+
199
+
200
+ def pdf_to_images(data_or_path):
201
+ # Create a temporary directory to store the page images
202
+ temp_dir = "temp_images"
203
+ os.makedirs(temp_dir, exist_ok=True)
204
+
205
+
206
+ try:
207
+ # Convert PDF to a list of PIL images
208
+ # Handle both BytesIO and file path input for PDF conversion
209
+ if isinstance(data_or_path, BytesIO):
210
+ # Convert directly from bytes
211
+ pages = convert_from_bytes(data_or_path.read())
212
+ elif isinstance(data_or_path, str):
213
+ # Convert from a file path
214
+ pages = convert_from_path(data_or_path)
215
+
216
+ # Save each page as an image file
217
+ page_images = []
218
+ for i, page in enumerate(pages):
219
+ image_path = os.path.join(temp_dir, f"page_{i+1}.jpg")
220
+ page.save(image_path, "JPEG")
221
+ page_images.append(load_image(image_path))
222
+
223
+ return page_images
224
+
225
+ except Exception as e:
226
+ print(f"Error converting PDF to images: {str(e)}")
227
+ return []
228
+ finally:
229
+ # Clean up the temporary directory (optional)
230
+ # os.rmdir(temp_dir)
231
+ pass
232
+
233
+ def process_images(images):
234
+ all_processed_images = []
235
+ all_high_confidence = []
236
+ all_medium_confidence = []
237
+ all_low_confidence = []
238
+
239
+ for img in images:
240
+ #print("Type of img before processing:", type(img))
241
+ #print(f" img before processing: {img}")
242
+ processed_images, high_confidence, medium_confidence, low_confidence = analyze_image(img)
243
+ all_processed_images.append(processed_images)
244
+ all_high_confidence.extend(high_confidence)
245
+ all_medium_confidence.extend(medium_confidence)
246
+ all_low_confidence.extend(low_confidence)
247
+
248
+ return all_processed_images, all_high_confidence, all_medium_confidence, all_low_confidence
249
+
250
+ title = "OIDA Image Collection Interactive demo: Document Layout Analysis with DiT and PubLayNet"
251
+ description = "<h3>OIDA Demo -- adapted liberally from <a href='https://huggingface.co/spaces/nielsr/dit-document-layout-analysis'>https://huggingface.co/spaces/nielsr/dit-document-layout-analysis</a></h3>Demo for Microsoft's DiT, the Document Image Transformer for state-of-the-art document understanding tasks. This particular model is fine-tuned on PubLayNet, a large dataset for document layout analysis (read more at the links below). To use it, simply upload an image or use the example image below and click 'Submit'. Results will show up in a few seconds. If you want to make the output bigger, right-click on it and select 'Open image in new tab'."
252
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2203.02378' target='_blank'>Paper</a> | <a href='https://github.com/microsoft/unilm/tree/master/dit' target='_blank'>Github Repo</a> | <a href='https://huggingface.co/docs/transformers/master/en/model_doc/dit' target='_blank'>HuggingFace doc</a> | <a href='https://ieeexplore.ieee.org/document/8977963' target='_blank'>PubLayNet paper</a></p>"
253
+ #examples =[['fpmj0236_Page_012.png'],['fnmf0234_Page_2.png'],['publaynet_example.jpeg'],['fpmj0236_Page_018.png'],['lrpw0232_Page_14.png'],['kllx0250'],['https://www.industrydocuments.ucsf.edu/opioids/docs/#id=yqgg0230']]
254
+ examples =[{'files': ['fnmf0234_Page_2.png']},{'files': ['fpmj0236_Page_012.png']},{'files': ['lrpw0232.pdf']},{'text': 'https://www.industrydocuments.ucsf.edu/opioids/docs/#id=yqgg0230'},{'files':['fpmj0236_Page_018.png']},{'files':['lrpw0232_Page_14.png']},{'files':['publaynet_example.jpeg']},{'text':'kllx0250'},{'text':'txhk0255'}]
255
+ #txhk0255
256
+ css = ".output-image, .input-image, .image-preview {height: 600px !important} td.textbox {display:none;} #component-5 .submit-button {display:none;}"
257
+
258
+ #iface = gr.Interface(fn=handle_input,
259
+ # inputs=gr.MultimodalTextbox(interactive=True,
260
+ # label="Upload image/PDF file OR enter OIDA ID or URL",
261
+ # file_types=["image",".pdf"],
262
+ # placeholder="Upload image/PDF file OR enter OIDA ID or URL"),
263
+ # outputs=[gr.Gallery(label="annotated documents"),
264
+ # gr.Gallery(label="Figures with High (>85%) Confidence Scores"),
265
+ # gr.Gallery(label="Figures with Moderate (50-85%) Confidence Scores"),
266
+ # gr.Gallery(label="Figures with Lower Confidence (under 50%) Scores")],
267
+ # title=title,
268
+ # description=description,
269
+ # examples=examples,
270
+ # article=article,
271
+ # css=css)
272
+ ## enable_queue=True)
273
+ with gr.Blocks(css=css) as iface:
274
+ gr.Markdown(f"# {title}")
275
+ gr.HTML(description)
276
+
277
+ with gr.Row():
278
+ with gr.Column():
279
+ input = gr.MultimodalTextbox(interactive=True,
280
+ label="Upload image/PDF file OR enter OIDA ID or URL",
281
+ file_types=["image",".pdf"],
282
+ placeholder="Upload image/PDF file OR enter OIDA ID or URL",
283
+ submit_btn=None)
284
+ submit_btn = gr.Button("Submit")
285
+ gr.HTML('<br /><br /><hr />')
286
+ gr.Examples(examples, [input])
287
+
288
+ with gr.Column():
289
+ outputs = [gr.Gallery(label="annotated documents"),
290
+ gr.Gallery(label="Figures with High (>85%) Confidence Scores"),
291
+ gr.Gallery(label="Figures with Moderate (50-85%) Confidence Scores"),
292
+ gr.Gallery(label="Figures with Lower Confidence (under 50%) Scores")]
293
+
294
+ with gr.Row():
295
+ gr.HTML(article)
296
+ submit_btn.click(handle_input, [input], outputs)
297
+
298
+ iface.launch(debug=True, auth=[("oida", "OIDA3.1"), ("Brian", "Hi")]) #, cache_examples=True)