File size: 5,149 Bytes
3fabfeb 5945677 3fabfeb 5945677 3fabfeb 5945677 3fabfeb 5945677 3fabfeb 5945677 3fabfeb 5945677 3fabfeb 5945677 3fabfeb 5945677 3fabfeb 5945677 3fabfeb 5945677 3fabfeb 5945677 630879e 5945677 3fabfeb 630879e 3fabfeb 5945677 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import gradio as gr
import os
import shutil
import zipfile
from pdf2image import convert_from_path
from transformers import AutoModelForObjectDetection, TableTransformerForObjectDetection
from PIL import Image, ImageDraw
import pandas as pd
import numpy as np
import torch
from torchvision import transforms
import easyocr
# Define functions for model loading and preprocessing
def load_detection_model():
model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
return model, device
def load_structure_model(device):
structure_model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-structure-recognition-v1.1-all")
structure_model.to(device)
return structure_model
class MaxResize(object):
def __init__(self, max_size=800):
self.max_size = max_size
def __call__(self, image):
width, height = image.size
current_max_size = max(width, height)
scale = self.max_size / current_max_size
resized_image = image.resize((int(round(scale * width)), int(round(scale * height))))
return resized_image
def preprocess_image(image, max_size=800):
detection_transform = transforms.Compose([
MaxResize(max_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
pixel_values = detection_transform(image).unsqueeze(0)
return pixel_values
# Define detection functions
def detect_tables(model, pixel_values, device):
pixel_values = pixel_values.to(device)
with torch.no_grad():
outputs = model(pixel_values)
return outputs
def rescale_bboxes(out_bbox, size):
img_w, img_h = size
x_c, y_c, w, h = out_bbox.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1) * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
def outputs_to_objects(outputs, img_size, id2label):
m = outputs.logits.softmax(-1).max(-1)
pred_labels = list(m.indices.detach().cpu().numpy())[0]
pred_scores = list(m.values.detach().cpu().numpy())[0]
pred_bboxes = outputs["pred_boxes"].detach().cpu()[0]
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]
objects = [
{"label": id2label[int(label)], "score": float(score), "bbox": [float(x) for x in bbox]}
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes)
if id2label[int(label)] != "no object"
]
return objects
# OCR function
def apply_ocr(image, language="vi"):
reader = easyocr.Reader([language])
result = reader.readtext(np.array(image), detail=0)
return result
# Process PDF
def process_pdf(pdf_path, output_dir):
images = convert_from_path(pdf_path)
model, device = load_detection_model()
structure_model = load_structure_model(device)
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
os.makedirs(output_dir)
txt_output = []
zip_filename = os.path.join(output_dir, "output.zip")
with zipfile.ZipFile(zip_filename, "w") as zipf:
for page_num, image in enumerate(images):
pixel_values = preprocess_image(image)
outputs = detect_tables(model, pixel_values, device)
id2label = model.config.id2label
id2label[len(id2label)] = "no object"
objects = outputs_to_objects(outputs, image.size, id2label)
# Detect tables
detected_tables = [obj for obj in objects if obj["label"] in ["table", "table rotated"]]
for idx, table in enumerate(detected_tables):
x_min, y_min, x_max, y_max = map(int, table["bbox"])
cropped_table = image.crop((x_min, y_min, x_max, y_max))
table_data = apply_ocr(cropped_table)
# Save CSV
csv_filename = os.path.join(output_dir, f"page_{page_num+1}_table_{idx+1}.csv")
pd.DataFrame(table_data).to_csv(csv_filename, index=False)
zipf.write(csv_filename, os.path.basename(csv_filename))
# Extract remaining text
text = apply_ocr(image)
txt_output.append("\n".join(text))
# Save text
txt_filename = os.path.join(output_dir, "remaining_text.txt")
with open(txt_filename, "w", encoding="utf-8") as txt_file:
txt_file.write("\n".join(txt_output))
zipf.write(txt_filename, os.path.basename(txt_filename))
return zip_filename
# Define Gradio UI
def process_file(pdf_file):
output_dir = "output"
output_zip = process_pdf(pdf_file.name, output_dir)
return output_zip
app = gr.Interface(
fn=process_file,
inputs=gr.File(label="Upload PDF", file_types=[".pdf"]),
outputs=gr.File(label="Download Output"),
title="Table Detection & OCR Extraction",
description="Upload a scanned PDF, and this app will extract detected tables as CSVs and text as a TXT file."
)
if __name__ == "__main__":
app.launch() |