import matplotlib.pyplot as plt import matplotlib.patches as patches from matplotlib.patches import Patch import io import cv2 from PIL import Image, ImageDraw, ImageFont import numpy as np import csv import pandas as pd from ultralytics import YOLO import torch from paddleocr import PaddleOCR import postprocess import gradio as gr device = "cuda" if torch.cuda.is_available() else "cpu" detection_model = YOLO('yolov8/runs/detect/yolov8s-custom-detection/weights/best.pt').to(device) structure_model = YOLO('yolov8/runs/detect/yolov8s-custom-structure-all/weights/best.pt').to(device) ocr_model = PaddleOCR(use_angle_cls=True, lang="ch", det_limit_side_len=1920) # TODO use large det_limit_side_len to get better OCR result detection_class_names = ['table', 'table rotated'] structure_class_names = [ 'table', 'table column', 'table row', 'table column header', 'table projected row header', 'table spanning cell', 'no object' ] structure_class_map = {k: v for v, k in enumerate(structure_class_names)} structure_class_thresholds = { "table": 0.5, "table column": 0.5, "table row": 0.5, "table column header": 0.5, "table projected row header": 0.5, "table spanning cell": 0.5, "no object": 10 } def table_detection(image): imgsz = 800 pred = detection_model.predict(image, imgsz=imgsz) pred = pred[0].boxes result = pred.cpu().numpy() result_list = [list(result.xywhn[i]) + [result.conf[i], result.cls[i]] for i in range(result.shape[0])] return result_list def table_structure(image): imgsz = 1024 pred = structure_model.predict(image, imgsz=imgsz) pred = pred[0].boxes result = pred.cpu().numpy() result_list = [list(result.xywhn[i]) + [result.conf[i], result.cls[i]] for i in range(result.shape[0])] return result_list def crop_image(image, detection_result): # crop_filenames = [] width = image.shape[1] height = image.shape[0] # print(width, height) crop_image = image for i, result in enumerate(detection_result[:1]): # TODO only return first detected table class_id = int(result[5]) score = float(result[4]) min_x = result[0] min_y = result[1] w = result[2] h = result[3] # x1 = max(0, int((min_x-w/2-0.02)*width)) # TODO expand 2% # y1 = max(0, int((min_y-h/2-0.02)*height)) # TODO expand 2% # x2 = min(width, int((min_x+w/2+0.02)*width)) # TODO expand 2% # y2 = min(height, int((min_y+h/2+0.02)*height)) # TODO expand 2% x1 = max(0, int((min_x-w/2)*width)-10) # TODO expand 10px y1 = max(0, int((min_y-h/2)*height)-10) # TODO expand 10px x2 = min(width, int((min_x+w/2)*width)+10) # TODO expand 10px y2 = min(height, int((min_y+h/2)*height)+10) # TODO expand 10px # print(x1, y1, x2, y2) crop_image = image[y1:y2, x1:x2, :] # crop_filename = filename[:-4]+'_'+str(i)+'_'+detection_class_names[class_id]+filename[-4:] # crop_filenames.append(crop_filename) # cv2.imwrite(crop_filename, crop_image) return crop_image def convert_stucture(ocr_result, image, structure_result): width = image.shape[1] height = image.shape[0] # print(width, height) bboxes = [] scores = [] labels = [] for i, result in enumerate(structure_result): class_id = int(result[5]) score = float(result[4]) min_x = result[0] min_y = result[1] w = result[2] h = result[3] x1 = int((min_x-w/2)*width) y1 = int((min_y-h/2)*height) x2 = int((min_x+w/2)*width) y2 = int((min_y+h/2)*height) # print(x1, y1, x2, y2) bboxes.append([x1, y1, x2, y2]) scores.append(score) labels.append(class_id) table_objects = [] for bbox, score, label in zip(bboxes, scores, labels): table_objects.append({'bbox': bbox, 'score': score, 'label': label}) # print('table_objects:', table_objects) table = {'objects': table_objects, 'page_num': 0} table_class_objects = [obj for obj in table_objects if obj['label'] == structure_class_map['table']] if len(table_class_objects) > 1: table_class_objects = sorted(table_class_objects, key=lambda x: x['score'], reverse=True) try: table_bbox = list(table_class_objects[0]['bbox']) except: table_bbox = (0,0,1000,1000) # print('table_class_objects:', table_class_objects) # print('table_bbox:', table_bbox) page_tokens = ocr_result tokens_in_table = [token for token in page_tokens if postprocess.iob(token['bbox'], table_bbox) >= 0.5] # print('tokens_in_table:', tokens_in_table) table_structures, cells, confidence_score = postprocess.objects_to_cells(table, table_objects, tokens_in_table, structure_class_names, structure_class_thresholds) return table_structures, cells, confidence_score def visualize_cells(image, table_structures, cells): width = image.shape[1] height = image.shape[0] # print(width, height) empty_image = np.zeros((height, width, 3), np.uint8) empty_image.fill(255) empty_image = Image.fromarray(cv2.cvtColor(empty_image, cv2.COLOR_BGR2RGB)) draw = ImageDraw.Draw(empty_image) fontStyle = ImageFont.truetype("SimSong.ttc", 10, encoding="utf-8") num_cols = len(table_structures['columns']) num_rows = len(table_structures['rows']) data_rows = [['' for _ in range(num_cols)] for _ in range(num_rows)] for i, cell in enumerate(cells): bbox = cell['bbox'] x1 = int(bbox[0]) y1 = int(bbox[1]) x2 = int(bbox[2]) y2 = int(bbox[3]) col_num = cell['column_nums'][0] row_num = cell['row_nums'][0] spans = cell['spans'] text = '' for span in spans: if 'text' in span: text += span['text'] data_rows[row_num][col_num] = text # print('text:', text) text_len = len(text) # print('text_len:', text_len) cell_width = x2-x1 # print('cell_width:', cell_width) num_per_line = cell_width//10 # print('num_per_line:', num_per_line) if num_per_line != 0: line_num = text_len//num_per_line else: line_num = 0 # print('line_num:', line_num) new_text = text[:num_per_line]+'\n' for j in range(line_num): new_text += text[(j+1)*num_per_line:(j+2)*num_per_line]+'\n' # print('new_text:', new_text) text = new_text cv2.rectangle(image, (x1, y1), (x2, y2), color=(0,255,0)) # cv2.putText(image, str(row_num)+'-'+str(col_num), (x1, y1+30), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255)) # cv2.rectangle(empty_image, (x1, y1), (x2, y2), color=(0,0,255)) # cv2.putText(empty_image, str(row_num)+'-'+str(col_num), (x1-10, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255)) # cv2.putText(empty_image, text, (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255)) draw.rectangle([(x1, y1), (x2, y2)], (255,255,255), (0,255,0)) # draw.text((x1-20, y1), str(row_num)+'-'+str(col_num), (255,0,0), font=fontStyle) # draw.text((x1, y1), text, (0,0,255), font=fontStyle) df = pd.DataFrame(data_rows) df.columns = df.columns.astype(str) return image, df, df.to_json() def ocr(image): result = ocr_model.ocr(image, cls=True) result = result[0] new_result = [] if result is not None: bounding_boxes = [line[0] for line in result] txts = [line[1][0] for line in result] scores = [line[1][1] for line in result] # print('txts:', txts) # print('scores:', scores) # print('bounding_boxes:', bounding_boxes) for label, bbox in zip(txts, bounding_boxes): new_result.append({'bbox': [bbox[0][0], bbox[0][1], bbox[2][0], bbox[2][1]], 'text': label}) return new_result def detect_and_crop_table(image): detection_result = table_detection(image) # print('detection_result:', detection_result) cropped_table = crop_image(image, detection_result) return cropped_table def recognize_table(image, ocr_result): structure_result = table_structure(image) print('structure_result:', structure_result) table_structures, cells, confidence_score = convert_stucture(ocr_result, image, structure_result) print('table_structures:', table_structures) print('cells:', cells) print('confidence_score:', confidence_score) image, df, data = visualize_cells(image, table_structures, cells) return image, df, data def process_pdf(image): image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) cropped_table = detect_and_crop_table(image) ocr_result = ocr(cropped_table) # print('ocr_result:', ocr_result) image, df, data = recognize_table(cropped_table, ocr_result) print('df:', df) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) return image, df, data title = "Demo: table detection & recognition with Table Structure Recognition (Yolov8)." description = """Demo for table extraction with the Table Structure Recognition (Yolov8).""" examples = [['image.png'], ['mistral_paper.png']] app = gr.Interface(fn=process_pdf, inputs=gr.Image(type="numpy"), outputs=[gr.Image(type="numpy", label="Detected table"), gr.Dataframe(label="Table as CSV"), gr.JSON(label="Data as JSON")], title=title, description=description, examples=examples) app.queue() # app.launch(debug=True, share=True) app.launch()