|
--- |
|
tags: |
|
- YOLOv5 |
|
- Object-Detection |
|
- Vision |
|
--- |
|
|
|
# Block Diagram Symbol Detection Model |
|
|
|
It was introduced in the paper **"Unveiling the Power of Integration: Block Diagram Summarization through Local-Global Fusion"** accepted at ACL 2024. The full code is available in this [BlockNet](https://github.com/shreyanshu09/BlockNet) github repository. |
|
|
|
## Model description |
|
|
|
This model is trained using an object detection model based on YOLOv5, which offers essential capabilities for detecting various objects in an image. Using the CBD, FCA, and FCB dataset, which includes annotations for different shapes and arrows in a diagram, we train the model to recognize six labels: arrow, terminator, process, decision, data, and text. |
|
|
|
## Training dataset |
|
|
|
- YOLOv5 is fine-tuned with the annotations provided in this [GitHub](https://github.com/shreyanshu09/Block-Diagram-Datasets) repository for symbol detection in block diagrams. |
|
- 396 samples from real-world English block diagram dataset (CBD) |
|
- 357 samples from handwritten English block diagram dataset (FC_A) |
|
- 476 samples from handwritten English block diagram dataset (FC_B) |
|
|
|
## How to use |
|
|
|
Here is how to use this model in PyTorch: |
|
|
|
```python |
|
import argparse |
|
import os |
|
from pathlib import Path |
|
import torch |
|
|
|
from models.common import DetectMultiBackend |
|
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams |
|
from utils.general import LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2, increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh |
|
from utils.plots import Annotator, colors, save_one_box, save_block_box |
|
from utils.torch_utils import select_device, smart_inference_mode |
|
|
|
def load_model(weights, device, dnn, data, fp16): |
|
device = select_device(device) |
|
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=fp16) |
|
return model |
|
|
|
def run_single_image_inference(model, img_path, stride, names, pt, conf_thres=0.35, iou_thres=0.7, max_det=100, augment=True, visualize=False, line_thickness=1, hide_labels=False, hide_conf=False, save_conf=False, save_crop=False, save_block=True, imgsz=(640, 640), vid_stride=1, bs=1, classes=None, agnostic_nms=False, save_txt=True, save_img=True): |
|
dataset = LoadImages(img_path, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride) # Load image from file |
|
imgsz = check_img_size(imgsz, s=stride) |
|
|
|
# Run inference |
|
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup |
|
seen, windows, dt = 0, [], (Profile(), Profile(), Profile()) |
|
for path, im, im0s, vid_cap, s in dataset: |
|
with dt[0]: |
|
im = torch.from_numpy(im).to(model.device) |
|
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32 |
|
im /= 255 # 0 - 255 to 0.0 - 1.0 |
|
if len(im.shape) == 3: |
|
im = im[None] # expand for batch dim |
|
|
|
# Inference |
|
with dt[1]: |
|
visualize = False |
|
pred = model(im, augment=augment, visualize=visualize) |
|
|
|
# NMS |
|
with dt[2]: |
|
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det) |
|
|
|
# Second-stage classifier (optional) |
|
# pred = utils.general.apply_classifier(pred, classifier_model, im, im0s) |
|
|
|
# Process predictions |
|
sorted_data_list = [] |
|
|
|
# Process predictions |
|
for i, det in enumerate(pred): # per image |
|
seen += 1 |
|
p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0) |
|
|
|
p = Path(p) # to Path |
|
s += '%gx%g ' % im.shape[2:] # print string |
|
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh |
|
imc = im0.copy() if save_crop or save_block else im0 # for save_crop |
|
annotator = Annotator(im0, line_width=line_thickness, example=str(names)) |
|
if len(det): |
|
# Rescale boxes from img_size to im0 size |
|
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() |
|
|
|
# Print results |
|
for c in det[:, 5].unique(): |
|
n = (det[:, 5] == c).sum() # detections per class |
|
s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string |
|
|
|
data_for_image=[] |
|
# Write results |
|
for *xyxy, conf, cls in reversed(det): |
|
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh |
|
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format |
|
data_for_image.append((int(cls), xywh)) |
|
c = int(cls) # integer class |
|
label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}') |
|
annotator.box_label(xyxy, label, color=colors(c, True)) |
|
|
|
# Sort the data based on the top-left coordinates (Y first, then X) |
|
sorted_data_for_image = sorted(data_for_image, key=lambda x: (x[1][1], x[1][0])) |
|
sorted_data_list.extend(sorted_data_for_image) |
|
|
|
# Return the combined sorted data as a tuple |
|
return tuple(sorted_data_list) |
|
|
|
|
|
# Weight path |
|
object_detection_output_path = 'symbol_detection/runs/detect/exp/labels' |
|
yolo_weights_path = 'symbol_detection/runs/train/best_all/weights/best.pt' |
|
yolo_yaml_file = 'symbol_detection/data/mydata.yaml' |
|
|
|
yolo_model = load_model(yolo_weights_path, device='cuda:0', dnn=False, data=yolo_yaml_file, fp16=False) |
|
stride, names, pt = yolo_model.stride, yolo_model.names, yolo_model.pt |
|
|
|
# Example usage |
|
image_path = "image.png" |
|
labels = run_single_image_inference(yolo_model, image_path, stride, names, pt) |
|
``` |
|
|
|
## Contact |
|
|
|
If you have any questions about this work, please contact **[Shreyanshu Bhushan](https://github.com/shreyanshu09)** using the following email addresses: **[email protected]**. |
|
|
|
## License |
|
|
|
The content of this project itself is licensed under the [Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0)](https://creativecommons.org/licenses/by-nc-sa/4.0/). |
|
|