Zengyf-CVer commited on
Commit
4766c77
·
1 Parent(s): 7291007

app update

Browse files
Files changed (1) hide show
  1. app.py +70 -22
app.py CHANGED
@@ -1,14 +1,16 @@
1
- # Gradio YOLOv8 Det v0.1
2
  # 创建人:曾逸夫
3
- # 创建时间:2023-01-15
4
-
5
- import os
6
 
7
  import argparse
8
  import csv
 
9
  import sys
10
 
11
  from ultralytics import YOLO
 
 
 
12
  import gc
13
  import json
14
  import random
@@ -34,22 +36,16 @@ SimSun = font_manager.FontProperties(fname=SimSun_path, size=12)
34
  # 新罗马字体
35
  TimesNesRoman = font_manager.FontProperties(fname=TimesNesRoman_path, size=12)
36
 
37
- from copy import deepcopy
38
-
39
  import yaml
40
  from PIL import Image, ImageDraw, ImageFont
41
 
42
  from util.fonts_opt import is_fonts
43
- from util.pdf_opt import pdf_generate
44
 
45
  ROOT_PATH = sys.path[0] # 根目录
46
 
47
- # 本地模型路径
48
- local_model_path = f"{ROOT_PATH}/models"
49
- local_model_path_02 = f"{ROOT_PATH}/yolov5"
50
-
51
  # Gradio YOLOv8 Det版本
52
- GYD_VERSION = "Gradio YOLOv8 Det v0.1"
53
 
54
  # 文件后缀
55
  suffix_list = [".csv", ".yaml"]
@@ -62,7 +58,7 @@ obj_style = ["小目标", "中目标", "大目标"]
62
 
63
 
64
  def parse_args(known=False):
65
- parser = argparse.ArgumentParser(description="Gradio YOLOv8 Det v0.1")
66
  parser.add_argument("--model_type", "-mt", default="online", type=str, help="model type")
67
  parser.add_argument("--source", "-src", default="upload", type=str, help="image input source")
68
  parser.add_argument("--source_video", "-src_v", default="upload", type=str, help="video input source")
@@ -194,13 +190,11 @@ def random_color(cls_num, is_light=True):
194
 
195
 
196
  # 检测绘制
197
- def pil_draw(img_path, score_l, bbox_l, cls_l, cls_index_l, textFont, color_list):
198
- img = Image.open(img_path)
199
  img_pil = ImageDraw.Draw(img)
200
  id = 0
201
 
202
  for score, (xmin, ymin, xmax, ymax), label, cls_index in zip(score_l, bbox_l, cls_l, cls_index_l):
203
-
204
  img_pil.rectangle([xmin, ymin, xmax, ymax], fill=None, outline=color_list[cls_index], width=2) # 边界框
205
  countdown_msg = f"{id}-{label} {score:.2f}"
206
  text_w, text_h = textFont.getsize(countdown_msg) # 标签尺寸
@@ -226,6 +220,42 @@ def pil_draw(img_path, score_l, bbox_l, cls_l, cls_index_l, textFont, color_list
226
  return img
227
 
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  # YOLOv5图片检测函数
230
  def yolo_det_img(img_path, model_name, infer_size, conf, iou):
231
 
@@ -242,12 +272,22 @@ def yolo_det_img(img_path, model_name, infer_size, conf, iou):
242
 
243
  # 模型加载
244
  predict_results = model_loading(img_path, conf, iou, infer_size, yolo_model=f"{model_name}.pt")
 
245
  xyxy_list = predict_results.boxes.xyxy.cpu().numpy().tolist()
246
  conf_list = predict_results.boxes.conf.cpu().numpy().tolist()
247
  cls_list = predict_results.boxes.cls.cpu().numpy().tolist()
248
 
 
249
  color_list = random_color(len(model_cls_name_cp), True)
250
 
 
 
 
 
 
 
 
 
251
  # 判断检测对象是否为空
252
  if (xyxy_list != []):
253
 
@@ -266,7 +306,6 @@ def yolo_det_img(img_path, model_name, infer_size, conf, iou):
266
  textFont = ImageFont.truetype(str(f"{ROOT_PATH}/fonts/malgun.ttf"), size=FONTSIZE)
267
 
268
  for i in range(len(xyxy_list)):
269
- # id = int(i) # 实例ID
270
  obj_cls_index = int(cls_list[i]) # 类别索引
271
  cls_index_det_stat.append(obj_cls_index)
272
 
@@ -284,16 +323,13 @@ def yolo_det_img(img_path, model_name, infer_size, conf, iou):
284
  conf = float(conf_list[i]) # 置信度
285
  score_det_stat.append(conf)
286
 
287
- # fps = f"{(1000 / float(results.t[1])):.2f}" # FPS
288
-
289
  # ---------- 加入目标尺寸 ----------
290
  w_obj = x1 - x0
291
  h_obj = y1 - y0
292
  area_obj = w_obj * h_obj
293
  area_obj_all.append(area_obj)
294
 
295
- det_img = pil_draw(img_path, score_det_stat, bbox_det_stat, cls_det_stat, cls_index_det_stat, textFont,
296
- color_list)
297
 
298
  # -------------- 目标尺寸计算 --------------
299
  for i in range(len(area_obj_all)):
@@ -372,7 +408,7 @@ def main(args):
372
  title = "Gradio YOLOv8 Det"
373
 
374
  # 描述
375
- description = "Author: 曾逸夫(Zeng Yifu), Github: https://github.com/Zengyf-CVer, thanks to [Gradio](https://github.com/gradio-app/gradio) & [YOLOv8](https://github.com/ultralytics/ultralytics)"
376
 
377
  # 示例图片
378
  examples_imgs = [
@@ -399,6 +435,18 @@ def main(args):
399
  "yolov8x",
400
  1280,
401
  0.5,
 
 
 
 
 
 
 
 
 
 
 
 
402
  0.5,],]
403
 
404
  # 接口
 
1
+ # Gradio YOLOv8 Det v0.2
2
  # 创建人:曾逸夫
3
+ # 创建时间:2023-01-20
 
 
4
 
5
  import argparse
6
  import csv
7
+ import os
8
  import sys
9
 
10
  from ultralytics import YOLO
11
+
12
+ csv.field_size_limit(sys.maxsize)
13
+
14
  import gc
15
  import json
16
  import random
 
36
  # 新罗马字体
37
  TimesNesRoman = font_manager.FontProperties(fname=TimesNesRoman_path, size=12)
38
 
39
+ import torch
 
40
  import yaml
41
  from PIL import Image, ImageDraw, ImageFont
42
 
43
  from util.fonts_opt import is_fonts
 
44
 
45
  ROOT_PATH = sys.path[0] # 根目录
46
 
 
 
 
 
47
  # Gradio YOLOv8 Det版本
48
+ GYD_VERSION = "Gradio YOLOv8 Det v0.2"
49
 
50
  # 文件后缀
51
  suffix_list = [".csv", ".yaml"]
 
58
 
59
 
60
  def parse_args(known=False):
61
+ parser = argparse.ArgumentParser(description="Gradio YOLOv8 Det v0.2")
62
  parser.add_argument("--model_type", "-mt", default="online", type=str, help="model type")
63
  parser.add_argument("--source", "-src", default="upload", type=str, help="image input source")
64
  parser.add_argument("--source_video", "-src_v", default="upload", type=str, help="video input source")
 
190
 
191
 
192
  # 检测绘制
193
+ def pil_draw(img, score_l, bbox_l, cls_l, cls_index_l, textFont, color_list):
 
194
  img_pil = ImageDraw.Draw(img)
195
  id = 0
196
 
197
  for score, (xmin, ymin, xmax, ymax), label, cls_index in zip(score_l, bbox_l, cls_l, cls_index_l):
 
198
  img_pil.rectangle([xmin, ymin, xmax, ymax], fill=None, outline=color_list[cls_index], width=2) # 边界框
199
  countdown_msg = f"{id}-{label} {score:.2f}"
200
  text_w, text_h = textFont.getsize(countdown_msg) # 标签尺寸
 
220
  return img
221
 
222
 
223
+ # 绘制多边形
224
+ def polygon_drawing(img_mask, canvas, color_seg):
225
+ # ------- RGB转BGR -------
226
+ color_seg = list(color_seg)
227
+ color_seg[0], color_seg[2] = color_seg[2], color_seg[0]
228
+ color_seg = tuple(color_seg)
229
+ # 定义多边形的顶点
230
+ pts = np.array(img_mask, dtype=np.int32)
231
+
232
+ # 多边形绘制
233
+ cv2.drawContours(canvas, [pts], -1, color_seg, thickness=-1)
234
+
235
+
236
+ # 输出分割结果
237
+ def seg_output(img_path, seg_mask_list, color_list, cls_list):
238
+ img = cv2.imread(img_path)
239
+ w, h = img.shape[1], img.shape[0]
240
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2BGRA)
241
+
242
+ w, h = img.shape[1], img.shape[0]
243
+ canvas = np.zeros((h, w, 3), dtype=np.uint8)
244
+ canvas = cv2.cvtColor(canvas, cv2.COLOR_BGR2BGRA)
245
+
246
+ # 获取分割坐标
247
+ for seg_mask, cls_index in zip(seg_mask_list, cls_list):
248
+ img_mask = []
249
+ for i in range(len(seg_mask)):
250
+ img_mask.append([seg_mask[i][0] * w, seg_mask[i][1] * h])
251
+
252
+ polygon_drawing(img_mask, canvas, color_list[int(cls_index)]) # 绘制分割图形
253
+
254
+ img_mask_merge = cv2.add(img, canvas) # 合并图像
255
+
256
+ return img_mask_merge
257
+
258
+
259
  # YOLOv5图片检测函数
260
  def yolo_det_img(img_path, model_name, infer_size, conf, iou):
261
 
 
272
 
273
  # 模型加载
274
  predict_results = model_loading(img_path, conf, iou, infer_size, yolo_model=f"{model_name}.pt")
275
+ # 检测参数
276
  xyxy_list = predict_results.boxes.xyxy.cpu().numpy().tolist()
277
  conf_list = predict_results.boxes.conf.cpu().numpy().tolist()
278
  cls_list = predict_results.boxes.cls.cpu().numpy().tolist()
279
 
280
+ # 颜色列表
281
  color_list = random_color(len(model_cls_name_cp), True)
282
 
283
+ # 图像分割
284
+ if (model_name[-3:] == "seg"):
285
+ masks_list = predict_results.masks.segments
286
+ img_mask_merge = seg_output(img_path, masks_list, color_list, cls_list)
287
+ img = Image.fromarray(cv2.cvtColor(img_mask_merge, cv2.COLOR_BGRA2RGBA))
288
+ else:
289
+ img = Image.open(img_path)
290
+
291
  # 判断检测对象是否为空
292
  if (xyxy_list != []):
293
 
 
306
  textFont = ImageFont.truetype(str(f"{ROOT_PATH}/fonts/malgun.ttf"), size=FONTSIZE)
307
 
308
  for i in range(len(xyxy_list)):
 
309
  obj_cls_index = int(cls_list[i]) # 类别索引
310
  cls_index_det_stat.append(obj_cls_index)
311
 
 
323
  conf = float(conf_list[i]) # 置信度
324
  score_det_stat.append(conf)
325
 
 
 
326
  # ---------- 加入目标尺寸 ----------
327
  w_obj = x1 - x0
328
  h_obj = y1 - y0
329
  area_obj = w_obj * h_obj
330
  area_obj_all.append(area_obj)
331
 
332
+ det_img = pil_draw(img, score_det_stat, bbox_det_stat, cls_det_stat, cls_index_det_stat, textFont, color_list)
 
333
 
334
  # -------------- 目标尺寸计算 --------------
335
  for i in range(len(area_obj_all)):
 
408
  title = "Gradio YOLOv8 Det"
409
 
410
  # 描述
411
+ description = "<div align='center'>Object detection and image segmentation system based on YOLOv8</div><div align='center'>Author: 曾逸夫(Zeng Yifu), Github: https://github.com/Zengyf-CVer, thanks to [Gradio](https://github.com/gradio-app/gradio) & [YOLOv8](https://github.com/ultralytics/ultralytics)</div>"
412
 
413
  # 示例图片
414
  examples_imgs = [
 
435
  "yolov8x",
436
  1280,
437
  0.5,
438
+ 0.5,],
439
+ [
440
+ "./img_examples/bus.jpg",
441
+ "yolov8s-seg",
442
+ 640,
443
+ 0.6,
444
+ 0.5,],
445
+ [
446
+ "./img_examples/Millenial-at-work.jpg",
447
+ "yolov8x-seg",
448
+ 1280,
449
+ 0.5,
450
  0.5,],]
451
 
452
  # 接口