bacngv commited on
Commit
5945677
·
verified ·
1 Parent(s): 3fabfeb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -100
app.py CHANGED
@@ -1,137 +1,137 @@
1
  import gradio as gr
2
  import os
 
 
3
  from pdf2image import convert_from_path
4
- from PIL import Image
5
- import torch
6
- from torchvision import transforms
7
  from transformers import AutoModelForObjectDetection, TableTransformerForObjectDetection
 
8
  import pandas as pd
9
  import numpy as np
 
 
10
  import easyocr
11
- import matplotlib.pyplot as plt
12
 
13
-
14
- # Load detection and structure models
15
  def load_detection_model():
16
  model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm")
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
  model.to(device)
19
  return model, device
20
 
21
-
22
  def load_structure_model(device):
23
- model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-structure-recognition-v1.1-all")
24
- model.to(device)
25
- return model
26
 
27
-
28
- # Preprocess image
29
- class MaxResize:
30
  def __init__(self, max_size=800):
31
  self.max_size = max_size
32
 
33
  def __call__(self, image):
34
  width, height = image.size
35
- max_dim = max(width, height)
36
- scale = self.max_size / max_dim
37
- resized = image.resize((int(round(width * scale)), int(round(height * scale))))
38
- return resized
39
-
40
 
41
  def preprocess_image(image, max_size=800):
42
- transform = transforms.Compose([
43
  MaxResize(max_size),
44
  transforms.ToTensor(),
45
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
46
  ])
47
- pixel_values = transform(image).unsqueeze(0)
48
  return pixel_values
49
 
50
-
51
- # Detect tables
52
- def detect_tables(model, image, device):
53
- pixel_values = preprocess_image(image).to(device)
54
  with torch.no_grad():
55
  outputs = model(pixel_values)
56
  return outputs
57
 
58
-
59
- # Post-process outputs
60
- def post_process(outputs, img_size, id2label):
61
- def box_cxcywh_to_xyxy(x):
62
- cx, cy, w, h = x.unbind(-1)
63
- b = [(cx - 0.5 * w), (cy - 0.5 * h), (cx + 0.5 * w), (cy + 0.5 * h)]
64
- return torch.stack(b, dim=1)
65
-
66
- def rescale_bboxes(out_bbox, size):
67
- img_w, img_h = size
68
- b = box_cxcywh_to_xyxy(out_bbox)
69
- return b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
70
-
71
- img_w, img_h = img_size
72
- prob = outputs.logits.softmax(-1)
73
- labels = prob.argmax(-1)
74
- boxes = rescale_bboxes(outputs.pred_boxes.squeeze(0).detach().cpu(), (img_w, img_h))
75
-
76
- tables = []
77
- for label, box in zip(labels[0], boxes):
78
- if id2label[label.item()] == "table":
79
- tables.append(box.tolist())
80
- return tables
81
-
82
-
83
- # OCR extraction
84
- def extract_table_data(image, bboxes):
85
- reader = easyocr.Reader(['vi'])
86
- data = []
87
- for bbox in bboxes:
88
- x_min, y_min, x_max, y_max = map(int, bbox)
89
- cropped = image.crop((x_min, y_min, x_max, y_max))
90
- text = reader.readtext(np.array(cropped), detail=0)
91
- data.append(" ".join(text))
92
- return data
93
-
94
-
95
- # Process uploaded PDF
96
- def process_pdf(pdf_file):
97
- output = []
98
-
99
- # Convert PDF to images
100
- pages = convert_from_path(pdf_file.name)
101
- detection_model, device = load_detection_model()
102
- id2label = detection_model.config.id2label
103
- id2label[len(id2label)] = "no object"
104
-
105
- for i, page in enumerate(pages):
106
- # Detect tables
107
- outputs = detect_tables(detection_model, page, device)
108
- table_bboxes = post_process(outputs, page.size, id2label)
109
-
110
- # Extract table data
111
- tables = extract_table_data(page, table_bboxes)
112
-
113
- # Save as DataFrame
114
- df = pd.DataFrame(tables)
115
- csv_path = f"page_{i + 1}_tables.csv"
116
- df.to_csv(csv_path, index=False)
117
- output.append(csv_path)
118
-
119
- return output
120
-
121
-
122
- # Gradio interface
123
- def app_interface(pdf_file):
124
- output_files = process_pdf(pdf_file)
125
- return output_files
126
-
127
-
128
- interface = gr.Interface(
129
- fn=app_interface,
 
 
 
 
 
130
  inputs=gr.inputs.File(label="Upload PDF"),
131
- outputs=gr.outputs.File(label="Extracted Tables"),
132
- title="Table Detection and Extraction",
133
- description="Upload a PDF, and this tool will extract tables into CSV format."
134
  )
135
 
136
  if __name__ == "__main__":
137
- interface.launch()
 
1
  import gradio as gr
2
  import os
3
+ import shutil
4
+ import zipfile
5
  from pdf2image import convert_from_path
 
 
 
6
  from transformers import AutoModelForObjectDetection, TableTransformerForObjectDetection
7
+ from PIL import Image, ImageDraw
8
  import pandas as pd
9
  import numpy as np
10
+ import torch
11
+ from torchvision import transforms
12
  import easyocr
 
13
 
14
+ # Define functions for model loading and preprocessing
 
15
  def load_detection_model():
16
  model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm")
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
  model.to(device)
19
  return model, device
20
 
 
21
  def load_structure_model(device):
22
+ structure_model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-structure-recognition-v1.1-all")
23
+ structure_model.to(device)
24
+ return structure_model
25
 
26
+ class MaxResize(object):
 
 
27
  def __init__(self, max_size=800):
28
  self.max_size = max_size
29
 
30
  def __call__(self, image):
31
  width, height = image.size
32
+ current_max_size = max(width, height)
33
+ scale = self.max_size / current_max_size
34
+ resized_image = image.resize((int(round(scale * width)), int(round(scale * height))))
35
+ return resized_image
 
36
 
37
  def preprocess_image(image, max_size=800):
38
+ detection_transform = transforms.Compose([
39
  MaxResize(max_size),
40
  transforms.ToTensor(),
41
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
42
  ])
43
+ pixel_values = detection_transform(image).unsqueeze(0)
44
  return pixel_values
45
 
46
+ # Define detection functions
47
+ def detect_tables(model, pixel_values, device):
48
+ pixel_values = pixel_values.to(device)
 
49
  with torch.no_grad():
50
  outputs = model(pixel_values)
51
  return outputs
52
 
53
+ def rescale_bboxes(out_bbox, size):
54
+ img_w, img_h = size
55
+ x_c, y_c, w, h = out_bbox.unbind(-1)
56
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
57
+ return torch.stack(b, dim=1) * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
58
+
59
+ def outputs_to_objects(outputs, img_size, id2label):
60
+ m = outputs.logits.softmax(-1).max(-1)
61
+ pred_labels = list(m.indices.detach().cpu().numpy())[0]
62
+ pred_scores = list(m.values.detach().cpu().numpy())[0]
63
+ pred_bboxes = outputs["pred_boxes"].detach().cpu()[0]
64
+ pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]
65
+ objects = [
66
+ {"label": id2label[int(label)], "score": float(score), "bbox": [float(x) for x in bbox]}
67
+ for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes)
68
+ if id2label[int(label)] != "no object"
69
+ ]
70
+ return objects
71
+
72
+ # OCR function
73
+ def apply_ocr(image, language="vi"):
74
+ reader = easyocr.Reader([language])
75
+ result = reader.readtext(np.array(image), detail=0)
76
+ return result
77
+
78
+ # Process PDF
79
+ def process_pdf(pdf_path, output_dir):
80
+ images = convert_from_path(pdf_path)
81
+ model, device = load_detection_model()
82
+ structure_model = load_structure_model(device)
83
+
84
+ if os.path.exists(output_dir):
85
+ shutil.rmtree(output_dir)
86
+ os.makedirs(output_dir)
87
+
88
+ txt_output = []
89
+ zip_filename = os.path.join(output_dir, "output.zip")
90
+ with zipfile.ZipFile(zip_filename, "w") as zipf:
91
+ for page_num, image in enumerate(images):
92
+ pixel_values = preprocess_image(image)
93
+ outputs = detect_tables(model, pixel_values, device)
94
+ id2label = model.config.id2label
95
+ id2label[len(id2label)] = "no object"
96
+ objects = outputs_to_objects(outputs, image.size, id2label)
97
+
98
+ # Detect tables
99
+ detected_tables = [obj for obj in objects if obj["label"] in ["table", "table rotated"]]
100
+ for idx, table in enumerate(detected_tables):
101
+ x_min, y_min, x_max, y_max = map(int, table["bbox"])
102
+ cropped_table = image.crop((x_min, y_min, x_max, y_max))
103
+ table_data = apply_ocr(cropped_table)
104
+
105
+ # Save CSV
106
+ csv_filename = os.path.join(output_dir, f"page_{page_num+1}_table_{idx+1}.csv")
107
+ pd.DataFrame(table_data).to_csv(csv_filename, index=False)
108
+ zipf.write(csv_filename, os.path.basename(csv_filename))
109
+
110
+ # Extract remaining text
111
+ text = apply_ocr(image)
112
+ txt_output.append("\n".join(text))
113
+
114
+ # Save text
115
+ txt_filename = os.path.join(output_dir, "remaining_text.txt")
116
+ with open(txt_filename, "w", encoding="utf-8") as txt_file:
117
+ txt_file.write("\n".join(txt_output))
118
+ zipf.write(txt_filename, os.path.basename(txt_filename))
119
+
120
+ return zip_filename
121
+
122
+ # Define Gradio UI
123
+ def process_file(pdf_file):
124
+ output_dir = "output"
125
+ output_zip = process_pdf(pdf_file.name, output_dir)
126
+ return output_zip
127
+
128
+ app = gr.Interface(
129
+ fn=process_file,
130
  inputs=gr.inputs.File(label="Upload PDF"),
131
+ outputs=gr.outputs.File(label="Download Output"),
132
+ title="Table Detection & OCR Extraction",
133
+ description="Upload a scanned PDF, and this app will extract detected tables as CSVs and text as a TXT file."
134
  )
135
 
136
  if __name__ == "__main__":
137
+ app.launch()