blitzkrieg0000 commited on
Commit
7dc5008
·
verified ·
1 Parent(s): 3ddfba1

Update Predict.py

Browse files
Files changed (1) hide show
  1. Predict.py +125 -29
Predict.py CHANGED
@@ -1,57 +1,153 @@
1
  import cv2
2
- from matplotlib import pyplot as plt
3
  import numpy as np
4
- from ultralytics import YOLO
5
  import torch
 
 
 
 
 
6
 
7
  # Data
8
- test01 = "data/16_3450.png"
9
- test_image = test01
 
 
10
 
11
  # Load a model
12
- model = YOLO("weight/yolov9c-cable-seg.pt") # load a custom model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
  with torch.no_grad():
16
  results = model(
17
- test_image,
18
- save=True,
19
  show_boxes=False,
20
- project="./result/",
21
  conf=0.5,
 
22
  retina_masks=False
23
  )
24
 
 
 
25
 
26
- with torch.no_grad():
27
- for result in results:
28
- masks = result.masks.data
29
- boxes = result.boxes.data
30
 
31
- #ALL
32
- canvas = torch.any(masks, dim=0).int() * 255
33
 
34
- clss = boxes[:, 5]
35
- obj_indices = torch.where(clss == 4)
 
 
 
 
 
 
36
 
37
- # Cable
38
- obj_masks = masks[obj_indices]
39
- obj_mask = torch.any(obj_masks, dim=0).int() * 255
40
- # cropped_image = result.orig_img[obj_mask.cpu().numpy()]
41
 
42
-
43
- fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(27,15))
44
- axs[0][0].imshow(result.orig_img)
45
  axs[0][0].set_title("Orijinal Görüntü")
46
 
47
- axs[0][1].imshow(canvas.cpu().numpy())
48
  axs[0][1].set_title("Segmentasyon Maskesi")
49
 
50
- mask = np.array(obj_mask.cpu().numpy())*255
51
- cv2.imwrite("cable_mask.png", mask)
52
- axs[1][0].imshow(obj_mask.cpu().numpy())
53
- axs[1][0].set_title("Seçilen")
54
 
55
- axs[1][1].imshow(result.plot())
56
  axs[1][1].set_title("Sonuç")
 
 
 
 
 
57
  plt.show()
 
1
  import cv2
 
2
  import numpy as np
 
3
  import torch
4
+ from matplotlib import pyplot as plt
5
+ from ultralytics import YOLO
6
+
7
+
8
+ DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
9
 
10
  # Data
11
+ test_image = "data/DJI_20240905091530_0003_W.JPG"
12
+
13
+ LABELS = {0: "Boş", 1: "Çelik Direkler", 2: "Kafes Kule", 3: "Kablo", 4: "Ahşap Kule"}
14
+ colorMap = {"Boş":"#ffffff", "Çelik Direkler":"#0000ff", "Kafes Kule":"#ff0000", "Kablo":"#00ff00", "Ahşap Kule":"#ff0000"}
15
 
16
  # Load a model
17
+ model = YOLO("Weight/yolov9c-cable-seg.pt") # load a custom model
18
+ model.fuse()
19
+
20
+
21
+ def ParseResults(results, threshold=0.5, scale_masks=True):
22
+ batches = []
23
+
24
+ SCORES = torch.Tensor([]).to(DEVICE)
25
+ CLASSES = torch.Tensor([]).to(DEVICE)
26
+ MASKS = torch.Tensor([]).to(DEVICE)
27
+ BOXES = torch.Tensor([]).to(DEVICE)
28
+
29
+ with torch.no_grad():
30
+ for result in results:
31
+ original_shape = result.orig_shape
32
+ _scores = result.boxes.conf # 7
33
+ _classes = result.boxes.cls # 7
34
+ _masks = result.masks.data # 7, 480, 640
35
+ _boxes = result.boxes.xyxy # 7, 4
36
+
37
+ # Threshold Filter
38
+ conditions = _scores > threshold
39
+ SCORES = torch.cat((SCORES, _scores[conditions]), dim=0)
40
+ CLASSES = torch.cat((CLASSES, _classes[conditions]), dim=0)
41
+ BOXES = torch.cat((BOXES, _boxes[conditions]), dim=0)
42
+ mask = _masks[conditions]
43
+ if scale_masks:
44
+ mask = ScaleMasks(mask, original_shape[:2])
45
+
46
+ MASKS = torch.cat((MASKS, mask), dim=0)
47
+
48
+ batches += [(SCORES, CLASSES, MASKS, BOXES)]
49
+
50
+ return batches
51
+
52
+
53
+ def ScaleMasks(masks: torch.Tensor, shape: tuple) -> torch.Tensor:
54
+ masks = masks.unsqueeze(0)
55
+ interpolatedMask:torch.Tensor = torch.nn.functional.interpolate(masks, shape, mode="nearest")
56
+ interpolatedMask = interpolatedMask.squeeze(0)
57
+ return interpolatedMask
58
+
59
+
60
+ def DrawResults(image, scores: torch.Tensor, classes: torch.Tensor, masks: torch.Tensor, boxes: torch.Tensor, labels:dict=LABELS, class_filter:list=None):
61
+ _image = np.array(image).copy()
62
+ _image = cv2.cvtColor(_image, cv2.COLOR_BGR2RGB)
63
+ maskCanvas = np.zeros_like(_image)
64
+
65
+
66
+ with torch.no_grad():
67
+ scores = scores.cpu().numpy()
68
+ classes = classes.cpu().numpy().astype(np.int32)
69
+ masks = masks.cpu().numpy()
70
+ boxes = boxes.cpu().numpy()
71
+
72
+ for score, cls, mask, box in zip(scores, classes, masks, boxes):
73
+ label = labels[cls]
74
+
75
+ if class_filter and cls not in class_filter:
76
+ continue
77
+
78
+ box = box.astype(np.int32)
79
+ mask = cv2.cvtColor(mask*255, cv2.COLOR_GRAY2BGR).astype(np.uint8)
80
+ maskCanvas = cv2.addWeighted(maskCanvas, 1.0, mask, 1.0, 0)
81
+ maskCanvas = cv2.rectangle(maskCanvas, (box[0], box[1]), (box[2], box[3]), color=(255, 0, 0), thickness=3) # Red color for bounding box
82
+ maskCanvas = cv2.putText(maskCanvas, f"{label} : {score:.2f}", (box[0], box[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color=(255, 0, 0), thickness=2)
83
+
84
+ canvas = cv2.addWeighted(_image, 1.0, maskCanvas.astype(np.uint8), 1.0, 0)
85
+ return canvas, maskCanvas
86
+
87
+
88
+ def RescaleTheMask(orijinal_image, masks):
89
+ _masks = []
90
+ for contour in masks:
91
+ b_mask = np.zeros(orijinal_image.shape[:2], np.uint8)
92
+ contour = contour.astype(np.int32)
93
+ # contour = contour.reshape(-1, 1, 2)
94
+
95
+ w = orijinal_image.shape[0]
96
+ h = orijinal_image.shape[1]
97
+
98
+ mask = cv2.drawContours(b_mask, [contour], -1, (1, 1, 1), cv2.FILLED)
99
+ _masks += [mask]
100
+
101
+ return _masks
102
+
103
+
104
+
105
+ image = cv2.imread(test_image)
106
 
107
 
108
  with torch.no_grad():
109
  results = model(
110
+ image,
111
+ save=False,
112
  show_boxes=False,
113
+ project="./inference/",
114
  conf=0.5,
115
+ iou=0.7,
116
  retina_masks=False
117
  )
118
 
119
+ batches = ParseResults(results, threshold=0.5, scale_masks=True)
120
+ scores, classes, masks, boxes = batches[0]
121
 
122
+ canvas, mask = DrawResults(image, scores, classes, masks, boxes, class_filter=[3])
 
 
 
123
 
 
 
124
 
125
+ # ALL Segmentation
126
+ # canvas = torch.any(result.masks.data, dim=0).int() * 255
127
+
128
+ # Instance Segmentation
129
+ # objIdx = torch.where(result.boxes.cls.data == 3)
130
+ # objMasks = result.masks.data[objIdx]
131
+ # obj_mask = torch.any(objMasks, dim=0).int() * 255
132
+
133
 
 
 
 
 
134
 
135
+ #! Plot
136
+ fig, axs = plt.subplots(2, 2, figsize=(27, 15))
137
+ axs[0][0].imshow(image)
138
  axs[0][0].set_title("Orijinal Görüntü")
139
 
140
+ axs[0][1].imshow(mask)
141
  axs[0][1].set_title("Segmentasyon Maskesi")
142
 
143
+ # axs[1][0].imshow(obj_mask.cpu().numpy())
144
+ # axs[1][0].set_title("Seçilen")
 
 
145
 
146
+ axs[1][1].imshow(canvas)
147
  axs[1][1].set_title("Sonuç")
148
+
149
+ # mask = np.array(obj_mask.cpu().numpy())*255
150
+ # cv2.imwrite("cable_mask.png", mask)
151
+
152
+ plt.tight_layout()
153
  plt.show()