Spaces:
Runtime error
Runtime error
Zengyf-CVer
commited on
Commit
·
0b4e363
1
Parent(s):
268a7b8
add examples
Browse files
app.py
CHANGED
@@ -7,6 +7,8 @@
|
|
7 |
import argparse
|
8 |
import csv
|
9 |
import sys
|
|
|
|
|
10 |
|
11 |
import gradio as gr
|
12 |
import torch
|
@@ -25,6 +27,9 @@ model_name_tmp = ""
|
|
25 |
# 设备临时变量
|
26 |
device_tmp = ""
|
27 |
|
|
|
|
|
|
|
28 |
|
29 |
def parse_args(known=False):
|
30 |
parser = argparse.ArgumentParser(description="Gradio YOLOv5 Det v0.1")
|
@@ -144,7 +149,25 @@ def yaml_parse(file_path):
|
|
144 |
return yaml.safe_load(open(file_path, "r", encoding="utf-8").read())
|
145 |
|
146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
def main(args):
|
|
|
|
|
148 |
global model
|
149 |
|
150 |
slider_step = 0.05 # 滑动步长
|
@@ -157,15 +180,21 @@ def main(args):
|
|
157 |
cls_name = args.cls_name
|
158 |
device = args.device
|
159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
# 模型加载
|
161 |
model = model_loading(model_name, device)
|
162 |
-
# 模型名称
|
163 |
-
# model_names = [i[0] for i in list(csv.reader(open(model_cfg)))] # csv版
|
164 |
-
model_names = yaml_parse(model_cfg).get("model_names") # yaml版
|
165 |
|
166 |
-
|
167 |
-
|
168 |
-
model_cls_name = yaml_parse(cls_name).get("model_cls_name") # yaml版
|
169 |
|
170 |
# -------------------输入组件-------------------
|
171 |
inputs_img = gr.inputs.Image(type="pil", label="原始图片")
|
@@ -205,7 +234,36 @@ def main(args):
|
|
205 |
# 描述
|
206 |
description = "<div align='center'>可自定义目标检测模型、安装简单、使用方便</div>"
|
207 |
|
208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
# 接口
|
211 |
gr.Interface(
|
@@ -214,6 +272,7 @@ def main(args):
|
|
214 |
outputs=[outputs, outputs02],
|
215 |
title=title,
|
216 |
description=description,
|
|
|
217 |
theme="seafoam",
|
218 |
# live=True, # 实时变更输出
|
219 |
flagging_dir="run" # 输出目录
|
|
|
7 |
import argparse
|
8 |
import csv
|
9 |
import sys
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
|
13 |
import gradio as gr
|
14 |
import torch
|
|
|
27 |
# 设备临时变量
|
28 |
device_tmp = ""
|
29 |
|
30 |
+
# 文件后缀
|
31 |
+
suffix_list = [".csv", ".yaml"]
|
32 |
+
|
33 |
|
34 |
def parse_args(known=False):
|
35 |
parser = argparse.ArgumentParser(description="Gradio YOLOv5 Det v0.1")
|
|
|
149 |
return yaml.safe_load(open(file_path, "r", encoding="utf-8").read())
|
150 |
|
151 |
|
152 |
+
# yaml csv 文件解析
|
153 |
+
def yaml_csv(file_path, file_tag):
|
154 |
+
file_suffix = Path(file_path).suffix
|
155 |
+
if file_suffix == suffix_list[0]:
|
156 |
+
# 模型名称
|
157 |
+
file_names = [i[0] for i in list(csv.reader(open(file_path)))] # csv版
|
158 |
+
elif file_suffix == suffix_list[1]:
|
159 |
+
# 模型名称
|
160 |
+
file_names = yaml_parse(file_path).get(file_tag) # yaml版
|
161 |
+
else:
|
162 |
+
print(f"{file_path}格式不正确!程序退出!")
|
163 |
+
sys.exit()
|
164 |
+
|
165 |
+
return file_names
|
166 |
+
|
167 |
+
|
168 |
def main(args):
|
169 |
+
gr.close_all()
|
170 |
+
|
171 |
global model
|
172 |
|
173 |
slider_step = 0.05 # 滑动步长
|
|
|
180 |
cls_name = args.cls_name
|
181 |
device = args.device
|
182 |
|
183 |
+
# # 模型加载
|
184 |
+
# model = model_loading(model_name, device)
|
185 |
+
# # 模型名称
|
186 |
+
# # model_names = [i[0] for i in list(csv.reader(open(model_cfg)))] # csv版
|
187 |
+
# model_names = yaml_parse(model_cfg).get("model_names") # yaml版
|
188 |
+
|
189 |
+
# # 类别名称
|
190 |
+
# # model_cls_name = [i[0] for i in list(csv.reader(open(cls_name)))] # csv版
|
191 |
+
# model_cls_name = yaml_parse(cls_name).get("model_cls_name") # yaml版
|
192 |
+
|
193 |
# 模型加载
|
194 |
model = model_loading(model_name, device)
|
|
|
|
|
|
|
195 |
|
196 |
+
model_names = yaml_csv(model_cfg, "model_names")
|
197 |
+
model_cls_name = yaml_csv(cls_name, "model_cls_name")
|
|
|
198 |
|
199 |
# -------------------输入组件-------------------
|
200 |
inputs_img = gr.inputs.Image(type="pil", label="原始图片")
|
|
|
234 |
# 描述
|
235 |
description = "<div align='center'>可自定义目标检测模型、安装简单、使用方便</div>"
|
236 |
|
237 |
+
# 示例图片
|
238 |
+
examples = [
|
239 |
+
[
|
240 |
+
"./img_example/bus.jpg",
|
241 |
+
"cpu",
|
242 |
+
"yolov5s",
|
243 |
+
0.6,
|
244 |
+
0.5,
|
245 |
+
True,
|
246 |
+
["人", "公交车"],
|
247 |
+
],
|
248 |
+
[
|
249 |
+
"./img_example/Millenial-at-work.jpg",
|
250 |
+
"0",
|
251 |
+
"yolov5l",
|
252 |
+
0.5,
|
253 |
+
0.45,
|
254 |
+
True,
|
255 |
+
["人", "椅子", "杯子", "笔记本电脑"],
|
256 |
+
],
|
257 |
+
[
|
258 |
+
"./img_example/zidane.jpg",
|
259 |
+
"0",
|
260 |
+
"yolov5m",
|
261 |
+
0.25,
|
262 |
+
0.5,
|
263 |
+
False,
|
264 |
+
["人", "领带"],
|
265 |
+
],
|
266 |
+
]
|
267 |
|
268 |
# 接口
|
269 |
gr.Interface(
|
|
|
272 |
outputs=[outputs, outputs02],
|
273 |
title=title,
|
274 |
description=description,
|
275 |
+
examples=examples,
|
276 |
theme="seafoam",
|
277 |
# live=True, # 实时变更输出
|
278 |
flagging_dir="run" # 输出目录
|