LIVE / app.py
Xu Ma
Update app.py
434282c
raw
history blame
6.92 kB
import argparse
import csv
import sys
sys.path.append("./LIVE")
from pathlib import Path
import gradio as gr
import torch
import yaml
from PIL import Image
from subprocess import call
ROOT_PATH = sys.path[0] # 根目录
# 模型路径
model_path = "ultralytics/yolov5"
# 模型名称临时变量
model_name_tmp = ""
# 设备临时变量
device_tmp = ""
# 文件后缀
suffix_list = [".csv", ".yaml"]
def parse_args(known=False):
parser = argparse.ArgumentParser(description="Gradio LIVE")
parser.add_argument(
"--model_name", "-mn", default="yolov5s", type=str, help="model name"
)
parser.add_argument(
"--model_cfg",
"-mc",
default="./model_config/model_name_p5_all.yaml",
type=str,
help="model config",
)
parser.add_argument(
"--cls_name",
"-cls",
default="./cls_name/cls_name.yaml",
type=str,
help="cls name",
)
parser.add_argument(
"--nms_conf",
"-conf",
default=0.5,
type=float,
help="model NMS confidence threshold",
)
parser.add_argument(
"--nms_iou", "-iou", default=0.45, type=float, help="model NMS IoU threshold"
)
parser.add_argument(
"--label_dnt_show",
"-lds",
action="store_false",
default=True,
help="label show",
)
parser.add_argument(
"--device",
"-dev",
default="cpu",
type=str,
help="cuda or cpu, hugging face only cpu",
)
parser.add_argument(
"--inference_size", "-isz", default=640, type=int, help="model inference size"
)
args = parser.parse_known_args()[0] if known else parser.parse_args()
return args
# 模型加载
def model_loading(model_name, device):
# 加载本地模型
model = torch.hub.load(model_path, model_name, force_reload=True, device=device)
return model
# 检测信息
def export_json(results, model, img_size):
return [
[
{
"id": int(i),
"class": int(result[i][5]),
"class_name": model.model.names[int(result[i][5])],
"normalized_box": {
"x0": round(result[i][:4].tolist()[0], 6),
"y0": round(result[i][:4].tolist()[1], 6),
"x1": round(result[i][:4].tolist()[2], 6),
"y1": round(result[i][:4].tolist()[3], 6),
},
"confidence": round(float(result[i][4]), 2),
"fps": round(1000 / float(results.t[1]), 2),
"width": img_size[0],
"height": img_size[1],
}
for i in range(len(result))
]
for result in results.xyxyn
]
def yolo_det(img, experiment_id, device=None, model_name=None, inference_size=None, conf=None, iou=None, label_opt=None, model_cls=None):
global model, model_name_tmp, device_tmp
if model_name_tmp != model_name:
# 模型判断,避免反复加载
model_name_tmp = model_name
model = model_loading(model_name_tmp, device)
elif device_tmp != device:
device_tmp = device
model = model_loading(model_name_tmp, device)
# -----------模型调参-----------
model.conf = conf # NMS 置信度阈值
model.iou = iou # NMS IOU阈值
model.max_det = 1000 # 最大检测框数
model.classes = model_cls # 模型类别
results = model(img, size=inference_size) # 检测
results.render(labels=label_opt) # 渲染
det_img = Image.fromarray(results.imgs[0]) # 检测图片
det_json = export_json(results, model, img.size)[0] # 检测信息
return det_img, det_json
def run_cmd(command):
try:
print(command)
call(command, shell=True)
except KeyboardInterrupt:
print("Process interrupted")
sys.exit(1)
run_cmd("gcc --version")
run_cmd("pwd")
run_cmd("ls")
run_cmd("git submodule update --init --recursive")
run_cmd("python setup.py install --user")
run_cmd("ls")
run_cmd("python main.py --config config/base.yaml --experiment experiment_5x1 --signature smile --target figures/smile.png --log_dir log/")
# yaml文件解析
def yaml_parse(file_path):
return yaml.safe_load(open(file_path, "r", encoding="utf-8").read())
# yaml csv 文件解析
def yaml_csv(file_path, file_tag):
file_suffix = Path(file_path).suffix
if file_suffix == suffix_list[0]:
# 模型名称
file_names = [i[0] for i in list(csv.reader(open(file_path)))] # csv版
elif file_suffix == suffix_list[1]:
# 模型名称
file_names = yaml_parse(file_path).get(file_tag) # yaml版
else:
print(f"{file_path}格式不正确!程序退出!")
sys.exit()
return file_names
def main(args):
gr.close_all()
# -------------------Inputs-------------------
inputs_img = gr.inputs.Image(type="pil", label="Input Image")
experiment_id = gr.inputs.Radio(
choices=[
"add [1, 1, 1, 1, 1] total 5 paths",
"add [1, 1, 1, 1, 1, 1, 1, 1] total 8 paths",
"add [1,2,4,8,16,32, ...] total 128 paths",
"add [1,2,4,8,16,32, ...] total 256 paths"], type="value", default="add [1,1,1,1,1] paths", label="Path Adding Scheduler"
)
# inputs
inputs = [
inputs_img, # input image
experiment_id, # path adding scheduler
]
# outputs
outputs = gr.outputs.Image(type="pil", label="检测图片")
outputs02 = gr.outputs.JSON(label="检测信息")
# title
title = "LIVE: Towards Layer-wise Image Vectorization"
# description
description = "<div align='center'>(CVPR 2022 Oral Presentation)</div>"
# examples
examples = [
[
"./examples/1.png",
"add [1, 1, 1, 1, 1] total 5 paths",
],
[
"./examples/2.png",
"add [1, 1, 1, 1, 1] total 5 paths",
],
[
"./examples/3.jpg",
"add [1,2,4,8,16,32, ...] total 128 paths",
],
[
"./examples/4.png",
"add [1,2,4,8,16,32, ...] total 256 paths",
],
[
"./examples/5.png",
"add [1, 1, 1, 1, 1] total 5 paths",
],
]
# Interface
gr.Interface(
fn=yolo_det,
inputs=inputs,
outputs=[outputs, outputs02],
title=title,
description=description,
examples=examples,
theme="seafoam",
# live=True, # 实时变更输出
flagging_dir="run" # 输出目录
# ).launch(inbrowser=True, auth=['admin', 'admin'])
).launch(
inbrowser=True, # 自动打开默认浏览器
show_tips=True, # 自动显示gradio最新功能
# favicon_path="./icon/logo.ico",
)
if __name__ == "__main__":
args = parse_args()
main(args)