Zengyf-CVer commited on
Commit
0b4e363
·
1 Parent(s): 268a7b8

add examples

Browse files
Files changed (1) hide show
  1. app.py +66 -7
app.py CHANGED
@@ -7,6 +7,8 @@
7
  import argparse
8
  import csv
9
  import sys
 
 
10
 
11
  import gradio as gr
12
  import torch
@@ -25,6 +27,9 @@ model_name_tmp = ""
25
  # 设备临时变量
26
  device_tmp = ""
27
 
 
 
 
28
 
29
  def parse_args(known=False):
30
  parser = argparse.ArgumentParser(description="Gradio YOLOv5 Det v0.1")
@@ -144,7 +149,25 @@ def yaml_parse(file_path):
144
  return yaml.safe_load(open(file_path, "r", encoding="utf-8").read())
145
 
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  def main(args):
 
 
148
  global model
149
 
150
  slider_step = 0.05 # 滑动步长
@@ -157,15 +180,21 @@ def main(args):
157
  cls_name = args.cls_name
158
  device = args.device
159
 
 
 
 
 
 
 
 
 
 
 
160
  # 模型加载
161
  model = model_loading(model_name, device)
162
- # 模型名称
163
- # model_names = [i[0] for i in list(csv.reader(open(model_cfg)))] # csv版
164
- model_names = yaml_parse(model_cfg).get("model_names") # yaml版
165
 
166
- # 类别名称
167
- # model_cls_name = [i[0] for i in list(csv.reader(open(cls_name)))] # csv版
168
- model_cls_name = yaml_parse(cls_name).get("model_cls_name") # yaml版
169
 
170
  # -------------------输入组件-------------------
171
  inputs_img = gr.inputs.Image(type="pil", label="原始图片")
@@ -205,7 +234,36 @@ def main(args):
205
  # 描述
206
  description = "<div align='center'>可自定义目标检测模型、安装简单、使用方便</div>"
207
 
208
- gr.close_all()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  # 接口
211
  gr.Interface(
@@ -214,6 +272,7 @@ def main(args):
214
  outputs=[outputs, outputs02],
215
  title=title,
216
  description=description,
 
217
  theme="seafoam",
218
  # live=True, # 实时变更输出
219
  flagging_dir="run" # 输出目录
 
7
  import argparse
8
  import csv
9
  import sys
10
+ from pathlib import Path
11
+
12
 
13
  import gradio as gr
14
  import torch
 
27
  # 设备临时变量
28
  device_tmp = ""
29
 
30
+ # 文件后缀
31
+ suffix_list = [".csv", ".yaml"]
32
+
33
 
34
  def parse_args(known=False):
35
  parser = argparse.ArgumentParser(description="Gradio YOLOv5 Det v0.1")
 
149
  return yaml.safe_load(open(file_path, "r", encoding="utf-8").read())
150
 
151
 
152
+ # yaml csv 文件解析
153
+ def yaml_csv(file_path, file_tag):
154
+ file_suffix = Path(file_path).suffix
155
+ if file_suffix == suffix_list[0]:
156
+ # 模型名称
157
+ file_names = [i[0] for i in list(csv.reader(open(file_path)))] # csv版
158
+ elif file_suffix == suffix_list[1]:
159
+ # 模型名称
160
+ file_names = yaml_parse(file_path).get(file_tag) # yaml版
161
+ else:
162
+ print(f"{file_path}格式不正确!程序退出!")
163
+ sys.exit()
164
+
165
+ return file_names
166
+
167
+
168
  def main(args):
169
+ gr.close_all()
170
+
171
  global model
172
 
173
  slider_step = 0.05 # 滑动步长
 
180
  cls_name = args.cls_name
181
  device = args.device
182
 
183
+ # # 模型加载
184
+ # model = model_loading(model_name, device)
185
+ # # 模型名称
186
+ # # model_names = [i[0] for i in list(csv.reader(open(model_cfg)))] # csv版
187
+ # model_names = yaml_parse(model_cfg).get("model_names") # yaml版
188
+
189
+ # # 类别名称
190
+ # # model_cls_name = [i[0] for i in list(csv.reader(open(cls_name)))] # csv版
191
+ # model_cls_name = yaml_parse(cls_name).get("model_cls_name") # yaml版
192
+
193
  # 模型加载
194
  model = model_loading(model_name, device)
 
 
 
195
 
196
+ model_names = yaml_csv(model_cfg, "model_names")
197
+ model_cls_name = yaml_csv(cls_name, "model_cls_name")
 
198
 
199
  # -------------------输入组件-------------------
200
  inputs_img = gr.inputs.Image(type="pil", label="原始图片")
 
234
  # 描述
235
  description = "<div align='center'>可自定义目标检测模型、安装简单、使用方便</div>"
236
 
237
+ # 示例图片
238
+ examples = [
239
+ [
240
+ "./img_example/bus.jpg",
241
+ "cpu",
242
+ "yolov5s",
243
+ 0.6,
244
+ 0.5,
245
+ True,
246
+ ["人", "公交车"],
247
+ ],
248
+ [
249
+ "./img_example/Millenial-at-work.jpg",
250
+ "0",
251
+ "yolov5l",
252
+ 0.5,
253
+ 0.45,
254
+ True,
255
+ ["人", "椅子", "杯子", "笔记本电脑"],
256
+ ],
257
+ [
258
+ "./img_example/zidane.jpg",
259
+ "0",
260
+ "yolov5m",
261
+ 0.25,
262
+ 0.5,
263
+ False,
264
+ ["人", "领带"],
265
+ ],
266
+ ]
267
 
268
  # 接口
269
  gr.Interface(
 
272
  outputs=[outputs, outputs02],
273
  title=title,
274
  description=description,
275
+ examples=examples,
276
  theme="seafoam",
277
  # live=True, # 实时变更输出
278
  flagging_dir="run" # 输出目录