File size: 4,729 Bytes
e730543 504dbd0 e730543 504dbd0 e730543 504dbd0 e730543 504dbd0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import os
from pdf2image import convert_from_path
from utils import preprocess_image, remove_table_areas, MaxResize
from models import load_detection_model, load_structure_model, detect_objects, outputs_to_objects
from visualization import visualize_detected_tables, visualize_cropped_table
from ocr import apply_ocr, apply_ocr_remaining_area
from io_utils import save_remaining_text_to_txt, save_to_csv
from torchvision import transforms
import torch
# Main function to process the PDF file
def process_pdf(pdf_path, output_folder):
# Convert PDF to images
images = convert_from_path(pdf_path)
model, device = load_detection_model()
structure_model = load_structure_model(device)
if not os.path.exists(output_folder):
os.makedirs(output_folder)
output_files = [] # Dùng để lưu danh sách file kết quả
for page_num, image in enumerate(images):
print(f"Processing page {page_num + 1}/{len(images)}")
# Detect tables in the image
pixel_values = preprocess_image(image)
outputs = detect_objects(model, pixel_values, device)
id2label = model.config.id2label
id2label[len(model.config.id2label)] = "no object"
objects = outputs_to_objects(outputs, image.size, id2label)
# Visualize detected tables
visualize_detected_tables(image, objects)
# Create a mask for the detected table bounding boxes
detected_bboxes = [obj['bbox'] for obj in objects if obj['label'] in ['table', 'table rotated']]
# Create a new image with the table areas removed
image_without_tables = remove_table_areas(image, detected_bboxes)
# Perform OCR on the remaining area of the image
remaining_text = apply_ocr_remaining_area(image_without_tables)
# Save the remaining text to a file and thêm vào danh sách output_files
txt_file = save_remaining_text_to_txt(remaining_text, output_folder, page_num)
output_files.append(txt_file)
# Process each cropped table and save data as CSV
table_data_list = []
for idx, bbox in enumerate(detected_bboxes):
x_min, y_min, x_max, y_max = bbox
cropped_table = image.crop((x_min, y_min, x_max, y_max))
table_data = process_cropped_table(cropped_table, structure_model, device, page_num, idx, output_folder)
table_data_list.append(table_data)
# Save each table data to CSV and thêm vào danh sách output_files
csv_filename = os.path.join(output_folder, f'page_{page_num + 1}_table_{idx + 1}.csv')
output_files.append(csv_filename)
# Trả về danh sách các file output cho Gradio
return output_files
# Function to process each cropped table and save to CSV
def process_cropped_table(cropped_table, structure_model, device, page_num, table_index, output_folder):
structure_transform = transforms.Compose([
MaxResize(1000),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
pixel_values = structure_transform(cropped_table).unsqueeze(0)
pixel_values = pixel_values.to(device)
# Forward pass
with torch.no_grad():
outputs = structure_model(pixel_values)
# Update id2label to include "no object"
structure_id2label = structure_model.config.id2label
structure_id2label[len(structure_id2label)] = "no object"
cells = outputs_to_objects(outputs, cropped_table.size, structure_id2label)
# Visualize the detected cells in the cropped table
visualize_cropped_table(cropped_table, cells)
# Apply OCR to the detected cells
cell_coordinates = get_cell_coordinates_by_row(cells)
data = apply_ocr(cell_coordinates, cropped_table)
# Save extracted data to CSV
csv_filename = os.path.join(output_folder, f'page_{page_num + 1}_table_{table_index + 1}.csv')
save_to_csv(data, csv_filename)
return data
def get_cell_coordinates_by_row(table_data):
rows = [entry for entry in table_data if entry['label'] == 'table row']
columns = [entry for entry in table_data if entry['label'] == 'table column']
rows.sort(key=lambda x: x['bbox'][1])
columns.sort(key=lambda x: x['bbox'][0])
def find_cell_coordinates(row, column):
cell_bbox = [column['bbox'][0], row['bbox'][1], column['bbox'][2], row['bbox'][3]]
return cell_bbox
cell_coordinates = []
for row in rows:
row_cells = []
for column in columns:
cell_bbox = find_cell_coordinates(row, column)
row_cells.append({'bbox': cell_bbox})
cell_coordinates.append(row_cells)
return cell_coordinates |