Joshnicholas commited on
Commit
0f22253
·
verified ·
1 Parent(s): b24fc32

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +296 -0
app.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Stolen from https://huggingface.co/spaces/pierreguillou/tatr-demo/blob/main/app.py
2
+
3
+ import matplotlib.pyplot as plt
4
+ import matplotlib.patches as patches
5
+ from matplotlib.patches import Patch
6
+ import io
7
+ from PIL import Image, ImageDraw
8
+ import numpy as np
9
+ import csv
10
+ import pandas as pd
11
+
12
+ from torchvision import transforms
13
+
14
+ from transformers import AutoModelForObjectDetection
15
+ import torch
16
+
17
+ import easyocr
18
+
19
+ import gradio as gr
20
+
21
+
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+
24
+
25
+ class MaxResize(object):
26
+ def __init__(self, max_size=800):
27
+ self.max_size = max_size
28
+
29
+ def __call__(self, image):
30
+ width, height = image.size
31
+ current_max_size = max(width, height)
32
+ scale = self.max_size / current_max_size
33
+ resized_image = image.resize((int(round(scale*width)), int(round(scale*height))))
34
+
35
+ return resized_image
36
+
37
+ detection_transform = transforms.Compose([
38
+ MaxResize(800),
39
+ transforms.ToTensor(),
40
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
41
+ ])
42
+
43
+ structure_transform = transforms.Compose([
44
+ MaxResize(1000),
45
+ transforms.ToTensor(),
46
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
47
+ ])
48
+
49
+ # load table detection model
50
+ # processor = TableTransformerImageProcessor(max_size=800)
51
+ model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm").to(device)
52
+
53
+ # load table structure recognition model
54
+ # structure_processor = TableTransformerImageProcessor(max_size=1000)
55
+ structure_model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition-v1.1-all").to(device)
56
+
57
+ # load EasyOCR reader
58
+ reader = easyocr.Reader(['en'])
59
+
60
+
61
+ # for output bounding box post-processing
62
+ def box_cxcywh_to_xyxy(x):
63
+ x_c, y_c, w, h = x.unbind(-1)
64
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
65
+ return torch.stack(b, dim=1)
66
+
67
+
68
+ def rescale_bboxes(out_bbox, size):
69
+ width, height = size
70
+ boxes = box_cxcywh_to_xyxy(out_bbox)
71
+ boxes = boxes * torch.tensor([width, height, width, height], dtype=torch.float32)
72
+ return boxes
73
+
74
+
75
+ def outputs_to_objects(outputs, img_size, id2label):
76
+ m = outputs.logits.softmax(-1).max(-1)
77
+ pred_labels = list(m.indices.detach().cpu().numpy())[0]
78
+ pred_scores = list(m.values.detach().cpu().numpy())[0]
79
+ pred_bboxes = outputs['pred_boxes'].detach().cpu()[0]
80
+ pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]
81
+
82
+ objects = []
83
+ for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
84
+ class_label = id2label[int(label)]
85
+ if not class_label == 'no object':
86
+ objects.append({'label': class_label, 'score': float(score),
87
+ 'bbox': [float(elem) for elem in bbox]})
88
+
89
+ return objects
90
+
91
+
92
+ def fig2img(fig):
93
+ """Convert a Matplotlib figure to a PIL Image and return it"""
94
+ buf = io.BytesIO()
95
+ fig.savefig(buf)
96
+ buf.seek(0)
97
+ image = Image.open(buf)
98
+ return image
99
+
100
+
101
+ def visualize_detected_tables(img, det_tables):
102
+ plt.imshow(img, interpolation="lanczos")
103
+ fig = plt.gcf()
104
+ fig.set_size_inches(20, 20)
105
+ ax = plt.gca()
106
+
107
+ for det_table in det_tables:
108
+ bbox = det_table['bbox']
109
+
110
+ if det_table['label'] == 'table':
111
+ facecolor = (1, 0, 0.45)
112
+ edgecolor = (1, 0, 0.45)
113
+ alpha = 0.3
114
+ linewidth = 2
115
+ hatch='//////'
116
+ elif det_table['label'] == 'table rotated':
117
+ facecolor = (0.95, 0.6, 0.1)
118
+ edgecolor = (0.95, 0.6, 0.1)
119
+ alpha = 0.3
120
+ linewidth = 2
121
+ hatch='//////'
122
+ else:
123
+ continue
124
+
125
+ rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth,
126
+ edgecolor='none',facecolor=facecolor, alpha=0.1)
127
+ ax.add_patch(rect)
128
+ rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth,
129
+ edgecolor=edgecolor,facecolor='none',linestyle='-', alpha=alpha)
130
+ ax.add_patch(rect)
131
+ rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=0,
132
+ edgecolor=edgecolor,facecolor='none',linestyle='-', hatch=hatch, alpha=0.2)
133
+ ax.add_patch(rect)
134
+
135
+ plt.xticks([], [])
136
+ plt.yticks([], [])
137
+
138
+ legend_elements = [Patch(facecolor=(1, 0, 0.45), edgecolor=(1, 0, 0.45),
139
+ label='Table', hatch='//////', alpha=0.3),
140
+ Patch(facecolor=(0.95, 0.6, 0.1), edgecolor=(0.95, 0.6, 0.1),
141
+ label='Table (rotated)', hatch='//////', alpha=0.3)]
142
+ plt.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.02), loc='upper center', borderaxespad=0,
143
+ fontsize=10, ncol=2)
144
+ plt.gcf().set_size_inches(10, 10)
145
+ plt.axis('off')
146
+
147
+ return fig
148
+
149
+
150
+ def detect_and_crop_table(image):
151
+ # prepare image for the model
152
+ # pixel_values = processor(image, return_tensors="pt").pixel_values
153
+ pixel_values = detection_transform(image).unsqueeze(0).to(device)
154
+
155
+ # forward pass
156
+ with torch.no_grad():
157
+ outputs = model(pixel_values)
158
+
159
+ # postprocess to get detected tables
160
+ id2label = model.config.id2label
161
+ id2label[len(model.config.id2label)] = "no object"
162
+ detected_tables = outputs_to_objects(outputs, image.size, id2label)
163
+
164
+ # visualize
165
+ # fig = visualize_detected_tables(image, detected_tables)
166
+ # image = fig2img(fig)
167
+
168
+ # crop first detected table out of image
169
+ cropped_table = image.crop(detected_tables[0]["bbox"])
170
+
171
+ return cropped_table
172
+
173
+
174
+ def recognize_table(image):
175
+ # prepare image for the model
176
+ # pixel_values = structure_processor(images=image, return_tensors="pt").pixel_values
177
+ pixel_values = structure_transform(image).unsqueeze(0).to(device)
178
+
179
+ # forward pass
180
+ with torch.no_grad():
181
+ outputs = structure_model(pixel_values)
182
+
183
+ # postprocess to get individual elements
184
+ id2label = structure_model.config.id2label
185
+ id2label[len(structure_model.config.id2label)] = "no object"
186
+ cells = outputs_to_objects(outputs, image.size, id2label)
187
+
188
+ # visualize cells on cropped table
189
+ draw = ImageDraw.Draw(image)
190
+
191
+ for cell in cells:
192
+ draw.rectangle(cell["bbox"], outline="red")
193
+
194
+ return image, cells
195
+
196
+
197
+ def get_cell_coordinates_by_row(table_data):
198
+ # Extract rows and columns
199
+ rows = [entry for entry in table_data if entry['label'] == 'table row']
200
+ columns = [entry for entry in table_data if entry['label'] == 'table column']
201
+
202
+ # Sort rows and columns by their Y and X coordinates, respectively
203
+ rows.sort(key=lambda x: x['bbox'][1])
204
+ columns.sort(key=lambda x: x['bbox'][0])
205
+
206
+ # Function to find cell coordinates
207
+ def find_cell_coordinates(row, column):
208
+ cell_bbox = [column['bbox'][0], row['bbox'][1], column['bbox'][2], row['bbox'][3]]
209
+ return cell_bbox
210
+
211
+ # Generate cell coordinates and count cells in each row
212
+ cell_coordinates = []
213
+
214
+ for row in rows:
215
+ row_cells = []
216
+ for column in columns:
217
+ cell_bbox = find_cell_coordinates(row, column)
218
+ row_cells.append({'column': column['bbox'], 'cell': cell_bbox})
219
+
220
+ # Sort cells in the row by X coordinate
221
+ row_cells.sort(key=lambda x: x['column'][0])
222
+
223
+ # Append row information to cell_coordinates
224
+ cell_coordinates.append({'row': row['bbox'], 'cells': row_cells, 'cell_count': len(row_cells)})
225
+
226
+ # Sort rows from top to bottom
227
+ cell_coordinates.sort(key=lambda x: x['row'][1])
228
+
229
+ return cell_coordinates
230
+
231
+
232
+ def apply_ocr(cell_coordinates, cropped_table):
233
+ # let's OCR row by row
234
+ data = dict()
235
+ max_num_columns = 0
236
+ for idx, row in enumerate(cell_coordinates):
237
+ row_text = []
238
+ for cell in row["cells"]:
239
+ # crop cell out of image
240
+ cell_image = np.array(cropped_table.crop(cell["cell"]))
241
+ # apply OCR
242
+ result = reader.readtext(np.array(cell_image))
243
+ if len(result) > 0:
244
+ text = " ".join([x[1] for x in result])
245
+ row_text.append(text)
246
+
247
+ if len(row_text) > max_num_columns:
248
+ max_num_columns = len(row_text)
249
+
250
+ data[str(idx)] = row_text
251
+
252
+ # pad rows which don't have max_num_columns elements
253
+ # to make sure all rows have the same number of columns
254
+ for idx, row_data in data.copy().items():
255
+ if len(row_data) != max_num_columns:
256
+ row_data = row_data + ["" for _ in range(max_num_columns - len(row_data))]
257
+ data[str(idx)] = row_data
258
+
259
+ # write to csv
260
+ with open('output.csv','w') as result_file:
261
+ wr = csv.writer(result_file, dialect='excel')
262
+
263
+ for row, row_text in data.items():
264
+ wr.writerow(row_text)
265
+
266
+ # return as Pandas dataframe
267
+ df = pd.read_csv('output.csv')
268
+
269
+ return df, data
270
+
271
+
272
+ def process_pdf(image):
273
+ cropped_table = detect_and_crop_table(image)
274
+
275
+ image, cells = recognize_table(cropped_table)
276
+
277
+ cell_coordinates = get_cell_coordinates_by_row(cells)
278
+
279
+ df, data = apply_ocr(cell_coordinates, image)
280
+
281
+ return image, df, data
282
+
283
+
284
+ title = "Demo: table detection & recognition with Table Transformer (TATR)."
285
+ description = """Demo for table extraction with the Table Transformer. First, table detection is performed on the input image using https://huggingface.co/microsoft/table-transformer-detection,
286
+ after which the detected table is extracted and https://huggingface.co/microsoft/table-transformer-structure-recognition-v1.1-all is leveraged to recognize the individual rows, columns and cells. OCR is then performed per cell, row by row."""
287
+ examples = [['image.png'], ['mistral_paper.png']]
288
+
289
+ app = gr.Interface(fn=process_pdf,
290
+ inputs=gr.Image(type="pil"),
291
+ outputs=[gr.Image(type="pil", label="Detected table"), gr.Dataframe(label="Table as CSV"), gr.JSON(label="Data as JSON")],
292
+ title=title,
293
+ description=description,
294
+ examples=examples)
295
+ app.queue()
296
+ app.launch(debug=True)