Block Diagram Symbol Detection Model

It was introduced in the paper "Unveiling the Power of Integration: Block Diagram Summarization through Local-Global Fusion" accepted at ACL 2024. The full code is available in this BlockNet github repository.

Model description

This model is trained using an object detection model based on YOLOv5, which offers essential capabilities for detecting various objects in an image. Using the CBD, FCA, and FCB dataset, which includes annotations for different shapes and arrows in a diagram, we train the model to recognize six labels: arrow, terminator, process, decision, data, and text.

Training dataset

  • YOLOv5 is fine-tuned with the annotations provided in this GitHub repository for symbol detection in block diagrams.
  • 396 samples from real-world English block diagram dataset (CBD)
  • 357 samples from handwritten English block diagram dataset (FC_A)
  • 476 samples from handwritten English block diagram dataset (FC_B)

How to use

Here is how to use this model in PyTorch:

import argparse
import os
from pathlib import Path
import torch

from models.common import DetectMultiBackend
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
from utils.general import LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2, increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh
from utils.plots import Annotator, colors, save_one_box, save_block_box
from utils.torch_utils import select_device, smart_inference_mode

def load_model(weights, device, dnn, data, fp16):
    device = select_device(device)
    model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=fp16)
    return model

def run_single_image_inference(model, img_path, stride, names, pt, conf_thres=0.35, iou_thres=0.7, max_det=100, augment=True, visualize=False, line_thickness=1, hide_labels=False, hide_conf=False, save_conf=False, save_crop=False, save_block=True, imgsz=(640, 640), vid_stride=1, bs=1, classes=None, agnostic_nms=False, save_txt=True, save_img=True):
    dataset = LoadImages(img_path, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)  # Load image from file
    imgsz = check_img_size(imgsz, s=stride) 

    # Run inference
    model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz))  # warmup
    seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
    for path, im, im0s, vid_cap, s in dataset:
        with dt[0]:
            im = torch.from_numpy(im).to(model.device)
            im = im.half() if model.fp16 else im.float()  # uint8 to fp16/32
            im /= 255  # 0 - 255 to 0.0 - 1.0
            if len(im.shape) == 3:
                im = im[None]  # expand for batch dim

        # Inference
        with dt[1]:
            visualize = False
            pred = model(im, augment=augment, visualize=visualize)

        # NMS
        with dt[2]:
            pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)

        # Second-stage classifier (optional)
        # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)

        # Process predictions
        sorted_data_list = []

        # Process predictions
        for i, det in enumerate(pred):  # per image
            seen += 1
            p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)

            p = Path(p)  # to Path
            s += '%gx%g ' % im.shape[2:]  # print string
            gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
            imc = im0.copy() if save_crop or save_block else im0  # for save_crop
            annotator = Annotator(im0, line_width=line_thickness, example=str(names))
            if len(det):
                # Rescale boxes from img_size to im0 size
                det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()

                # Print results
                for c in det[:, 5].unique():
                    n = (det[:, 5] == c).sum()  # detections per class
                    s += f"{n} {names[int(c)]}{'s' * (n > 1)}, "  # add to string
                
                data_for_image=[]
                # Write results
                for *xyxy, conf, cls in reversed(det):
                    xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
                    line = (cls, *xywh, conf) if save_conf else (cls, *xywh)  # label format
                    data_for_image.append((int(cls), xywh))
                    c = int(cls)  # integer class
                    label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
                    annotator.box_label(xyxy, label, color=colors(c, True))

            # Sort the data based on the top-left coordinates (Y first, then X)
            sorted_data_for_image = sorted(data_for_image, key=lambda x: (x[1][1], x[1][0]))
            sorted_data_list.extend(sorted_data_for_image)
    
    # Return the combined sorted data as a tuple
    return tuple(sorted_data_list)


# Weight path
object_detection_output_path = 'symbol_detection/runs/detect/exp/labels'
yolo_weights_path = 'symbol_detection/runs/train/best_all/weights/best.pt'
yolo_yaml_file = 'symbol_detection/data/mydata.yaml'

yolo_model = load_model(yolo_weights_path, device='cuda:0', dnn=False, data=yolo_yaml_file, fp16=False)
stride, names, pt = yolo_model.stride, yolo_model.names, yolo_model.pt

# Example usage
image_path = "image.png"
labels = run_single_image_inference(yolo_model, image_path, stride, names, pt)

Contact

If you have any questions about this work, please contact Shreyanshu Bhushan using the following email addresses: [email protected].

License

The content of this project itself is licensed under the Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0).

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .