Spaces:
Runtime error
Runtime error
import argparse | |
import csv | |
import sys | |
sys.path.append("./LIVE") | |
from pathlib import Path | |
import gradio as gr | |
import torch | |
import yaml | |
from PIL import Image | |
from subprocess import call | |
ROOT_PATH = sys.path[0] # 根目录 | |
# 模型路径 | |
model_path = "ultralytics/yolov5" | |
# 模型名称临时变量 | |
model_name_tmp = "" | |
# 设备临时变量 | |
device_tmp = "" | |
# 文件后缀 | |
suffix_list = [".csv", ".yaml"] | |
def parse_args(known=False): | |
parser = argparse.ArgumentParser(description="Gradio LIVE") | |
parser.add_argument( | |
"--model_name", "-mn", default="yolov5s", type=str, help="model name" | |
) | |
parser.add_argument( | |
"--model_cfg", | |
"-mc", | |
default="./model_config/model_name_p5_all.yaml", | |
type=str, | |
help="model config", | |
) | |
parser.add_argument( | |
"--cls_name", | |
"-cls", | |
default="./cls_name/cls_name.yaml", | |
type=str, | |
help="cls name", | |
) | |
parser.add_argument( | |
"--nms_conf", | |
"-conf", | |
default=0.5, | |
type=float, | |
help="model NMS confidence threshold", | |
) | |
parser.add_argument( | |
"--nms_iou", "-iou", default=0.45, type=float, help="model NMS IoU threshold" | |
) | |
parser.add_argument( | |
"--label_dnt_show", | |
"-lds", | |
action="store_false", | |
default=True, | |
help="label show", | |
) | |
parser.add_argument( | |
"--device", | |
"-dev", | |
default="cpu", | |
type=str, | |
help="cuda or cpu, hugging face only cpu", | |
) | |
parser.add_argument( | |
"--inference_size", "-isz", default=640, type=int, help="model inference size" | |
) | |
args = parser.parse_known_args()[0] if known else parser.parse_args() | |
return args | |
# 模型加载 | |
def model_loading(model_name, device): | |
# 加载本地模型 | |
model = torch.hub.load(model_path, model_name, force_reload=True, device=device) | |
return model | |
# 检测信息 | |
def export_json(results, model, img_size): | |
return [ | |
[ | |
{ | |
"id": int(i), | |
"class": int(result[i][5]), | |
"class_name": model.model.names[int(result[i][5])], | |
"normalized_box": { | |
"x0": round(result[i][:4].tolist()[0], 6), | |
"y0": round(result[i][:4].tolist()[1], 6), | |
"x1": round(result[i][:4].tolist()[2], 6), | |
"y1": round(result[i][:4].tolist()[3], 6), | |
}, | |
"confidence": round(float(result[i][4]), 2), | |
"fps": round(1000 / float(results.t[1]), 2), | |
"width": img_size[0], | |
"height": img_size[1], | |
} | |
for i in range(len(result)) | |
] | |
for result in results.xyxyn | |
] | |
def yolo_det(img, experiment_id, device=None, model_name=None, inference_size=None, conf=None, iou=None, label_opt=None, model_cls=None): | |
global model, model_name_tmp, device_tmp | |
if model_name_tmp != model_name: | |
# 模型判断,避免反复加载 | |
model_name_tmp = model_name | |
model = model_loading(model_name_tmp, device) | |
elif device_tmp != device: | |
device_tmp = device | |
model = model_loading(model_name_tmp, device) | |
# -----------模型调参----------- | |
model.conf = conf # NMS 置信度阈值 | |
model.iou = iou # NMS IOU阈值 | |
model.max_det = 1000 # 最大检测框数 | |
model.classes = model_cls # 模型类别 | |
results = model(img, size=inference_size) # 检测 | |
results.render(labels=label_opt) # 渲染 | |
det_img = Image.fromarray(results.imgs[0]) # 检测图片 | |
det_json = export_json(results, model, img.size)[0] # 检测信息 | |
return det_img, det_json | |
def run_cmd(command): | |
try: | |
print(command) | |
call(command, shell=True) | |
except KeyboardInterrupt: | |
print("Process interrupted") | |
sys.exit(1) | |
run_cmd("gcc --version") | |
run_cmd("pwd") | |
run_cmd("ls") | |
run_cmd("git submodule update --init --recursive") | |
run_cmd("python setup.py install --user") | |
run_cmd("ls") | |
run_cmd("python main.py --config config/base.yaml --experiment experiment_5x1 --signature smile --target figures/smile.png --log_dir log/") | |
# yaml文件解析 | |
def yaml_parse(file_path): | |
return yaml.safe_load(open(file_path, "r", encoding="utf-8").read()) | |
# yaml csv 文件解析 | |
def yaml_csv(file_path, file_tag): | |
file_suffix = Path(file_path).suffix | |
if file_suffix == suffix_list[0]: | |
# 模型名称 | |
file_names = [i[0] for i in list(csv.reader(open(file_path)))] # csv版 | |
elif file_suffix == suffix_list[1]: | |
# 模型名称 | |
file_names = yaml_parse(file_path).get(file_tag) # yaml版 | |
else: | |
print(f"{file_path}格式不正确!程序退出!") | |
sys.exit() | |
return file_names | |
def main(args): | |
gr.close_all() | |
# -------------------Inputs------------------- | |
inputs_img = gr.inputs.Image(type="pil", label="Input Image") | |
experiment_id = gr.inputs.Radio( | |
choices=[ | |
"add [1, 1, 1, 1, 1] total 5 paths", | |
"add [1, 1, 1, 1, 1, 1, 1, 1] total 8 paths", | |
"add [1,2,4,8,16,32, ...] total 128 paths", | |
"add [1,2,4,8,16,32, ...] total 256 paths"], type="value", default="add [1,1,1,1,1] paths", label="Path Adding Scheduler" | |
) | |
# inputs | |
inputs = [ | |
inputs_img, # input image | |
experiment_id, # path adding scheduler | |
] | |
# outputs | |
outputs = gr.outputs.Image(type="pil", label="检测图片") | |
outputs02 = gr.outputs.JSON(label="检测信息") | |
# title | |
title = "LIVE: Towards Layer-wise Image Vectorization" | |
# description | |
description = "<div align='center'>(CVPR 2022 Oral Presentation)</div>" | |
# examples | |
examples = [ | |
[ | |
"./examples/1.png", | |
"add [1, 1, 1, 1, 1] total 5 paths", | |
], | |
[ | |
"./examples/2.png", | |
"add [1, 1, 1, 1, 1] total 5 paths", | |
], | |
[ | |
"./examples/3.jpg", | |
"add [1,2,4,8,16,32, ...] total 128 paths", | |
], | |
[ | |
"./examples/4.png", | |
"add [1,2,4,8,16,32, ...] total 256 paths", | |
], | |
[ | |
"./examples/5.png", | |
"add [1, 1, 1, 1, 1] total 5 paths", | |
], | |
] | |
# Interface | |
gr.Interface( | |
fn=yolo_det, | |
inputs=inputs, | |
outputs=[outputs, outputs02], | |
title=title, | |
description=description, | |
examples=examples, | |
theme="seafoam", | |
# live=True, # 实时变更输出 | |
flagging_dir="run" # 输出目录 | |
# ).launch(inbrowser=True, auth=['admin', 'admin']) | |
).launch( | |
inbrowser=True, # 自动打开默认浏览器 | |
show_tips=True, # 自动显示gradio最新功能 | |
# favicon_path="./icon/logo.ico", | |
) | |
if __name__ == "__main__": | |
args = parse_args() | |
main(args) | |