vaishanthr's picture
updated files
a62ae31
raw
history blame
6.14 kB
from ultralytics import YOLO
import cv2
import gradio as gr
import numpy as np
import os
import torch
from image_segmenter import ImageSegmenter
# params
CANCEL_PROCESSING = False
img_seg = ImageSegmenter(model_type="yolov8m-seg-custom")
def resize(image):
"""
resize the input nd array
"""
h, w = image.shape[:2]
if h > w:
return cv2.resize(image, (480, 640))
else:
return cv2.resize(image, (640, 480))
def process_image(image):
image = resize(image)
prediction, _ = img_seg.predict(image)
return prediction
def process_video(vid_path=None):
vid_cap = cv2.VideoCapture(vid_path)
while vid_cap.isOpened():
ret, frame = vid_cap.read()
if ret:
print("Making frame predictions ....")
frame = resize(frame)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
prediction, _ = img_seg.predict(frame)
yield prediction
return None
def update_segmentation_options(options):
img_seg.is_show_bounding_boxes = True if 'Show Boundary Box' in options else False
img_seg.is_show_segmentation = True if 'Show Segmentation Region' in options else False
img_seg.is_show_segmentation_boundary = True if 'Show Segmentation Boundary' in options else False
def update_confidence_threshold(thres_val):
img_seg.confidence_threshold = thres_val/100
def model_selector(model_type):
if "Small - Better performance and less accuracy" == model_type:
yolo_model = "yolov8s_seg_custom"
elif "Medium - Balanced performance and accuracy" == model_type:
yolo_model = "yolov8m-seg-custom"
elif "Large - Slow performance and high accuracy" == model_type:
yolo_model = "yolov8m-seg-custom"
else:
yolo_model = "yolov8m-seg-custom"
img_seg = ImageSegmenter(model_type=yolo_model)
def cancel():
CANCEL_PROCESSING = True
if __name__ == "__main__":
# gradio gui app
with gr.Blocks() as my_app:
# title
gr.Markdown("<h1><center>Hand detection and segmentation</center></h1>")
# tabs
with gr.Tab("Image"):
with gr.Row():
with gr.Column(scale=1):
img_input = gr.Image()
model_type_img = gr.Dropdown(
["Small - Better performance and less accuracy",
"Medium - Balanced performance and accuracy",
"Large - Slow performance and high accuracy"],
label="Model Type", value="Medium - Balanced performance and accuracy",
info="Select the inference model before running predictions!")
options_checkbox_img = gr.CheckboxGroup(["Show Boundary Box", "Show Segmentation Region"], label="Options")
conf_thres_img = gr.Slider(1, 100, value=60, label="Confidence Threshold", info="Choose the threshold above which objects should be detected")
submit_btn_img = gr.Button(value="Predict")
with gr.Column(scale=2):
with gr.Row():
img_output = gr.Image(height=300, label="Segmentation")
gr.Markdown("## Sample Images")
gr.Examples(
examples=[os.path.join(os.path.dirname(__file__), "assets/images/img_1.jpg"),
os.path.join(os.path.dirname(__file__), "assets/images/img_1.jpg")],
inputs=img_input,
outputs=img_output,
fn=process_image,
cache_examples=True,
)
with gr.Tab("Video"):
with gr.Row():
with gr.Column(scale=1):
vid_input = gr.Video()
model_type_vid = gr.Dropdown(
["Small - Better performance and less accuracy",
"Medium - Balanced performance and accuracy",
"Large - Slow performance and high accuracy"],
label="Model Type", value="Medium - Balanced performance and accuracy",
info="Select the inference model before running predictions!")
options_checkbox_vid = gr.CheckboxGroup(["Show Boundary Box", "Show Segmentation Region"], label="Options")
conf_thres_vid = gr.Slider(1, 100, value=60, label="Confidence Threshold", info="Choose the threshold above which objects should be detected")
with gr.Row():
cancel_btn = gr.Button(value="Cancel")
submit_btn_vid = gr.Button(value="Predict")
with gr.Column(scale=2):
with gr.Row():
vid_output = gr.Image(height=300, label="Segmentation")
gr.Markdown("## Sample Videos")
gr.Examples(
examples=[os.path.join(os.path.dirname(__file__), "assets/videos/vid_1.mp4"),
os.path.join(os.path.dirname(__file__), "assets/videos/vid_2.mp4"),],
inputs=vid_input,
# outputs=vid_output,
# fn=vid_segmenation,
)
# image tab logic
submit_btn_img.click(process_image, inputs=img_input, outputs=img_output)
options_checkbox_img.change(update_segmentation_options, options_checkbox_img, [])
conf_thres_img.change(update_confidence_threshold, conf_thres_img, [])
model_type_img.change(model_selector, model_type_img, [])
# video tab logic
submit_btn_vid.click(process_video, inputs=vid_input, outputs=vid_output)
model_type_vid.change(model_selector, model_type_vid, [])
cancel_btn.click(cancel, inputs=[], outputs=[])
options_checkbox_vid.change(update_segmentation_options, options_checkbox_vid, [])
conf_thres_vid.change(update_confidence_threshold, conf_thres_vid, [])
my_app.queue(concurrency_count=5, max_size=20).launch(debug=True)