import argparse import csv import sys sys.path.append("./LIVE") from pathlib import Path import gradio as gr import torch import yaml from PIL import Image from subprocess import call ROOT_PATH = sys.path[0] # 根目录 # 模型路径 model_path = "ultralytics/yolov5" # 模型名称临时变量 model_name_tmp = "" # 设备临时变量 device_tmp = "" # 文件后缀 suffix_list = [".csv", ".yaml"] def parse_args(known=False): parser = argparse.ArgumentParser(description="Gradio LIVE") parser.add_argument( "--model_name", "-mn", default="yolov5s", type=str, help="model name" ) parser.add_argument( "--model_cfg", "-mc", default="./model_config/model_name_p5_all.yaml", type=str, help="model config", ) parser.add_argument( "--cls_name", "-cls", default="./cls_name/cls_name.yaml", type=str, help="cls name", ) parser.add_argument( "--nms_conf", "-conf", default=0.5, type=float, help="model NMS confidence threshold", ) parser.add_argument( "--nms_iou", "-iou", default=0.45, type=float, help="model NMS IoU threshold" ) parser.add_argument( "--label_dnt_show", "-lds", action="store_false", default=True, help="label show", ) parser.add_argument( "--device", "-dev", default="cpu", type=str, help="cuda or cpu, hugging face only cpu", ) parser.add_argument( "--inference_size", "-isz", default=640, type=int, help="model inference size" ) args = parser.parse_known_args()[0] if known else parser.parse_args() return args # 模型加载 def model_loading(model_name, device): # 加载本地模型 model = torch.hub.load(model_path, model_name, force_reload=True, device=device) return model # 检测信息 def export_json(results, model, img_size): return [ [ { "id": int(i), "class": int(result[i][5]), "class_name": model.model.names[int(result[i][5])], "normalized_box": { "x0": round(result[i][:4].tolist()[0], 6), "y0": round(result[i][:4].tolist()[1], 6), "x1": round(result[i][:4].tolist()[2], 6), "y1": round(result[i][:4].tolist()[3], 6), }, "confidence": round(float(result[i][4]), 2), "fps": round(1000 / float(results.t[1]), 2), "width": img_size[0], "height": img_size[1], } for i in range(len(result)) ] for result in results.xyxyn ] def yolo_det(img, experiment_id, device=None, model_name=None, inference_size=None, conf=None, iou=None, label_opt=None, model_cls=None): global model, model_name_tmp, device_tmp if model_name_tmp != model_name: # 模型判断,避免反复加载 model_name_tmp = model_name model = model_loading(model_name_tmp, device) elif device_tmp != device: device_tmp = device model = model_loading(model_name_tmp, device) # -----------模型调参----------- model.conf = conf # NMS 置信度阈值 model.iou = iou # NMS IOU阈值 model.max_det = 1000 # 最大检测框数 model.classes = model_cls # 模型类别 results = model(img, size=inference_size) # 检测 results.render(labels=label_opt) # 渲染 det_img = Image.fromarray(results.imgs[0]) # 检测图片 det_json = export_json(results, model, img.size)[0] # 检测信息 return det_img, det_json def run_cmd(command): try: print(command) call(command, shell=True) except KeyboardInterrupt: print("Process interrupted") sys.exit(1) run_cmd("gcc --version") run_cmd("pwd") run_cmd("ls") run_cmd("git submodule update --init --recursive") run_cmd("python setup.py install --user") run_cmd("ls") run_cmd("python main.py --config config/base.yaml --experiment experiment_5x1 --signature smile --target figures/smile.png --log_dir log/") # yaml文件解析 def yaml_parse(file_path): return yaml.safe_load(open(file_path, "r", encoding="utf-8").read()) # yaml csv 文件解析 def yaml_csv(file_path, file_tag): file_suffix = Path(file_path).suffix if file_suffix == suffix_list[0]: # 模型名称 file_names = [i[0] for i in list(csv.reader(open(file_path)))] # csv版 elif file_suffix == suffix_list[1]: # 模型名称 file_names = yaml_parse(file_path).get(file_tag) # yaml版 else: print(f"{file_path}格式不正确!程序退出!") sys.exit() return file_names def main(args): gr.close_all() # -------------------Inputs------------------- inputs_img = gr.inputs.Image(type="pil", label="Input Image") experiment_id = gr.inputs.Radio( choices=[ "add [1, 1, 1, 1, 1] total 5 paths", "add [1, 1, 1, 1, 1, 1, 1, 1] total 8 paths", "add [1,2,4,8,16,32, ...] total 128 paths", "add [1,2,4,8,16,32, ...] total 256 paths"], type="value", default="add [1,1,1,1,1] paths", label="Path Adding Scheduler" ) # inputs inputs = [ inputs_img, # input image experiment_id, # path adding scheduler ] # outputs outputs = gr.outputs.Image(type="pil", label="检测图片") outputs02 = gr.outputs.JSON(label="检测信息") # title title = "LIVE: Towards Layer-wise Image Vectorization" # description description = "