|
import gradio as gr |
|
import os |
|
import shutil |
|
import zipfile |
|
from pdf2image import convert_from_path |
|
from transformers import AutoModelForObjectDetection, TableTransformerForObjectDetection |
|
from PIL import Image, ImageDraw |
|
import pandas as pd |
|
import numpy as np |
|
import torch |
|
from torchvision import transforms |
|
import easyocr |
|
|
|
|
|
def load_detection_model(): |
|
model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm") |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model.to(device) |
|
return model, device |
|
|
|
def load_structure_model(device): |
|
structure_model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-structure-recognition-v1.1-all") |
|
structure_model.to(device) |
|
return structure_model |
|
|
|
class MaxResize(object): |
|
def __init__(self, max_size=800): |
|
self.max_size = max_size |
|
|
|
def __call__(self, image): |
|
width, height = image.size |
|
current_max_size = max(width, height) |
|
scale = self.max_size / current_max_size |
|
resized_image = image.resize((int(round(scale * width)), int(round(scale * height)))) |
|
return resized_image |
|
|
|
def preprocess_image(image, max_size=800): |
|
detection_transform = transforms.Compose([ |
|
MaxResize(max_size), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
pixel_values = detection_transform(image).unsqueeze(0) |
|
return pixel_values |
|
|
|
|
|
def detect_tables(model, pixel_values, device): |
|
pixel_values = pixel_values.to(device) |
|
with torch.no_grad(): |
|
outputs = model(pixel_values) |
|
return outputs |
|
|
|
def rescale_bboxes(out_bbox, size): |
|
img_w, img_h = size |
|
x_c, y_c, w, h = out_bbox.unbind(-1) |
|
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] |
|
return torch.stack(b, dim=1) * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) |
|
|
|
def outputs_to_objects(outputs, img_size, id2label): |
|
m = outputs.logits.softmax(-1).max(-1) |
|
pred_labels = list(m.indices.detach().cpu().numpy())[0] |
|
pred_scores = list(m.values.detach().cpu().numpy())[0] |
|
pred_bboxes = outputs["pred_boxes"].detach().cpu()[0] |
|
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)] |
|
objects = [ |
|
{"label": id2label[int(label)], "score": float(score), "bbox": [float(x) for x in bbox]} |
|
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes) |
|
if id2label[int(label)] != "no object" |
|
] |
|
return objects |
|
|
|
|
|
def apply_ocr(image, language="vi"): |
|
reader = easyocr.Reader([language]) |
|
result = reader.readtext(np.array(image), detail=0) |
|
return result |
|
|
|
|
|
def process_pdf(pdf_path, output_dir): |
|
images = convert_from_path(pdf_path) |
|
model, device = load_detection_model() |
|
structure_model = load_structure_model(device) |
|
|
|
if os.path.exists(output_dir): |
|
shutil.rmtree(output_dir) |
|
os.makedirs(output_dir) |
|
|
|
txt_output = [] |
|
zip_filename = os.path.join(output_dir, "output.zip") |
|
with zipfile.ZipFile(zip_filename, "w") as zipf: |
|
for page_num, image in enumerate(images): |
|
pixel_values = preprocess_image(image) |
|
outputs = detect_tables(model, pixel_values, device) |
|
id2label = model.config.id2label |
|
id2label[len(id2label)] = "no object" |
|
objects = outputs_to_objects(outputs, image.size, id2label) |
|
|
|
|
|
detected_tables = [obj for obj in objects if obj["label"] in ["table", "table rotated"]] |
|
for idx, table in enumerate(detected_tables): |
|
x_min, y_min, x_max, y_max = map(int, table["bbox"]) |
|
cropped_table = image.crop((x_min, y_min, x_max, y_max)) |
|
table_data = apply_ocr(cropped_table) |
|
|
|
|
|
csv_filename = os.path.join(output_dir, f"page_{page_num+1}_table_{idx+1}.csv") |
|
pd.DataFrame(table_data).to_csv(csv_filename, index=False) |
|
zipf.write(csv_filename, os.path.basename(csv_filename)) |
|
|
|
|
|
text = apply_ocr(image) |
|
txt_output.append("\n".join(text)) |
|
|
|
|
|
txt_filename = os.path.join(output_dir, "remaining_text.txt") |
|
with open(txt_filename, "w", encoding="utf-8") as txt_file: |
|
txt_file.write("\n".join(txt_output)) |
|
zipf.write(txt_filename, os.path.basename(txt_filename)) |
|
|
|
return zip_filename |
|
|
|
|
|
def process_file(pdf_file): |
|
output_dir = "output" |
|
output_zip = process_pdf(pdf_file.name, output_dir) |
|
return output_zip |
|
|
|
app = gr.Interface( |
|
fn=process_file, |
|
inputs=gr.File(label="Upload PDF", file_types=[".pdf"]), |
|
outputs=gr.File(label="Download Output"), |
|
title="Table Detection & OCR Extraction", |
|
description="Upload a scanned PDF, and this app will extract detected tables as CSVs and text as a TXT file." |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
app.launch() |