xqt commited on
Commit
f91c3fb
·
1 Parent(s): e312782

REF: SAM2 AMG and the corresponding test case.

Browse files
SegmentAnything2AssistApp.py CHANGED
@@ -257,25 +257,27 @@ def generate_auto_mask(
257
  if VERBOSE:
258
  print("SegmentAnything2AssistApp::generate_auto_mask::Called.")
259
 
260
- __auto_masks, masks, bboxes = segment_anything2assist.generate_automatic_masks(
261
- image,
262
- points_per_side,
263
- points_per_batch,
264
- pred_iou_thresh,
265
- stability_score_thresh,
266
- stability_score_offset,
267
- mask_threshold,
268
- box_nms_thresh,
269
- crop_n_layers,
270
- crop_nms_thresh,
271
- crop_overlay_ratio,
272
- crop_n_points_downscale_factor,
273
- min_mask_region_area,
274
- use_m2m,
275
- multimask_output,
 
 
276
  )
277
 
278
- if len(__auto_masks) == 0:
279
  gradio.Warning(
280
  "No masks generated, please tweak the advanced parameters.", duration=5
281
  )
@@ -294,7 +296,7 @@ def generate_auto_mask(
294
  ),
295
  )
296
  else:
297
- choices = [str(i) for i in range(len(__auto_masks))]
298
 
299
  returning_image = __generate_auto_mask(
300
  image, ["0"], output_mode, False, masks, bboxes
 
257
  if VERBOSE:
258
  print("SegmentAnything2AssistApp::generate_auto_mask::Called.")
259
 
260
+ masks, bboxes, predicted_iou, stability_score = (
261
+ segment_anything2assist.generate_automatic_masks(
262
+ image,
263
+ points_per_side,
264
+ points_per_batch,
265
+ pred_iou_thresh,
266
+ stability_score_thresh,
267
+ stability_score_offset,
268
+ mask_threshold,
269
+ box_nms_thresh,
270
+ crop_n_layers,
271
+ crop_nms_thresh,
272
+ crop_overlay_ratio,
273
+ crop_n_points_downscale_factor,
274
+ min_mask_region_area,
275
+ use_m2m,
276
+ multimask_output,
277
+ )
278
  )
279
 
280
+ if len(masks) == 0:
281
  gradio.Warning(
282
  "No masks generated, please tweak the advanced parameters.", duration=5
283
  )
 
296
  ),
297
  )
298
  else:
299
+ choices = [str(i) for i in range(len(masks))]
300
 
301
  returning_image = __generate_auto_mask(
302
  image, ["0"], output_mode, False, masks, bboxes
src/SegmentAnything2Assist/SegmentAnything2Assist.py CHANGED
@@ -98,7 +98,7 @@ class SegmentAnything2Assist:
98
  )
99
 
100
  if download:
101
- self.download_model()
102
 
103
  if self.is_model_available():
104
  self.sam2 = sam2.build_sam.build_sam2(
@@ -121,14 +121,14 @@ class SegmentAnything2Assist:
121
  print(f"SegmentAnything2Assist::is_model_available::{ret}")
122
  return ret
123
 
124
- def load_model(self) -> bool:
125
  if self.is_model_available():
126
  self.sam2 = sam2.build_sam(checkpoint=self.model_path)
127
  return True
128
 
129
  return False
130
 
131
- def download_model(self, force: bool = False) -> bool:
132
  if not force and self.is_model_available():
133
  print(f"{self.model_path} already exists. Skipping download.")
134
  return False
@@ -162,7 +162,17 @@ class SegmentAnything2Assist:
162
  min_mask_region_area=0,
163
  use_m2m=False,
164
  multimask_output=True,
165
- ):
 
 
 
 
 
 
 
 
 
 
166
  if self.sam2 is None:
167
  print(
168
  "SegmentAnything2Assist::generate_automatic_masks::SAM2 is not loaded."
@@ -196,8 +206,15 @@ class SegmentAnything2Assist:
196
  cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) for mask in segmentation_masks
197
  ]
198
  bbox_masks = [mask["bbox"] for mask in masks]
199
-
200
- return masks, segmentation_masks, bbox_masks
 
 
 
 
 
 
 
201
 
202
  def generate_masks_from_image(
203
  self,
@@ -208,7 +225,15 @@ class SegmentAnything2Assist:
208
  mask_threshold=0.0,
209
  max_hole_area=0.0,
210
  max_sprinkle_area=0.0,
211
- ):
 
 
 
 
 
 
 
 
212
  generator = sam2.sam2_image_predictor.SAM2ImagePredictor(
213
  self.sam2,
214
  mask_threshold=mask_threshold,
@@ -240,8 +265,6 @@ class SegmentAnything2Assist:
240
  image_with_bounding_boxes = image.copy()
241
  all_masks = None
242
 
243
- cv2.imwrite(".tmp/mask_2.png", masks[3])
244
-
245
  for _ in auto_list:
246
  mask = masks[_]
247
  mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
@@ -252,8 +275,6 @@ class SegmentAnything2Assist:
252
  else:
253
  all_masks = cv2.bitwise_or(all_masks, mask)
254
 
255
- cv2.imwrite(".tmp/mask_3.png", masks[3])
256
-
257
  random_color = numpy.random.randint(0, 255, size=3)
258
  image_with_bounding_boxes = cv2.rectangle(
259
  image_with_bounding_boxes,
 
98
  )
99
 
100
  if download:
101
+ self.__download_model()
102
 
103
  if self.is_model_available():
104
  self.sam2 = sam2.build_sam.build_sam2(
 
121
  print(f"SegmentAnything2Assist::is_model_available::{ret}")
122
  return ret
123
 
124
+ def __load_model(self) -> bool:
125
  if self.is_model_available():
126
  self.sam2 = sam2.build_sam(checkpoint=self.model_path)
127
  return True
128
 
129
  return False
130
 
131
+ def __download_model(self, force: bool = False) -> bool:
132
  if not force and self.is_model_available():
133
  print(f"{self.model_path} already exists. Skipping download.")
134
  return False
 
162
  min_mask_region_area=0,
163
  use_m2m=False,
164
  multimask_output=True,
165
+ ) -> typing.Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray]:
166
+ """
167
+ Generates automatic masks from the given image.
168
+
169
+ Returns:
170
+ typing.Tuple: Four numpy arrays where:
171
+ - segmentation_masks: Numpy array shape (N, H, W, C) where N is the number of masks, H is the height of the image, W is the width of the image, and C is the number of channels. Each N is a binary mask of the image of shape (H, W, C).
172
+ - bbox_masks: Numpy array of shape (N, 4) where N is the number of masks and 4 is the bounding box coordinates. Each mask is a bounding box of shape (x, y, w, h).
173
+ - predicted_iou: Numpy array of shape (N,) where N is the number of masks. Each value is the predicted IOU of the mask.
174
+ - stability_score: Numpy array of shape (N,) where N is the number of masks. Each value is the stability score of the mask.
175
+ """
176
  if self.sam2 is None:
177
  print(
178
  "SegmentAnything2Assist::generate_automatic_masks::SAM2 is not loaded."
 
206
  cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) for mask in segmentation_masks
207
  ]
208
  bbox_masks = [mask["bbox"] for mask in masks]
209
+ predicted_iou = [mask["predicted_iou"] for mask in masks]
210
+ stability_score = [mask["stability_score"] for mask in masks]
211
+
212
+ return (
213
+ numpy.array(segmentation_masks, dtype=numpy.uint8),
214
+ numpy.array(bbox_masks, dtype=numpy.uint32),
215
+ numpy.array(predicted_iou, dtype=numpy.float32),
216
+ numpy.array(stability_score, dtype=numpy.float32),
217
+ )
218
 
219
  def generate_masks_from_image(
220
  self,
 
225
  mask_threshold=0.0,
226
  max_hole_area=0.0,
227
  max_sprinkle_area=0.0,
228
+ ) -> typing.Tuple[numpy.ndarray, numpy.ndarray]:
229
+ """
230
+ Generates masks from the given image.
231
+
232
+ Returns:
233
+ typing.Tuple: Two numpy arrays where:
234
+ - masks_chw: Numpy array shape (1, H, W) for the mask, H is the height of the image, and W is the width of the image.
235
+ - mask_iou: Numpy array of shape (1,) for IOU of the mask.
236
+ """
237
  generator = sam2.sam2_image_predictor.SAM2ImagePredictor(
238
  self.sam2,
239
  mask_threshold=mask_threshold,
 
265
  image_with_bounding_boxes = image.copy()
266
  all_masks = None
267
 
 
 
268
  for _ in auto_list:
269
  mask = masks[_]
270
  mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
 
275
  else:
276
  all_masks = cv2.bitwise_or(all_masks, mask)
277
 
 
 
278
  random_color = numpy.random.randint(0, 255, size=3)
279
  image_with_bounding_boxes = cv2.rectangle(
280
  image_with_bounding_boxes,
test/test_module.py CHANGED
@@ -2,6 +2,8 @@ import unittest
2
  import src.SegmentAnything2Assist.SegmentAnything2Assist as SegmentAnything2Assist
3
  import cv2
4
 
 
 
5
 
6
  class TestSegmentAnything2Assist(unittest.TestCase):
7
  def setUp(self) -> None:
@@ -39,21 +41,46 @@ class TestSegmentAnything2Assist(unittest.TestCase):
39
  device="cpu",
40
  )
41
 
42
- def test_generate_automatic_mask(self):
43
  image = cv2.imread("test/assets/liberty.jpg")
44
 
45
  sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
46
  sam_model_name="sam2_hiera_tiny", download=True, device="cpu"
47
  )
48
 
49
- masks, segmentation_masks, bboxes = sam_model.generate_automatic_masks(image)
 
 
50
 
51
- print(type(masks[0]))
52
- print(type(segmentation_masks[0]))
53
- print(type(bboxes[0]))
 
 
 
 
 
 
 
 
54
 
55
- self.assertEqual(len(masks), len(segmentation_masks))
56
- self.assertEqual(len(masks), len(bboxes))
57
 
58
- # for mask, segmentation_mask, bbox in zip(masks, segmentation_masks, bboxes):
59
- self.assertEqual(segmentation_masks[0].shape, image.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import src.SegmentAnything2Assist.SegmentAnything2Assist as SegmentAnything2Assist
3
  import cv2
4
 
5
+ import numpy
6
+
7
 
8
  class TestSegmentAnything2Assist(unittest.TestCase):
9
  def setUp(self) -> None:
 
41
  device="cpu",
42
  )
43
 
44
+ def _generate_automatic_mask(self):
45
  image = cv2.imread("test/assets/liberty.jpg")
46
 
47
  sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
48
  sam_model_name="sam2_hiera_tiny", download=True, device="cpu"
49
  )
50
 
51
+ segmentation_masks, bboxes, predicted_iou, stability_score = (
52
+ sam_model.generate_automatic_masks(image)
53
+ )
54
 
55
+ self.assertEqual(len(segmentation_masks.shape), 4)
56
+ self.assertEqual(segmentation_masks[0].shape, image.shape)
57
+ self.assertEqual(segmentation_masks.shape[3], 3)
58
+ self.assertEqual(type(segmentation_masks[0][0][0][0]), numpy.uint8)
59
+ self.assertEqual(len(bboxes.shape), 2)
60
+ self.assertEqual(bboxes[0].shape, (4,))
61
+ self.assertEqual(type(bboxes[0][0]), numpy.uint32)
62
+ self.assertEqual(len(predicted_iou.shape), 1)
63
+ self.assertEqual(type(predicted_iou[0]), numpy.float32)
64
+ self.assertEqual(len(stability_score.shape), 1)
65
+ self.assertEqual(type(stability_score[0]), numpy.float32)
66
 
67
+ for segmentation_mask in segmentation_masks:
68
+ self.assertEqual(segmentation_mask.shape, image.shape)
69
 
70
+ def test_generate_masks_from_image(self):
71
+ image = cv2.imread("test/assets/liberty.jpg")
72
+
73
+ sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
74
+ sam_model_name="sam2_hiera_tiny", download=True, device="cpu"
75
+ )
76
+
77
+ mask_chw, mask_iou = sam_model.generate_masks_from_image(
78
+ image, None, None, None
79
+ )
80
+
81
+ self.assertEqual(len(mask_chw.shape), 3)
82
+ self.assertEqual(mask_chw[0].shape, image.shape)
83
+ self.assertEqual(mask_chw.shape[0], 1)
84
+
85
+ self.assertEqual(len(mask_iou.shape), 1)
86
+ self.assertEqual(mask_iou.shape[0], 1)