File size: 5,149 Bytes
3fabfeb
 
5945677
 
3fabfeb
 
5945677
3fabfeb
 
5945677
 
3fabfeb
 
5945677
3fabfeb
 
 
 
 
 
 
5945677
 
 
3fabfeb
5945677
3fabfeb
 
 
 
 
5945677
 
 
 
3fabfeb
 
5945677
3fabfeb
 
 
 
5945677
3fabfeb
 
5945677
 
 
3fabfeb
 
 
 
5945677
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
630879e
 
5945677
 
3fabfeb
 
630879e
3fabfeb
5945677
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import gradio as gr
import os
import shutil
import zipfile
from pdf2image import convert_from_path
from transformers import AutoModelForObjectDetection, TableTransformerForObjectDetection
from PIL import Image, ImageDraw
import pandas as pd
import numpy as np
import torch
from torchvision import transforms
import easyocr

# Define functions for model loading and preprocessing
def load_detection_model():
    model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    return model, device

def load_structure_model(device):
    structure_model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-structure-recognition-v1.1-all")
    structure_model.to(device)
    return structure_model

class MaxResize(object):
    def __init__(self, max_size=800):
        self.max_size = max_size

    def __call__(self, image):
        width, height = image.size
        current_max_size = max(width, height)
        scale = self.max_size / current_max_size
        resized_image = image.resize((int(round(scale * width)), int(round(scale * height))))
        return resized_image

def preprocess_image(image, max_size=800):
    detection_transform = transforms.Compose([
        MaxResize(max_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    pixel_values = detection_transform(image).unsqueeze(0)
    return pixel_values

# Define detection functions
def detect_tables(model, pixel_values, device):
    pixel_values = pixel_values.to(device)
    with torch.no_grad():
        outputs = model(pixel_values)
    return outputs

def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    x_c, y_c, w, h = out_bbox.unbind(-1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1) * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)

def outputs_to_objects(outputs, img_size, id2label):
    m = outputs.logits.softmax(-1).max(-1)
    pred_labels = list(m.indices.detach().cpu().numpy())[0]
    pred_scores = list(m.values.detach().cpu().numpy())[0]
    pred_bboxes = outputs["pred_boxes"].detach().cpu()[0]
    pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]
    objects = [
        {"label": id2label[int(label)], "score": float(score), "bbox": [float(x) for x in bbox]}
        for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes)
        if id2label[int(label)] != "no object"
    ]
    return objects

# OCR function
def apply_ocr(image, language="vi"):
    reader = easyocr.Reader([language])
    result = reader.readtext(np.array(image), detail=0)
    return result

# Process PDF
def process_pdf(pdf_path, output_dir):
    images = convert_from_path(pdf_path)
    model, device = load_detection_model()
    structure_model = load_structure_model(device)

    if os.path.exists(output_dir):
        shutil.rmtree(output_dir)
    os.makedirs(output_dir)

    txt_output = []
    zip_filename = os.path.join(output_dir, "output.zip")
    with zipfile.ZipFile(zip_filename, "w") as zipf:
        for page_num, image in enumerate(images):
            pixel_values = preprocess_image(image)
            outputs = detect_tables(model, pixel_values, device)
            id2label = model.config.id2label
            id2label[len(id2label)] = "no object"
            objects = outputs_to_objects(outputs, image.size, id2label)

            # Detect tables
            detected_tables = [obj for obj in objects if obj["label"] in ["table", "table rotated"]]
            for idx, table in enumerate(detected_tables):
                x_min, y_min, x_max, y_max = map(int, table["bbox"])
                cropped_table = image.crop((x_min, y_min, x_max, y_max))
                table_data = apply_ocr(cropped_table)

                # Save CSV
                csv_filename = os.path.join(output_dir, f"page_{page_num+1}_table_{idx+1}.csv")
                pd.DataFrame(table_data).to_csv(csv_filename, index=False)
                zipf.write(csv_filename, os.path.basename(csv_filename))

            # Extract remaining text
            text = apply_ocr(image)
            txt_output.append("\n".join(text))

        # Save text
        txt_filename = os.path.join(output_dir, "remaining_text.txt")
        with open(txt_filename, "w", encoding="utf-8") as txt_file:
            txt_file.write("\n".join(txt_output))
        zipf.write(txt_filename, os.path.basename(txt_filename))

    return zip_filename

# Define Gradio UI
def process_file(pdf_file):
    output_dir = "output"
    output_zip = process_pdf(pdf_file.name, output_dir)
    return output_zip

app = gr.Interface(
    fn=process_file,
    inputs=gr.File(label="Upload PDF", file_types=[".pdf"]),
    outputs=gr.File(label="Download Output"),
    title="Table Detection & OCR Extraction",
    description="Upload a scanned PDF, and this app will extract detected tables as CSVs and text as a TXT file."
)


if __name__ == "__main__":
    app.launch()