PDF2TEXT / app.py
bacngv's picture
Update app.py
630879e verified
raw
history blame
5.15 kB
import gradio as gr
import os
import shutil
import zipfile
from pdf2image import convert_from_path
from transformers import AutoModelForObjectDetection, TableTransformerForObjectDetection
from PIL import Image, ImageDraw
import pandas as pd
import numpy as np
import torch
from torchvision import transforms
import easyocr
# Define functions for model loading and preprocessing
def load_detection_model():
model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
return model, device
def load_structure_model(device):
structure_model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-structure-recognition-v1.1-all")
structure_model.to(device)
return structure_model
class MaxResize(object):
def __init__(self, max_size=800):
self.max_size = max_size
def __call__(self, image):
width, height = image.size
current_max_size = max(width, height)
scale = self.max_size / current_max_size
resized_image = image.resize((int(round(scale * width)), int(round(scale * height))))
return resized_image
def preprocess_image(image, max_size=800):
detection_transform = transforms.Compose([
MaxResize(max_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
pixel_values = detection_transform(image).unsqueeze(0)
return pixel_values
# Define detection functions
def detect_tables(model, pixel_values, device):
pixel_values = pixel_values.to(device)
with torch.no_grad():
outputs = model(pixel_values)
return outputs
def rescale_bboxes(out_bbox, size):
img_w, img_h = size
x_c, y_c, w, h = out_bbox.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1) * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
def outputs_to_objects(outputs, img_size, id2label):
m = outputs.logits.softmax(-1).max(-1)
pred_labels = list(m.indices.detach().cpu().numpy())[0]
pred_scores = list(m.values.detach().cpu().numpy())[0]
pred_bboxes = outputs["pred_boxes"].detach().cpu()[0]
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]
objects = [
{"label": id2label[int(label)], "score": float(score), "bbox": [float(x) for x in bbox]}
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes)
if id2label[int(label)] != "no object"
]
return objects
# OCR function
def apply_ocr(image, language="vi"):
reader = easyocr.Reader([language])
result = reader.readtext(np.array(image), detail=0)
return result
# Process PDF
def process_pdf(pdf_path, output_dir):
images = convert_from_path(pdf_path)
model, device = load_detection_model()
structure_model = load_structure_model(device)
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
os.makedirs(output_dir)
txt_output = []
zip_filename = os.path.join(output_dir, "output.zip")
with zipfile.ZipFile(zip_filename, "w") as zipf:
for page_num, image in enumerate(images):
pixel_values = preprocess_image(image)
outputs = detect_tables(model, pixel_values, device)
id2label = model.config.id2label
id2label[len(id2label)] = "no object"
objects = outputs_to_objects(outputs, image.size, id2label)
# Detect tables
detected_tables = [obj for obj in objects if obj["label"] in ["table", "table rotated"]]
for idx, table in enumerate(detected_tables):
x_min, y_min, x_max, y_max = map(int, table["bbox"])
cropped_table = image.crop((x_min, y_min, x_max, y_max))
table_data = apply_ocr(cropped_table)
# Save CSV
csv_filename = os.path.join(output_dir, f"page_{page_num+1}_table_{idx+1}.csv")
pd.DataFrame(table_data).to_csv(csv_filename, index=False)
zipf.write(csv_filename, os.path.basename(csv_filename))
# Extract remaining text
text = apply_ocr(image)
txt_output.append("\n".join(text))
# Save text
txt_filename = os.path.join(output_dir, "remaining_text.txt")
with open(txt_filename, "w", encoding="utf-8") as txt_file:
txt_file.write("\n".join(txt_output))
zipf.write(txt_filename, os.path.basename(txt_filename))
return zip_filename
# Define Gradio UI
def process_file(pdf_file):
output_dir = "output"
output_zip = process_pdf(pdf_file.name, output_dir)
return output_zip
app = gr.Interface(
fn=process_file,
inputs=gr.File(label="Upload PDF", file_types=[".pdf"]),
outputs=gr.File(label="Download Output"),
title="Table Detection & OCR Extraction",
description="Upload a scanned PDF, and this app will extract detected tables as CSVs and text as a TXT file."
)
if __name__ == "__main__":
app.launch()