liuyizhang commited on
Commit
72fe59d
·
1 Parent(s): 1f8f331

update app.py

Browse files
GroundingDINO/demo/inference_on_a_image.py CHANGED
@@ -143,7 +143,7 @@ if __name__ == "__main__":
143
  text_prompt = args.text_prompt
144
  output_dir = args.output_dir
145
  box_threshold = args.box_threshold
146
- text_threshold = args.box_threshold
147
 
148
  # make dir
149
  os.makedirs(output_dir, exist_ok=True)
 
143
  text_prompt = args.text_prompt
144
  output_dir = args.output_dir
145
  box_threshold = args.box_threshold
146
+ text_threshold = args.text_threshold
147
 
148
  # make dir
149
  os.makedirs(output_dir, exist_ok=True)
app.py CHANGED
@@ -1,10 +1,11 @@
1
 
2
- import subprocess, os, sys, time
3
 
4
  os.environ["CUDA_VISIBLE_DEVICES"] = "0"
5
 
6
- result = subprocess.run(['pip', 'install', '-e', 'GroundingDINO'], check=True)
7
- print(f'pip install GroundingDINO = {result}')
 
8
 
9
  result = subprocess.run(['pip', 'list'], check=True)
10
  print(f'pip list = {result}')
@@ -12,6 +13,7 @@ print(f'pip list = {result}')
12
  sys.path.insert(0, './GroundingDINO')
13
 
14
  if not os.path.exists('./sam_vit_h_4b8939.pth'):
 
15
  result = subprocess.run(['wget', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'], check=True)
16
  print(f'wget sam_vit_h_4b8939.pth result = {result}')
17
 
@@ -19,10 +21,11 @@ import gradio as gr
19
 
20
  import argparse
21
  import copy
 
22
 
23
  import numpy as np
24
  import torch
25
- from PIL import Image, ImageDraw, ImageFont
26
 
27
  # Grounding DINO
28
  import GroundingDINO.groundingdino.datasets.transforms as T
@@ -31,12 +34,14 @@ from GroundingDINO.groundingdino.util import box_ops
31
  from GroundingDINO.groundingdino.util.slconfig import SLConfig
32
  from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
33
 
34
- # segment anything
35
- from segment_anything import build_sam, SamPredictor
36
  import cv2
37
  import numpy as np
38
  import matplotlib.pyplot as plt
 
 
39
 
 
 
40
 
41
  # diffusers
42
  import PIL
@@ -108,8 +113,10 @@ def plot_boxes_to_image(image_pil, tgt):
108
 
109
  def load_image(image_path):
110
  # # load image
111
- # image_pil = Image.open(image_path).convert("RGB") # load image
112
- image_pil = image_path
 
 
113
 
114
  transform = T.Compose(
115
  [
@@ -181,6 +188,38 @@ def show_box(box, ax, label):
181
  ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
182
  ax.text(x0, y0, label)
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
185
  ckpt_repo_id = "ShilongLiu/GroundingDINO"
186
  ckpt_filenmae = "groundingdino_swint_ogc.pth"
@@ -189,53 +228,157 @@ output_dir = "outputs"
189
  device = "cuda"
190
 
191
  device = get_device()
192
-
193
  print(f'device={device}')
194
 
195
  # initialize groundingdino model
 
196
  groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
197
 
198
  # initialize SAM
 
199
  sam_predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
200
 
201
  # initialize stable-diffusion-inpainting
202
- sd_pipe = StableDiffusionInpaintPipeline.from_pretrained(
203
- "runwayml/stable-diffusion-inpainting",
204
- torch_dtype=torch.float16
 
 
 
 
 
 
 
 
 
 
 
 
205
  )
206
- sd_pipe = sd_pipe.to(device)
207
 
208
- def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold):
209
- assert text_prompt, 'text_prompt is not found!'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
  # make dir
212
  os.makedirs(output_dir, exist_ok=True)
213
  # load image
214
- image_pil, image = load_image(image_path.convert("RGB"))
 
 
 
215
 
216
  file_temp = int(time.time())
217
 
218
  # visualize raw image
219
  # image_pil.save(os.path.join(output_dir, f"raw_image_{file_temp}.jpg"))
 
 
220
 
 
221
  # run grounding dino model
222
- groundingdino_device = 'cpu'
223
- if device != 'cpu':
224
- try:
225
- from groundingdino import _C
226
- groundingdino_device = 'cuda:0'
227
- except:
228
- warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only in groundingdino!")
229
-
230
- groundingdino_device = 'cpu'
231
- boxes_filt, pred_phrases = get_grounding_output(
232
- groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
233
- )
 
 
 
 
234
 
235
- size = image_pil.size
 
 
 
 
 
 
 
 
 
 
236
 
237
- if task_type == 'segment' or task_type == 'inpainting':
238
- image = np.array(image_path)
 
239
  sam_predictor.set_image(image)
240
 
241
  H, W = size[1], size[0]
@@ -253,25 +396,8 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr
253
  boxes = transformed_boxes,
254
  multimask_output = False,
255
  )
256
-
257
- # masks: [1, 1, 512, 512]
258
-
259
- if task_type == 'detection':
260
- pred_dict = {
261
- "boxes": boxes_filt,
262
- "size": [size[1], size[0]], # H,W
263
- "labels": pred_phrases,
264
- }
265
- # import ipdb; ipdb.set_trace()
266
- image_with_box = plot_boxes_to_image(image_pil, pred_dict)[0]
267
- image_path = os.path.join(output_dir, f"grounding_dino_output_{file_temp}.jpg")
268
- image_with_box.save(image_path)
269
- image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
270
- os.remove(image_path)
271
- return image_result
272
- elif task_type == 'segment':
273
  assert sam_checkpoint, 'sam_checkpoint is not found!'
274
-
275
  # draw output image
276
  plt.figure(figsize=(10, 10))
277
  plt.imshow(image)
@@ -282,39 +408,106 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr
282
  plt.axis('off')
283
  image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.jpg")
284
  plt.savefig(image_path, bbox_inches="tight")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
286
  os.remove(image_path)
287
- return image_result
288
- elif task_type == 'inpainting':
289
- assert inpaint_prompt, 'inpaint_prompt is not found!'
290
- # inpainting pipeline
291
- mask = masks[0][0].cpu().numpy() # simply choose the first mask, which will be refine in the future release
292
- mask_pil = Image.fromarray(mask)
293
- image_pil = Image.fromarray(image)
294
- # image_inpainting = sd_pipe(prompt=inpaint_prompt, image=image_pil, mask_image=mask_pil).images[0]
295
-
296
- # resize for inpaint
297
- image_source_for_inpaint = image_pil.resize((512, 512))
298
- image_mask_for_inpaint = mask_pil.resize((512, 512))
299
- image_inpainting = sd_pipe(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
301
 
302
  image_path = os.path.join(output_dir, f"grounded_sam_inpainting_output_{file_temp}.jpg")
303
  image_inpainting.save(image_path)
304
  image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
305
  os.remove(image_path)
306
- return image_result
 
 
307
  else:
308
- print("task_type:{} error!".format(task_type))
309
-
310
- def change_task_type(task_type):
 
 
 
 
 
311
  if task_type == "inpainting":
312
- return gr.Textbox.update(visible=True)
313
- else:
314
- return gr.Textbox.update(visible=False)
 
 
 
315
 
316
  if __name__ == "__main__":
317
-
318
  parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
319
  parser.add_argument("--debug", action="store_true", help="using debug mode")
320
  parser.add_argument("--share", action="store_true", help="share the app")
@@ -326,11 +519,14 @@ if __name__ == "__main__":
326
  with block:
327
  with gr.Row():
328
  with gr.Column():
329
- input_image = gr.Image(source='upload', type="pil")
330
- task_type = gr.Radio(["detection", "segment", "inpainting"], value="detection",
331
  label='Task type',interactive=True, visible=True)
 
 
 
332
  text_prompt = gr.Textbox(label="Detection Prompt", placeholder="Cannot be empty")
333
- inpaint_prompt = gr.Textbox(label="Inpaint Prompt", visible=True)
334
  run_button = gr.Button(label="Run")
335
  with gr.Accordion("Advanced options", open=False):
336
  box_threshold = gr.Slider(
@@ -339,18 +535,28 @@ if __name__ == "__main__":
339
  text_threshold = gr.Slider(
340
  label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
341
  )
 
 
 
 
 
 
 
 
 
342
 
343
  with gr.Column():
344
- gallery = gr.outputs.Image(
345
- type="pil",
346
- ).style(full_width=True, full_height=True)
347
 
348
  run_button.click(fn=run_grounded_sam, inputs=[
349
- input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold], outputs=[gallery])
350
- # task_type.change(fn=change_task_type, inputs=[task_type], outputs=[inpaint_prompt])
 
351
 
352
  DESCRIPTION = '### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). Thanks for their excellent work.'
353
  DESCRIPTION += f'<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/yizhangliu/Grounded-Segment-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
354
  gr.Markdown(DESCRIPTION)
355
 
356
- block.launch(server_name='0.0.0.0', debug=args.debug, share=args.share)
 
1
 
2
+ import subprocess, io, os, sys, time
3
 
4
  os.environ["CUDA_VISIBLE_DEVICES"] = "0"
5
 
6
+ if os.environ.get('IS_MY_DEBUG') is None:
7
+ result = subprocess.run(['pip', 'install', '-e', 'GroundingDINO'], check=True)
8
+ print(f'pip install GroundingDINO = {result}')
9
 
10
  result = subprocess.run(['pip', 'list'], check=True)
11
  print(f'pip list = {result}')
 
13
  sys.path.insert(0, './GroundingDINO')
14
 
15
  if not os.path.exists('./sam_vit_h_4b8939.pth'):
16
+ logger.info(f"get sam_vit_h_4b8939.pth...")
17
  result = subprocess.run(['wget', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'], check=True)
18
  print(f'wget sam_vit_h_4b8939.pth result = {result}')
19
 
 
21
 
22
  import argparse
23
  import copy
24
+ from loguru import logger
25
 
26
  import numpy as np
27
  import torch
28
+ from PIL import Image, ImageDraw, ImageFont, ImageOps
29
 
30
  # Grounding DINO
31
  import GroundingDINO.groundingdino.datasets.transforms as T
 
34
  from GroundingDINO.groundingdino.util.slconfig import SLConfig
35
  from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
36
 
 
 
37
  import cv2
38
  import numpy as np
39
  import matplotlib.pyplot as plt
40
+ from lama_cleaner.model_manager import ModelManager
41
+ from lama_cleaner.schema import Config
42
 
43
+ # segment anything
44
+ from segment_anything import build_sam, SamPredictor
45
 
46
  # diffusers
47
  import PIL
 
113
 
114
  def load_image(image_path):
115
  # # load image
116
+ if isinstance(image_path, PIL.Image.Image):
117
+ image_pil = image_path
118
+ else:
119
+ image_pil = Image.open(image_path).convert("RGB") # load image
120
 
121
  transform = T.Compose(
122
  [
 
188
  ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
189
  ax.text(x0, y0, label)
190
 
191
+ def xywh_to_xyxy(box, sizeW, sizeH):
192
+ if isinstance(box, list):
193
+ box = torch.Tensor(box)
194
+ box = box * torch.Tensor([sizeW, sizeH, sizeW, sizeH])
195
+ box[:2] -= box[2:] / 2
196
+ box[2:] += box[:2]
197
+ box = box.numpy()
198
+ return box
199
+
200
+ def mask_extend(img, box, extend_pixels=10, useRectangle=True):
201
+ box[0] = int(box[0])
202
+ box[1] = int(box[1])
203
+ box[2] = int(box[2])
204
+ box[3] = int(box[3])
205
+ region = img.crop(tuple(box))
206
+ new_width = box[2] - box[0] + 2*extend_pixels
207
+ new_height = box[3] - box[1] + 2*extend_pixels
208
+
209
+ region_BILINEAR = region.resize((int(new_width), int(new_height)))
210
+ if useRectangle:
211
+ region_draw = ImageDraw.Draw(region_BILINEAR)
212
+ region_draw.rectangle((0, 0, new_width, new_height), fill=(255, 255, 255))
213
+ img.paste(region_BILINEAR, (int(box[0]-extend_pixels), int(box[1]-extend_pixels)))
214
+ return img
215
+
216
+ def mix_masks(imgs):
217
+ re_img = 1 - np.asarray(imgs[0].convert("1"))
218
+ for i in range(len(imgs)-1):
219
+ re_img = np.multiply(re_img, 1 - np.asarray(imgs[i+1].convert("1")))
220
+ re_img = 1 - re_img
221
+ return Image.fromarray(np.uint8(255*re_img))
222
+
223
  config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
224
  ckpt_repo_id = "ShilongLiu/GroundingDINO"
225
  ckpt_filenmae = "groundingdino_swint_ogc.pth"
 
228
  device = "cuda"
229
 
230
  device = get_device()
 
231
  print(f'device={device}')
232
 
233
  # initialize groundingdino model
234
+ logger.info(f"initialize groundingdino model...")
235
  groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
236
 
237
  # initialize SAM
238
+ logger.info(f"initialize SAM model...")
239
  sam_predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
240
 
241
  # initialize stable-diffusion-inpainting
242
+ logger.info(f"initialize stable-diffusion-inpainting...")
243
+ sd_pipe = None
244
+ if os.environ.get('IS_MY_DEBUG') is None:
245
+ sd_pipe = StableDiffusionInpaintPipeline.from_pretrained(
246
+ "runwayml/stable-diffusion-inpainting",
247
+ torch_dtype=torch.float16
248
+ )
249
+ sd_pipe = sd_pipe.to(device)
250
+
251
+ # initialize lama_cleaner
252
+ logger.info(f"initialize lama_cleaner...")
253
+ from lama_cleaner.helper import (
254
+ load_img,
255
+ numpy_to_bytes,
256
+ resize_max_size,
257
  )
 
258
 
259
+ lama_cleaner_model = ModelManager(
260
+ name='lama',
261
+ device=device,
262
+ )
263
+
264
+ def lama_cleaner_process(image, mask):
265
+ ori_image = image
266
+ if mask.shape[0] == image.shape[1] and mask.shape[1] == image.shape[0] and mask.shape[0] != mask.shape[1]:
267
+ # rotate image
268
+ ori_image = np.transpose(image[::-1, ...][:, ::-1], axes=(1, 0, 2))[::-1, ...]
269
+ image = ori_image
270
+
271
+ original_shape = ori_image.shape
272
+ interpolation = cv2.INTER_CUBIC
273
+
274
+ size_limit = 1080
275
+ if size_limit == "Original":
276
+ size_limit = max(image.shape)
277
+ else:
278
+ size_limit = int(size_limit)
279
+
280
+ config = Config(
281
+ ldm_steps=25,
282
+ ldm_sampler='plms',
283
+ zits_wireframe=True,
284
+ hd_strategy='Original',
285
+ hd_strategy_crop_margin=196,
286
+ hd_strategy_crop_trigger_size=1280,
287
+ hd_strategy_resize_limit=2048,
288
+ prompt='',
289
+ use_croper=False,
290
+ croper_x=0,
291
+ croper_y=0,
292
+ croper_height=512,
293
+ croper_width=512,
294
+ sd_mask_blur=5,
295
+ sd_strength=0.75,
296
+ sd_steps=50,
297
+ sd_guidance_scale=7.5,
298
+ sd_sampler='ddim',
299
+ sd_seed=42,
300
+ cv2_flag='INPAINT_NS',
301
+ cv2_radius=5,
302
+ )
303
+
304
+ if config.sd_seed == -1:
305
+ config.sd_seed = random.randint(1, 999999999)
306
+
307
+ # logger.info(f"Origin image shape_0_: {original_shape} / {size_limit}")
308
+ image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
309
+ # logger.info(f"Resized image shape_1_: {image.shape}")
310
+
311
+ # logger.info(f"mask image shape_0_: {mask.shape} / {type(mask)}")
312
+ mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
313
+ # logger.info(f"mask image shape_1_: {mask.shape} / {type(mask)}")
314
+
315
+ res_np_img = lama_cleaner_model(image, mask, config)
316
+ torch.cuda.empty_cache()
317
+
318
+ image = Image.open(io.BytesIO(numpy_to_bytes(res_np_img, 'png')))
319
+ return image
320
+
321
+ mask_source_draw = "draw a mask on input image"
322
+ mask_source_segment = "type what to detect below"
323
+
324
+ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
325
+ iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend):
326
+ if (task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw:
327
+ pass
328
+ else:
329
+ assert text_prompt, 'text_prompt is not found!'
330
+
331
+ logger.info(f'run_grounded_sam_1_')
332
 
333
  # make dir
334
  os.makedirs(output_dir, exist_ok=True)
335
  # load image
336
+ input_mask_pil = input_image['mask']
337
+ input_mask = np.array(input_mask_pil.convert("L"))
338
+
339
+ image_pil, image = load_image(input_image['image'].convert("RGB"))
340
 
341
  file_temp = int(time.time())
342
 
343
  # visualize raw image
344
  # image_pil.save(os.path.join(output_dir, f"raw_image_{file_temp}.jpg"))
345
+
346
+ size = image_pil.size
347
 
348
+ output_images = []
349
  # run grounding dino model
350
+ if (task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw:
351
+ pass
352
+ else:
353
+ groundingdino_device = 'cpu'
354
+ if device != 'cpu':
355
+ try:
356
+ from groundingdino import _C
357
+ groundingdino_device = 'cuda:0'
358
+ except:
359
+ warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only in groundingdino!")
360
+
361
+ groundingdino_device = 'cpu'
362
+ boxes_filt, pred_phrases = get_grounding_output(
363
+ groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
364
+ )
365
+ boxes_filt_ori = copy.deepcopy(boxes_filt)
366
 
367
+ pred_dict = {
368
+ "boxes": boxes_filt,
369
+ "size": [size[1], size[0]], # H,W
370
+ "labels": pred_phrases,
371
+ }
372
+ image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
373
+ image_path = os.path.join(output_dir, f"grounding_dino_output_{file_temp}.jpg")
374
+ image_with_box.save(image_path)
375
+ detection_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
376
+ os.remove(image_path)
377
+ output_images.append(detection_image_result)
378
 
379
+ logger.info(f'run_grounded_sam_2_')
380
+ if task_type == 'segment' or ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_segment):
381
+ image = np.array(input_image['image'])
382
  sam_predictor.set_image(image)
383
 
384
  H, W = size[1], size[0]
 
396
  boxes = transformed_boxes,
397
  multimask_output = False,
398
  )
399
+ # masks: [9, 1, 512, 512]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  assert sam_checkpoint, 'sam_checkpoint is not found!'
 
401
  # draw output image
402
  plt.figure(figsize=(10, 10))
403
  plt.imshow(image)
 
408
  plt.axis('off')
409
  image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.jpg")
410
  plt.savefig(image_path, bbox_inches="tight")
411
+ segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
412
+ os.remove(image_path)
413
+ output_images.append(segment_image_result)
414
+
415
+ logger.info(f'run_grounded_sam_3_')
416
+ if task_type == 'detection' or task_type == 'segment':
417
+ logger.info(f'run_grounded_sam_9_{task_type}_')
418
+ return output_images
419
+ elif task_type == 'inpainting' or task_type == 'remove':
420
+ if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
421
+ task_type = 'remove'
422
+
423
+ logger.info(f'run_grounded_sam_4_{task_type}_')
424
+ if mask_source_radio == mask_source_draw:
425
+ mask_pil = input_mask_pil
426
+ mask = input_mask
427
+ else:
428
+ if inpaint_mode == 'merge':
429
+ masks = torch.sum(masks, dim=0).unsqueeze(0)
430
+ masks = torch.where(masks > 0, True, False)
431
+ mask = masks[0][0].cpu().numpy()
432
+ mask_pil = Image.fromarray(mask)
433
+
434
+ image_path = os.path.join(output_dir, f"image_mask_{file_temp}.jpg")
435
+ mask_pil.convert("RGB").save(image_path)
436
  image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
437
  os.remove(image_path)
438
+ output_images.append(image_result)
439
+
440
+ if task_type == 'inpainting':
441
+ # inpainting pipeline
442
+ image_source_for_inpaint = image_pil.resize((512, 512))
443
+ image_mask_for_inpaint = mask_pil.resize((512, 512))
444
+ image_inpainting = sd_pipe(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
445
+ else:
446
+ # remove from mask
447
+ if mask_source_radio == mask_source_segment:
448
+ mask_imgs = []
449
+ masks_shape = masks.shape
450
+ boxes_filt_ori_array = boxes_filt_ori.numpy()
451
+ if inpaint_mode == 'merge':
452
+ extend_shape_0 = masks_shape[0]
453
+ extend_shape_1 = masks_shape[1]
454
+ else:
455
+ extend_shape_0 = 1
456
+ extend_shape_1 = 1
457
+ for i in range(extend_shape_0):
458
+ for j in range(extend_shape_1):
459
+ mask = masks[i][j].cpu().numpy()
460
+ mask_pil = Image.fromarray(mask)
461
+
462
+ if remove_mode == 'segment':
463
+ useRectangle = False
464
+ else:
465
+ useRectangle = True
466
+
467
+ try:
468
+ remove_mask_extend = int(remove_mask_extend)
469
+ except:
470
+ remove_mask_extend = 10
471
+ mask_pil_exp = mask_extend(copy.deepcopy(mask_pil).convert("RGB"),
472
+ xywh_to_xyxy(torch.tensor(boxes_filt_ori_array[i]), size[0], size[1]),
473
+ extend_pixels=remove_mask_extend, useRectangle=useRectangle)
474
+ mask_imgs.append(mask_pil_exp)
475
+ mask_pil = mix_masks(mask_imgs)
476
+
477
+ image_path = os.path.join(output_dir, f"image_mask_{file_temp}.jpg")
478
+ mask_pil.convert("RGB").save(image_path)
479
+ image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
480
+ os.remove(image_path)
481
+ output_images.append(image_result)
482
+ image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")))
483
+
484
  image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
485
 
486
  image_path = os.path.join(output_dir, f"grounded_sam_inpainting_output_{file_temp}.jpg")
487
  image_inpainting.save(image_path)
488
  image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
489
  os.remove(image_path)
490
+ logger.info(f'run_grounded_sam_9_{task_type}_')
491
+ output_images.append(image_result)
492
+ return output_images
493
  else:
494
+ logger.info(f"task_type:{task_type} error!")
495
+ logger.info(f'run_grounded_sam_9_9_')
496
+ return output_images
497
+
498
+ def change_radio_display(task_type, mask_source_radio):
499
+ text_prompt_visible = True
500
+ inpaint_prompt_visible = False
501
+ mask_source_radio_visible = False
502
  if task_type == "inpainting":
503
+ inpaint_prompt_visible = True
504
+ if task_type == "inpainting" or task_type == "remove":
505
+ mask_source_radio_visible = True
506
+ if mask_source_radio == mask_source_draw:
507
+ text_prompt_visible = False
508
+ return gr.Textbox.update(visible=text_prompt_visible), gr.Textbox.update(visible=inpaint_prompt_visible), gr.Radio.update(visible=mask_source_radio_visible)
509
 
510
  if __name__ == "__main__":
 
511
  parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
512
  parser.add_argument("--debug", action="store_true", help="using debug mode")
513
  parser.add_argument("--share", action="store_true", help="share the app")
 
519
  with block:
520
  with gr.Row():
521
  with gr.Column():
522
+ input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload")
523
+ task_type = gr.Radio(["detection", "segment", "inpainting", "remove"], value="detection",
524
  label='Task type',interactive=True, visible=True)
525
+ mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
526
+ value=mask_source_segment, label="Mask from",
527
+ interactive=True, visible=False)
528
  text_prompt = gr.Textbox(label="Detection Prompt", placeholder="Cannot be empty")
529
+ inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
530
  run_button = gr.Button(label="Run")
531
  with gr.Accordion("Advanced options", open=False):
532
  box_threshold = gr.Slider(
 
535
  text_threshold = gr.Slider(
536
  label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
537
  )
538
+ iou_threshold = gr.Slider(
539
+ label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.001
540
+ )
541
+ inpaint_mode = gr.Radio(["merge", "first"], value="merge", label="inpaint_mode")
542
+ with gr.Row():
543
+ with gr.Column(scale=1):
544
+ remove_mode = gr.Radio(["segment", "rectangle"], value="segment", label='remove mode')
545
+ with gr.Column(scale=1):
546
+ remove_mask_extend = gr.Textbox(label="remove_mask_extend", value='10')
547
 
548
  with gr.Column():
549
+ gallery = gr.Gallery(
550
+ label="Generated images", show_label=False, elem_id="gallery"
551
+ ).style(grid=[2], full_width=True, full_height=True)
552
 
553
  run_button.click(fn=run_grounded_sam, inputs=[
554
+ input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend], outputs=[gallery])
555
+ task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio])
556
+ mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio])
557
 
558
  DESCRIPTION = '### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). Thanks for their excellent work.'
559
  DESCRIPTION += f'<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/yizhangliu/Grounded-Segment-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
560
  gr.Markdown(DESCRIPTION)
561
 
562
+ block.launch(server_name='0.0.0.0', debug=args.debug, share=args.share)
automatic_label_demo.py CHANGED
@@ -224,7 +224,7 @@ if __name__ == "__main__":
224
  openai_proxy = args.openai_proxy
225
  output_dir = args.output_dir
226
  box_threshold = args.box_threshold
227
- text_threshold = args.box_threshold
228
  iou_threshold = args.iou_threshold
229
  device = args.device
230
 
@@ -264,7 +264,9 @@ if __name__ == "__main__":
264
  )
265
 
266
  # initialize SAM
267
- predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
 
 
268
  image = cv2.imread(image_path)
269
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
270
  predictor.set_image(image)
@@ -286,7 +288,7 @@ if __name__ == "__main__":
286
  caption = check_caption(caption, pred_phrases)
287
  print(f"Revise caption with number: {caption}")
288
 
289
- transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
290
 
291
  masks, _, _ = predictor.predict_torch(
292
  point_coords = None,
 
224
  openai_proxy = args.openai_proxy
225
  output_dir = args.output_dir
226
  box_threshold = args.box_threshold
227
+ text_threshold = args.text_threshold
228
  iou_threshold = args.iou_threshold
229
  device = args.device
230
 
 
264
  )
265
 
266
  # initialize SAM
267
+ sam = build_sam(checkpoint=sam_checkpoint)
268
+ sam.to(device=device)
269
+ predictor = SamPredictor(sam)
270
  image = cv2.imread(image_path)
271
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
272
  predictor.set_image(image)
 
288
  caption = check_caption(caption, pred_phrases)
289
  print(f"Revise caption with number: {caption}")
290
 
291
+ transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
292
 
293
  masks, _, _ = predictor.predict_torch(
294
  point_coords = None,
gradio_app.py DELETED
@@ -1,345 +0,0 @@
1
- import os
2
- # os.system('pip install v0.1.0-alpha2.tar.gz')
3
- import gradio as gr
4
-
5
- import argparse
6
- import copy
7
-
8
- import numpy as np
9
- import torch
10
- import torchvision
11
- from PIL import Image, ImageDraw, ImageFont
12
-
13
- # Grounding DINO
14
- import GroundingDINO.groundingdino.datasets.transforms as T
15
- from GroundingDINO.groundingdino.models import build_model
16
- from GroundingDINO.groundingdino.util import box_ops
17
- from GroundingDINO.groundingdino.util.slconfig import SLConfig
18
- from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
19
-
20
- # segment anything
21
- from segment_anything import build_sam, SamPredictor
22
- import cv2
23
- import numpy as np
24
- import matplotlib.pyplot as plt
25
-
26
-
27
- # diffusers
28
- import PIL
29
- import requests
30
- import torch
31
- from io import BytesIO
32
- from diffusers import StableDiffusionInpaintPipeline
33
- from huggingface_hub import hf_hub_download
34
-
35
- # BLIP
36
- from transformers import BlipProcessor, BlipForConditionalGeneration
37
-
38
-
39
- def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
40
- args = SLConfig.fromfile(model_config_path)
41
- model = build_model(args)
42
- args.device = device
43
-
44
- cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
45
- checkpoint = torch.load(cache_file, map_location='cpu')
46
- log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
47
- print("Model loaded from {} \n => {}".format(cache_file, log))
48
- _ = model.eval()
49
- return model
50
-
51
- def generate_caption(processor, blip_model, raw_image):
52
- # unconditional image captioning
53
- inputs = processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
54
- out = blip_model.generate(**inputs)
55
- caption = processor.decode(out[0], skip_special_tokens=True)
56
- return caption
57
-
58
- def plot_boxes_to_image(image_pil, tgt):
59
- H, W = tgt["size"]
60
- boxes = tgt["boxes"]
61
- labels = tgt["labels"]
62
- assert len(boxes) == len(labels), "boxes and labels must have same length"
63
-
64
- draw = ImageDraw.Draw(image_pil)
65
- mask = Image.new("L", image_pil.size, 0)
66
- mask_draw = ImageDraw.Draw(mask)
67
-
68
- # draw boxes and masks
69
- for box, label in zip(boxes, labels):
70
- # from 0..1 to 0..W, 0..H
71
- box = box * torch.Tensor([W, H, W, H])
72
- # from xywh to xyxy
73
- box[:2] -= box[2:] / 2
74
- box[2:] += box[:2]
75
- # random color
76
- color = tuple(np.random.randint(0, 255, size=3).tolist())
77
- # draw
78
- x0, y0, x1, y1 = box
79
- x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
80
-
81
- draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
82
- # draw.text((x0, y0), str(label), fill=color)
83
-
84
- font = ImageFont.load_default()
85
- if hasattr(font, "getbbox"):
86
- bbox = draw.textbbox((x0, y0), str(label), font)
87
- else:
88
- w, h = draw.textsize(str(label), font)
89
- bbox = (x0, y0, w + x0, y0 + h)
90
- # bbox = draw.textbbox((x0, y0), str(label))
91
- draw.rectangle(bbox, fill=color)
92
- draw.text((x0, y0), str(label), fill="white")
93
-
94
- mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6)
95
-
96
- return image_pil, mask
97
-
98
- def load_image(image_path):
99
- # # load image
100
- # image_pil = Image.open(image_path).convert("RGB") # load image
101
- image_pil = image_path
102
-
103
- transform = T.Compose(
104
- [
105
- T.RandomResize([800], max_size=1333),
106
- T.ToTensor(),
107
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
108
- ]
109
- )
110
- image, _ = transform(image_pil, None) # 3, h, w
111
- return image_pil, image
112
-
113
-
114
- def load_model(model_config_path, model_checkpoint_path, device):
115
- args = SLConfig.fromfile(model_config_path)
116
- args.device = device
117
- model = build_model(args)
118
- checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
119
- load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
120
- print(load_res)
121
- _ = model.eval()
122
- return model
123
-
124
-
125
- def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
126
- caption = caption.lower()
127
- caption = caption.strip()
128
- if not caption.endswith("."):
129
- caption = caption + "."
130
- model = model.to(device)
131
- image = image.to(device)
132
- with torch.no_grad():
133
- outputs = model(image[None], captions=[caption])
134
- logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
135
- boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
136
- logits.shape[0]
137
-
138
- # filter output
139
- logits_filt = logits.clone()
140
- boxes_filt = boxes.clone()
141
- filt_mask = logits_filt.max(dim=1)[0] > box_threshold
142
- logits_filt = logits_filt[filt_mask] # num_filt, 256
143
- boxes_filt = boxes_filt[filt_mask] # num_filt, 4
144
- logits_filt.shape[0]
145
-
146
- # get phrase
147
- tokenlizer = model.tokenizer
148
- tokenized = tokenlizer(caption)
149
- # build pred
150
- pred_phrases = []
151
- scores = []
152
- for logit, box in zip(logits_filt, boxes_filt):
153
- pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
154
- if with_logits:
155
- pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
156
- else:
157
- pred_phrases.append(pred_phrase)
158
- scores.append(logit.max().item())
159
-
160
- return boxes_filt, torch.Tensor(scores), pred_phrases
161
-
162
- def show_mask(mask, ax, random_color=False):
163
- if random_color:
164
- color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
165
- else:
166
- color = np.array([30/255, 144/255, 255/255, 0.6])
167
- h, w = mask.shape[-2:]
168
- mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
169
- ax.imshow(mask_image)
170
-
171
-
172
- def show_box(box, ax, label):
173
- x0, y0 = box[0], box[1]
174
- w, h = box[2] - box[0], box[3] - box[1]
175
- ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
176
- ax.text(x0, y0, label)
177
-
178
-
179
- config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
180
- ckpt_repo_id = "ShilongLiu/GroundingDINO"
181
- ckpt_filenmae = "groundingdino_swint_ogc.pth"
182
- sam_checkpoint='sam_vit_h_4b8939.pth'
183
- output_dir="outputs"
184
- device="cuda"
185
-
186
- def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode):
187
-
188
- # make dir
189
- os.makedirs(output_dir, exist_ok=True)
190
- # load image
191
- image_pil, image = load_image(image_path.convert("RGB"))
192
- # load model
193
- model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
194
- # model = load_model(config_file, ckpt_filenmae, device=device)
195
-
196
- # visualize raw image
197
- image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
198
-
199
- if task_type == 'automatic':
200
- # generate caption and tags
201
- # use Tag2Text can generate better captions
202
- # https://huggingface.co/spaces/xinyu1205/Tag2Text
203
- # but there are some bugs...
204
- processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
205
- blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
206
- text_prompt = generate_caption(processor, blip_model, image_pil)
207
- print(f"Caption: {text_prompt}")
208
-
209
- # run grounding dino model
210
- boxes_filt, scores, pred_phrases = get_grounding_output(
211
- model, image, text_prompt, box_threshold, text_threshold, device=device
212
- )
213
-
214
- size = image_pil.size
215
-
216
- if task_type == 'seg' or task_type == 'inpainting' or task_type == 'automatic':
217
- # initialize SAM
218
- predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
219
- image = np.array(image_path)
220
- predictor.set_image(image)
221
-
222
- H, W = size[1], size[0]
223
- for i in range(boxes_filt.size(0)):
224
- boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
225
- boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
226
- boxes_filt[i][2:] += boxes_filt[i][:2]
227
-
228
- boxes_filt = boxes_filt.cpu()
229
-
230
- if task_type == 'automatic':
231
- # use NMS to handle overlapped boxes
232
- print(f"Before NMS: {boxes_filt.shape[0]} boxes")
233
- nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
234
- boxes_filt = boxes_filt[nms_idx]
235
- pred_phrases = [pred_phrases[idx] for idx in nms_idx]
236
- print(f"After NMS: {boxes_filt.shape[0]} boxes")
237
- print(f"Revise caption with number: {text_prompt}")
238
-
239
- transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
240
-
241
- masks, _, _ = predictor.predict_torch(
242
- point_coords = None,
243
- point_labels = None,
244
- boxes = transformed_boxes,
245
- multimask_output = False,
246
- )
247
-
248
- # masks: [1, 1, 512, 512]
249
-
250
- if task_type == 'det':
251
- pred_dict = {
252
- "boxes": boxes_filt,
253
- "size": [size[1], size[0]], # H,W
254
- "labels": pred_phrases,
255
- }
256
- # import ipdb; ipdb.set_trace()
257
- image_with_box = plot_boxes_to_image(image_pil, pred_dict)[0]
258
- image_path = os.path.join(output_dir, "grounding_dino_output.jpg")
259
- image_with_box.save(image_path)
260
- image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
261
- return image_result
262
- elif task_type == 'seg' or task_type == 'automatic':
263
- assert sam_checkpoint, 'sam_checkpoint is not found!'
264
-
265
- # draw output image
266
- plt.figure(figsize=(10, 10))
267
- plt.imshow(image)
268
- for mask in masks:
269
- show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
270
- for box, label in zip(boxes_filt, pred_phrases):
271
- show_box(box.numpy(), plt.gca(), label)
272
- if task_type == 'automatic':
273
- plt.title(text_prompt)
274
- plt.axis('off')
275
- image_path = os.path.join(output_dir, "grounding_dino_output.jpg")
276
- plt.savefig(image_path, bbox_inches="tight")
277
- image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
278
- return image_result
279
- elif task_type == 'inpainting':
280
- assert inpaint_prompt, 'inpaint_prompt is not found!'
281
- # inpainting pipeline
282
- if inpaint_mode == 'merge':
283
- masks = torch.sum(masks, dim=0).unsqueeze(0)
284
- masks = torch.where(masks > 0, True, False)
285
- else:
286
- mask = masks[0][0].cpu().numpy() # simply choose the first mask, which will be refine in the future release
287
- mask_pil = Image.fromarray(mask)
288
-
289
- pipe = StableDiffusionInpaintPipeline.from_pretrained(
290
- "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
291
- )
292
- pipe = pipe.to("cuda")
293
-
294
- image_pil = image_pil.resize((512, 512))
295
- mask_pil = mask_pil.resize((512, 512))
296
-
297
- image = pipe(prompt=inpaint_prompt, image=image_pil, mask_image=mask_pil).images[0]
298
- image = image.resize(size)
299
-
300
- image_path = os.path.join(output_dir, "grounded_sam_inpainting_output.jpg")
301
- image.save(image_path)
302
- image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
303
- return image_result
304
- else:
305
- print("task_type:{} error!".format(task_type))
306
-
307
- if __name__ == "__main__":
308
-
309
- parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
310
- parser.add_argument("--debug", action="store_true", help="using debug mode")
311
- parser.add_argument("--share", action="store_true", help="share the app")
312
- parser.add_argument('--port', type=int, default=7589, help='port to run the server')
313
- args = parser.parse_args()
314
-
315
- block = gr.Blocks().queue()
316
- with block:
317
- with gr.Row():
318
- with gr.Column():
319
- input_image = gr.Image(source='upload', type="pil", value="assets/demo1.jpg")
320
- task_type = gr.Dropdown(["det", "seg", "inpainting", "automatic"], value="automatic", label="task_type")
321
- text_prompt = gr.Textbox(label="Text Prompt")
322
- inpaint_prompt = gr.Textbox(label="Inpaint Prompt")
323
- run_button = gr.Button(label="Run")
324
- with gr.Accordion("Advanced options", open=False):
325
- box_threshold = gr.Slider(
326
- label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
327
- )
328
- text_threshold = gr.Slider(
329
- label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
330
- )
331
- iou_threshold = gr.Slider(
332
- label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.001
333
- )
334
- inpaint_mode = gr.Dropdown(["merge", "first"], value="merge", label="inpaint_mode")
335
-
336
- with gr.Column():
337
- gallery = gr.outputs.Image(
338
- type="pil",
339
- ).style(full_width=True, full_height=True)
340
-
341
- run_button.click(fn=run_grounded_sam, inputs=[
342
- input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode], outputs=[gallery])
343
-
344
-
345
- block.launch(server_name='0.0.0.0', server_port=args.port, debug=args.debug, share=args.share)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_auto_label.py DELETED
@@ -1,392 +0,0 @@
1
- import gradio as gr
2
- import json
3
- import argparse
4
- import os
5
- import copy
6
-
7
- import numpy as np
8
- import torch
9
- import torchvision
10
- from PIL import Image, ImageDraw, ImageFont
11
- import openai
12
- # Grounding DINO
13
- import GroundingDINO.groundingdino.datasets.transforms as T
14
- from GroundingDINO.groundingdino.models import build_model
15
- from GroundingDINO.groundingdino.util import box_ops
16
- from GroundingDINO.groundingdino.util.slconfig import SLConfig
17
- from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
18
- from transformers import BlipProcessor, BlipForConditionalGeneration
19
- # segment anything
20
- from segment_anything import build_sam, SamPredictor
21
- from segment_anything.utils.amg import remove_small_regions
22
- import cv2
23
- import numpy as np
24
- import matplotlib.pyplot as plt
25
-
26
-
27
- # diffusers
28
- import PIL
29
- import requests
30
- import torch
31
- from io import BytesIO
32
- from huggingface_hub import hf_hub_download
33
- from sys import platform
34
-
35
- #macos
36
- if platform == 'darwin':
37
- import matplotlib
38
- matplotlib.use('agg')
39
-
40
- def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
41
- args = SLConfig.fromfile(model_config_path)
42
- model = build_model(args)
43
- args.device = device
44
-
45
- cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
46
- checkpoint = torch.load(cache_file, map_location='cpu')
47
- log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
48
- print("Model loaded from {} \n => {}".format(cache_file, log))
49
- _ = model.eval()
50
- return model
51
-
52
- def plot_boxes_to_image(image_pil, tgt):
53
- H, W = tgt["size"]
54
- boxes = tgt["boxes"]
55
- labels = tgt["labels"]
56
- assert len(boxes) == len(labels), "boxes and labels must have same length"
57
-
58
- draw = ImageDraw.Draw(image_pil)
59
- mask = Image.new("L", image_pil.size, 0)
60
- mask_draw = ImageDraw.Draw(mask)
61
-
62
- # draw boxes and masks
63
- for box, label in zip(boxes, labels):
64
- # from 0..1 to 0..W, 0..H
65
- box = box * torch.Tensor([W, H, W, H])
66
- # from xywh to xyxy
67
- box[:2] -= box[2:] / 2
68
- box[2:] += box[:2]
69
- # random color
70
- color = tuple(np.random.randint(0, 255, size=3).tolist())
71
- # draw
72
- x0, y0, x1, y1 = box
73
- x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
74
-
75
- draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
76
- # draw.text((x0, y0), str(label), fill=color)
77
-
78
- font = ImageFont.load_default()
79
- if hasattr(font, "getbbox"):
80
- bbox = draw.textbbox((x0, y0), str(label), font)
81
- else:
82
- w, h = draw.textsize(str(label), font)
83
- bbox = (x0, y0, w + x0, y0 + h)
84
- # bbox = draw.textbbox((x0, y0), str(label))
85
- draw.rectangle(bbox, fill=color)
86
- draw.text((x0, y0), str(label), fill="white")
87
-
88
- mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6)
89
-
90
- return image_pil, mask
91
-
92
- def load_image(image_path):
93
- # # load image
94
- # image_pil = Image.open(image_path).convert("RGB") # load image
95
- image_pil = image_path
96
-
97
- transform = T.Compose(
98
- [
99
- T.RandomResize([800], max_size=1333),
100
- T.ToTensor(),
101
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
102
- ]
103
- )
104
- image, _ = transform(image_pil, None) # 3, h, w
105
- return image_pil, image
106
-
107
-
108
- def load_model(model_config_path, model_checkpoint_path, device):
109
- args = SLConfig.fromfile(model_config_path)
110
- args.device = device
111
- model = build_model(args)
112
- checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
113
- load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
114
- _ = model.eval()
115
- return model
116
-
117
-
118
- def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
119
- caption = caption.lower()
120
- caption = caption.strip()
121
- if not caption.endswith("."):
122
- caption = caption + "."
123
- model = model.to(device)
124
- image = image.to(device)
125
- with torch.no_grad():
126
- outputs = model(image[None], captions=[caption])
127
- logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
128
- boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
129
- logits.shape[0]
130
-
131
- # filter output
132
- logits_filt = logits.clone()
133
- boxes_filt = boxes.clone()
134
- filt_mask = logits_filt.max(dim=1)[0] > box_threshold
135
- logits_filt = logits_filt[filt_mask] # num_filt, 256
136
- boxes_filt = boxes_filt[filt_mask] # num_filt, 4
137
- logits_filt.shape[0]
138
-
139
- # get phrase
140
- tokenlizer = model.tokenizer
141
- tokenized = tokenlizer(caption)
142
- # build pred
143
- pred_phrases = []
144
- scores = []
145
- for logit, box in zip(logits_filt, boxes_filt):
146
- pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
147
- if with_logits:
148
- pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
149
- else:
150
- pred_phrases.append(pred_phrase)
151
- scores.append(logit.max().item())
152
-
153
- return boxes_filt, torch.Tensor(scores), pred_phrases
154
-
155
- def show_mask(mask, ax, random_color=False):
156
- if random_color:
157
- color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
158
- else:
159
- color = np.array([30/255, 144/255, 255/255, 0.6])
160
- h, w = mask.shape[-2:]
161
- mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
162
- ax.imshow(mask_image)
163
-
164
- def save_mask_data(output_dir, mask_list, box_list, label_list):
165
- value = 0 # 0 for background
166
-
167
- mask_img = torch.zeros(mask_list.shape[-2:])
168
- for idx, mask in enumerate(mask_list):
169
- mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1
170
- plt.figure(figsize=(10, 10))
171
- plt.imshow(mask_img.numpy())
172
- plt.axis('off')
173
- mask_img_path = os.path.join(output_dir, 'mask.jpg')
174
- plt.savefig(mask_img_path, bbox_inches="tight", dpi=300, pad_inches=0.0)
175
-
176
- json_data = [{
177
- 'value': value,
178
- 'label': 'background'
179
- }]
180
- for label, box in zip(label_list, box_list):
181
- value += 1
182
- name, logit = label.split('(')
183
- logit = logit[:-1] # the last is ')'
184
- json_data.append({
185
- 'value': value,
186
- 'label': name,
187
- 'logit': float(logit),
188
- 'box': box.numpy().tolist(),
189
- })
190
-
191
- mask_json_path = os.path.join(output_dir, 'mask.json')
192
- with open(mask_json_path, 'w') as f:
193
- json.dump(json_data, f)
194
-
195
- return mask_img_path, mask_json_path
196
-
197
- def show_box(box, ax, label):
198
- x0, y0 = box[0], box[1]
199
- w, h = box[2] - box[0], box[3] - box[1]
200
- ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
201
- ax.text(x0, y0, label)
202
-
203
- config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
204
- ckpt_repo_id = "ShilongLiu/GroundingDINO"
205
- ckpt_filenmae = "groundingdino_swint_ogc.pth"
206
- sam_checkpoint='sam_vit_h_4b8939.pth'
207
- output_dir="outputs"
208
- device="cpu"
209
-
210
- processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
211
- blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
212
-
213
- def generate_caption(raw_image):
214
- # unconditional image captioning
215
- inputs = processor(raw_image, return_tensors="pt")
216
- out = blip_model.generate(**inputs)
217
- caption = processor.decode(out[0], skip_special_tokens=True)
218
- return caption
219
-
220
-
221
- def generate_tags(caption, split=',', max_tokens=100, model="gpt-3.5-turbo", openai_key=''):
222
- openai.api_key = openai_key
223
- prompt = [
224
- {
225
- 'role': 'system',
226
- 'content': 'Extract the unique nouns in the caption. Remove all the adjectives. ' + \
227
- f'List the nouns in singular form. Split them by "{split} ". ' + \
228
- f'Caption: {caption}.'
229
- }
230
- ]
231
- response = openai.ChatCompletion.create(model=model, messages=prompt, temperature=0.6, max_tokens=max_tokens)
232
- reply = response['choices'][0]['message']['content']
233
- # sometimes return with "noun: xxx, xxx, xxx"
234
- tags = reply.split(':')[-1].strip()
235
- return tags
236
-
237
- def check_caption(caption, pred_phrases, max_tokens=100, model="gpt-3.5-turbo"):
238
- object_list = [obj.split('(')[0] for obj in pred_phrases]
239
- object_num = []
240
- for obj in set(object_list):
241
- object_num.append(f'{object_list.count(obj)} {obj}')
242
- object_num = ', '.join(object_num)
243
- print(f"Correct object number: {object_num}")
244
-
245
- prompt = [
246
- {
247
- 'role': 'system',
248
- 'content': 'Revise the number in the caption if it is wrong. ' + \
249
- f'Caption: {caption}. ' + \
250
- f'True object number: {object_num}. ' + \
251
- 'Only give the revised caption: '
252
- }
253
- ]
254
- response = openai.ChatCompletion.create(model=model, messages=prompt, temperature=0.6, max_tokens=max_tokens)
255
- reply = response['choices'][0]['message']['content']
256
- # sometimes return with "Caption: xxx, xxx, xxx"
257
- caption = reply.split(':')[-1].strip()
258
- return caption
259
-
260
- def run_grounded_sam(image_path, openai_key, box_threshold, text_threshold, iou_threshold, area_threshold):
261
- assert openai_key, 'Openai key is not found!'
262
-
263
- # make dir
264
- os.makedirs(output_dir, exist_ok=True)
265
- # load image
266
- image_pil, image = load_image(image_path.convert("RGB"))
267
- # load model
268
- model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
269
-
270
- # visualize raw image
271
- image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
272
-
273
- caption = generate_caption(image_pil)
274
- # Currently ", " is better for detecting single tags
275
- # while ". " is a little worse in some case
276
- split = ','
277
- tags = generate_tags(caption, split=split, openai_key=openai_key)
278
-
279
- # run grounding dino model
280
- boxes_filt, scores, pred_phrases = get_grounding_output(
281
- model, image, tags, box_threshold, text_threshold, device=device
282
- )
283
-
284
- size = image_pil.size
285
-
286
- # initialize SAM
287
- predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
288
- image = np.array(image_path)
289
- predictor.set_image(image)
290
-
291
- H, W = size[1], size[0]
292
- for i in range(boxes_filt.size(0)):
293
- boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
294
- boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
295
- boxes_filt[i][2:] += boxes_filt[i][:2]
296
-
297
- boxes_filt = boxes_filt.cpu()
298
- # use NMS to handle overlapped boxes
299
- print(f"Before NMS: {boxes_filt.shape[0]} boxes")
300
- nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
301
- boxes_filt = boxes_filt[nms_idx]
302
- pred_phrases = [pred_phrases[idx] for idx in nms_idx]
303
- print(f"After NMS: {boxes_filt.shape[0]} boxes")
304
- caption = check_caption(caption, pred_phrases)
305
- print(f"Revise caption with number: {caption}")
306
-
307
- transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
308
-
309
- masks, _, _ = predictor.predict_torch(
310
- point_coords = None,
311
- point_labels = None,
312
- boxes = transformed_boxes,
313
- multimask_output = False,
314
- )
315
- # area threshold: remove the mask when area < area_thresh (in pixels)
316
- new_masks = []
317
- for mask in masks:
318
- # reshape to be used in remove_small_regions()
319
- mask = mask.cpu().numpy().squeeze()
320
- mask, _ = remove_small_regions(mask, area_threshold, mode="holes")
321
- mask, _ = remove_small_regions(mask, area_threshold, mode="islands")
322
- new_masks.append(torch.as_tensor(mask).unsqueeze(0))
323
-
324
- masks = torch.stack(new_masks, dim=0)
325
- # masks: [1, 1, 512, 512]
326
- assert sam_checkpoint, 'sam_checkpoint is not found!'
327
-
328
- # draw output image
329
- plt.figure(figsize=(10, 10))
330
- plt.imshow(image)
331
- for mask in masks:
332
- show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
333
- for box, label in zip(boxes_filt, pred_phrases):
334
- show_box(box.numpy(), plt.gca(), label)
335
- plt.axis('off')
336
- image_path = os.path.join(output_dir, "grounding_dino_output.jpg")
337
- plt.savefig(image_path, bbox_inches="tight")
338
- image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
339
-
340
- mask_img_path, _ = save_mask_data('./outputs', masks, boxes_filt, pred_phrases)
341
-
342
- mask_img = cv2.cvtColor(cv2.imread(mask_img_path), cv2.COLOR_BGR2RGB)
343
-
344
- return image_result, mask_img, caption, tags
345
-
346
- if __name__ == "__main__":
347
-
348
- parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
349
- parser.add_argument("--debug", action="store_true", help="using debug mode")
350
- parser.add_argument("--share", action="store_true", help="share the app")
351
- args = parser.parse_args()
352
-
353
- block = gr.Blocks().queue()
354
- with block:
355
- with gr.Row():
356
- with gr.Column():
357
- input_image = gr.Image(source='upload', type="pil")
358
- openai_key = gr.Textbox(label="OpenAI key")
359
-
360
- run_button = gr.Button(label="Run")
361
- with gr.Accordion("Advanced options", open=False):
362
- box_threshold = gr.Slider(
363
- label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
364
- )
365
- text_threshold = gr.Slider(
366
- label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
367
- )
368
- iou_threshold = gr.Slider(
369
- label="IoU Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.001
370
- )
371
- area_threshold = gr.Slider(
372
- label="Area Threshold", minimum=0.0, maximum=2500, value=100, step=10
373
- )
374
-
375
- with gr.Column():
376
- image_caption = gr.Textbox(label="Image Caption")
377
- identified_labels = gr.Textbox(label="Key objects extracted by ChatGPT")
378
- gallery = gr.outputs.Image(
379
- type="pil",
380
- ).style(full_width=True, full_height=True)
381
-
382
- mask_gallary = gr.outputs.Image(
383
- type="pil",
384
- ).style(full_width=True, full_height=True)
385
-
386
-
387
- run_button.click(fn=run_grounded_sam, inputs=[
388
- input_image, openai_key, box_threshold, text_threshold, iou_threshold, area_threshold],
389
- outputs=[gallery, mask_gallary, image_caption, identified_labels])
390
-
391
-
392
- block.launch(server_name='0.0.0.0', server_port=7589, debug=args.debug, share=args.share)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
grounded_sam.ipynb CHANGED
@@ -224,7 +224,9 @@
224
  "outputs": [],
225
  "source": [
226
  "sam_checkpoint = 'sam_vit_h_4b8939.pth'\n",
227
- "sam_predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))"
 
 
228
  ]
229
  },
230
  {
@@ -404,7 +406,7 @@
404
  "metadata": {},
405
  "outputs": [],
406
  "source": [
407
- "transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_xyxy, image_source.shape[:2])\n",
408
  "masks, _, _ = sam_predictor.predict_torch(\n",
409
  " point_coords = None,\n",
410
  " point_labels = None,\n",
 
224
  "outputs": [],
225
  "source": [
226
  "sam_checkpoint = 'sam_vit_h_4b8939.pth'\n",
227
+ "sam = build_sam(checkpoint=sam_checkpoint)\n",
228
+ "sam.to(device=device)\n",
229
+ "sam_predictor = SamPredictor(sam)"
230
  ]
231
  },
232
  {
 
406
  "metadata": {},
407
  "outputs": [],
408
  "source": [
409
+ "transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_xyxy, image_source.shape[:2]).to(device)\n",
410
  "masks, _, _ = sam_predictor.predict_torch(\n",
411
  " point_coords = None,\n",
412
  " point_labels = None,\n",
grounded_sam_demo.py DELETED
@@ -1,217 +0,0 @@
1
- import argparse
2
- import os
3
- import copy
4
-
5
- import numpy as np
6
- import json
7
- import torch
8
- from PIL import Image, ImageDraw, ImageFont
9
-
10
- # Grounding DINO
11
- import GroundingDINO.groundingdino.datasets.transforms as T
12
- from GroundingDINO.groundingdino.models import build_model
13
- from GroundingDINO.groundingdino.util import box_ops
14
- from GroundingDINO.groundingdino.util.slconfig import SLConfig
15
- from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
16
-
17
- # segment anything
18
- from segment_anything import build_sam, SamPredictor
19
- import cv2
20
- import numpy as np
21
- import matplotlib.pyplot as plt
22
-
23
-
24
- def load_image(image_path):
25
- # load image
26
- image_pil = Image.open(image_path).convert("RGB") # load image
27
-
28
- transform = T.Compose(
29
- [
30
- T.RandomResize([800], max_size=1333),
31
- T.ToTensor(),
32
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
33
- ]
34
- )
35
- image, _ = transform(image_pil, None) # 3, h, w
36
- return image_pil, image
37
-
38
-
39
- def load_model(model_config_path, model_checkpoint_path, device):
40
- args = SLConfig.fromfile(model_config_path)
41
- args.device = device
42
- model = build_model(args)
43
- checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
44
- load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
45
- print(load_res)
46
- _ = model.eval()
47
- return model
48
-
49
-
50
- def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
51
- caption = caption.lower()
52
- caption = caption.strip()
53
- if not caption.endswith("."):
54
- caption = caption + "."
55
- model = model.to(device)
56
- image = image.to(device)
57
- with torch.no_grad():
58
- outputs = model(image[None], captions=[caption])
59
- logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
60
- boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
61
- logits.shape[0]
62
-
63
- # filter output
64
- logits_filt = logits.clone()
65
- boxes_filt = boxes.clone()
66
- filt_mask = logits_filt.max(dim=1)[0] > box_threshold
67
- logits_filt = logits_filt[filt_mask] # num_filt, 256
68
- boxes_filt = boxes_filt[filt_mask] # num_filt, 4
69
- logits_filt.shape[0]
70
-
71
- # get phrase
72
- tokenlizer = model.tokenizer
73
- tokenized = tokenlizer(caption)
74
- # build pred
75
- pred_phrases = []
76
- for logit, box in zip(logits_filt, boxes_filt):
77
- pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
78
- if with_logits:
79
- pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
80
- else:
81
- pred_phrases.append(pred_phrase)
82
-
83
- return boxes_filt, pred_phrases
84
-
85
- def show_mask(mask, ax, random_color=False):
86
- if random_color:
87
- color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
88
- else:
89
- color = np.array([30/255, 144/255, 255/255, 0.6])
90
- h, w = mask.shape[-2:]
91
- mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
92
- ax.imshow(mask_image)
93
-
94
-
95
- def show_box(box, ax, label):
96
- x0, y0 = box[0], box[1]
97
- w, h = box[2] - box[0], box[3] - box[1]
98
- ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
99
- ax.text(x0, y0, label)
100
-
101
-
102
- def save_mask_data(output_dir, mask_list, box_list, label_list):
103
- value = 0 # 0 for background
104
-
105
- mask_img = torch.zeros(mask_list.shape[-2:])
106
- for idx, mask in enumerate(mask_list):
107
- mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1
108
- plt.figure(figsize=(10, 10))
109
- plt.imshow(mask_img.numpy())
110
- plt.axis('off')
111
- plt.savefig(os.path.join(output_dir, 'mask.jpg'), bbox_inches="tight", dpi=300, pad_inches=0.0)
112
-
113
- json_data = [{
114
- 'value': value,
115
- 'label': 'background'
116
- }]
117
- for label, box in zip(label_list, box_list):
118
- value += 1
119
- name, logit = label.split('(')
120
- logit = logit[:-1] # the last is ')'
121
- json_data.append({
122
- 'value': value,
123
- 'label': name,
124
- 'logit': float(logit),
125
- 'box': box.numpy().tolist(),
126
- })
127
- with open(os.path.join(output_dir, 'mask.json'), 'w') as f:
128
- json.dump(json_data, f)
129
-
130
-
131
- if __name__ == "__main__":
132
-
133
- parser = argparse.ArgumentParser("Grounded-Segment-Anything Demo", add_help=True)
134
- parser.add_argument("--config", type=str, required=True, help="path to config file")
135
- parser.add_argument(
136
- "--grounded_checkpoint", type=str, required=True, help="path to checkpoint file"
137
- )
138
- parser.add_argument(
139
- "--sam_checkpoint", type=str, required=True, help="path to checkpoint file"
140
- )
141
- parser.add_argument("--input_image", type=str, required=True, help="path to image file")
142
- parser.add_argument("--text_prompt", type=str, required=True, help="text prompt")
143
- parser.add_argument(
144
- "--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
145
- )
146
-
147
- parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
148
- parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
149
-
150
- parser.add_argument("--device", type=str, default="cpu", help="running on cpu only!, default=False")
151
- args = parser.parse_args()
152
-
153
- # cfg
154
- config_file = args.config # change the path of the model config file
155
- grounded_checkpoint = args.grounded_checkpoint # change the path of the model
156
- sam_checkpoint = args.sam_checkpoint
157
- image_path = args.input_image
158
- text_prompt = args.text_prompt
159
- output_dir = args.output_dir
160
- box_threshold = args.box_threshold
161
- text_threshold = args.box_threshold
162
- device = args.device
163
-
164
- # make dir
165
- os.makedirs(output_dir, exist_ok=True)
166
- # load image
167
- image_pil, image = load_image(image_path)
168
- # load model
169
- model = load_model(config_file, grounded_checkpoint, device=device)
170
-
171
- # visualize raw image
172
- image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
173
-
174
- # run grounding dino model
175
- boxes_filt, pred_phrases = get_grounding_output(
176
- model, image, text_prompt, box_threshold, text_threshold, device=device
177
- )
178
-
179
- # initialize SAM
180
- predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
181
- image = cv2.imread(image_path)
182
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
183
- predictor.set_image(image)
184
-
185
- size = image_pil.size
186
- H, W = size[1], size[0]
187
- for i in range(boxes_filt.size(0)):
188
- boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
189
- boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
190
- boxes_filt[i][2:] += boxes_filt[i][:2]
191
-
192
- boxes_filt = boxes_filt.cpu()
193
- transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
194
-
195
- masks, _, _ = predictor.predict_torch(
196
- point_coords = None,
197
- point_labels = None,
198
- boxes = transformed_boxes,
199
- multimask_output = False,
200
- )
201
-
202
- # draw output image
203
- plt.figure(figsize=(10, 10))
204
- plt.imshow(image)
205
- for mask in masks:
206
- show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
207
- for box, label in zip(boxes_filt, pred_phrases):
208
- show_box(box.numpy(), plt.gca(), label)
209
-
210
- plt.axis('off')
211
- plt.savefig(
212
- os.path.join(output_dir, "grounded_sam_output.jpg"),
213
- bbox_inches="tight", dpi=300, pad_inches=0.0
214
- )
215
-
216
- save_mask_data(output_dir, masks, boxes_filt, pred_phrases)
217
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
grounded_sam_inpainting_demo.py DELETED
@@ -1,215 +0,0 @@
1
- import argparse
2
- import os
3
- import copy
4
-
5
- import numpy as np
6
- import torch
7
- from PIL import Image, ImageDraw, ImageFont
8
-
9
- # Grounding DINO
10
- import GroundingDINO.groundingdino.datasets.transforms as T
11
- from GroundingDINO.groundingdino.models import build_model
12
- from GroundingDINO.groundingdino.util import box_ops
13
- from GroundingDINO.groundingdino.util.slconfig import SLConfig
14
- from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
15
-
16
- # segment anything
17
- from segment_anything import build_sam, SamPredictor
18
- import cv2
19
- import numpy as np
20
- import matplotlib.pyplot as plt
21
-
22
-
23
- # diffusers
24
- import PIL
25
- import requests
26
- import torch
27
- from io import BytesIO
28
- from diffusers import StableDiffusionInpaintPipeline
29
-
30
-
31
- def load_image(image_path):
32
- # load image
33
- image_pil = Image.open(image_path).convert("RGB") # load image
34
-
35
- transform = T.Compose(
36
- [
37
- T.RandomResize([800], max_size=1333),
38
- T.ToTensor(),
39
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
40
- ]
41
- )
42
- image, _ = transform(image_pil, None) # 3, h, w
43
- return image_pil, image
44
-
45
-
46
- def load_model(model_config_path, model_checkpoint_path, device):
47
- args = SLConfig.fromfile(model_config_path)
48
- args.device = device
49
- model = build_model(args)
50
- checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
51
- load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
52
- print(load_res)
53
- _ = model.eval()
54
- return model
55
-
56
-
57
- def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
58
- caption = caption.lower()
59
- caption = caption.strip()
60
- if not caption.endswith("."):
61
- caption = caption + "."
62
- model = model.to(device)
63
- image = image.to(device)
64
- with torch.no_grad():
65
- outputs = model(image[None], captions=[caption])
66
- logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
67
- boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
68
- logits.shape[0]
69
-
70
- # filter output
71
- logits_filt = logits.clone()
72
- boxes_filt = boxes.clone()
73
- filt_mask = logits_filt.max(dim=1)[0] > box_threshold
74
- logits_filt = logits_filt[filt_mask] # num_filt, 256
75
- boxes_filt = boxes_filt[filt_mask] # num_filt, 4
76
- logits_filt.shape[0]
77
-
78
- # get phrase
79
- tokenlizer = model.tokenizer
80
- tokenized = tokenlizer(caption)
81
- # build pred
82
- pred_phrases = []
83
- for logit, box in zip(logits_filt, boxes_filt):
84
- pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
85
- if with_logits:
86
- pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
87
- else:
88
- pred_phrases.append(pred_phrase)
89
-
90
- return boxes_filt, pred_phrases
91
-
92
- def show_mask(mask, ax, random_color=False):
93
- if random_color:
94
- color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
95
- else:
96
- color = np.array([30/255, 144/255, 255/255, 0.6])
97
- h, w = mask.shape[-2:]
98
- mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
99
- ax.imshow(mask_image)
100
-
101
-
102
- def show_box(box, ax, label):
103
- x0, y0 = box[0], box[1]
104
- w, h = box[2] - box[0], box[3] - box[1]
105
- ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
106
- ax.text(x0, y0, label)
107
-
108
-
109
- if __name__ == "__main__":
110
-
111
- parser = argparse.ArgumentParser("Grounded-Segment-Anything Demo", add_help=True)
112
- parser.add_argument("--config", type=str, required=True, help="path to config file")
113
- parser.add_argument(
114
- "--grounded_checkpoint", type=str, required=True, help="path to checkpoint file"
115
- )
116
- parser.add_argument(
117
- "--sam_checkpoint", type=str, required=True, help="path to checkpoint file"
118
- )
119
- parser.add_argument("--input_image", type=str, required=True, help="path to image file")
120
- parser.add_argument("--det_prompt", type=str, required=True, help="text prompt")
121
- parser.add_argument("--inpaint_prompt", type=str, required=True, help="inpaint prompt")
122
- parser.add_argument(
123
- "--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
124
- )
125
-
126
- parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
127
- parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
128
- parser.add_argument("--inpaint_mode", type=str, default="first", help="inpaint mode")
129
- parser.add_argument("--device", type=str, default="cpu", help="running on cpu only!, default=False")
130
- args = parser.parse_args()
131
-
132
- # cfg
133
- config_file = args.config # change the path of the model config file
134
- grounded_checkpoint = args.grounded_checkpoint # change the path of the model
135
- sam_checkpoint = args.sam_checkpoint
136
- image_path = args.input_image
137
- det_prompt = args.det_prompt
138
- inpaint_prompt = args.inpaint_prompt
139
- output_dir = args.output_dir
140
- box_threshold = args.box_threshold
141
- text_threshold = args.box_threshold
142
- inpaint_mode = args.inpaint_mode
143
- device = args.device
144
-
145
- # make dir
146
- os.makedirs(output_dir, exist_ok=True)
147
- # load image
148
- image_pil, image = load_image(image_path)
149
- # load model
150
- model = load_model(config_file, grounded_checkpoint, device=device)
151
-
152
- # visualize raw image
153
- image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
154
-
155
- # run grounding dino model
156
- boxes_filt, pred_phrases = get_grounding_output(
157
- model, image, det_prompt, box_threshold, text_threshold, device=device
158
- )
159
-
160
- # initialize SAM
161
- predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
162
- image = cv2.imread(image_path)
163
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
164
- predictor.set_image(image)
165
-
166
- size = image_pil.size
167
- H, W = size[1], size[0]
168
- for i in range(boxes_filt.size(0)):
169
- boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
170
- boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
171
- boxes_filt[i][2:] += boxes_filt[i][:2]
172
-
173
- boxes_filt = boxes_filt.cpu()
174
- transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
175
-
176
- masks, _, _ = predictor.predict_torch(
177
- point_coords = None,
178
- point_labels = None,
179
- boxes = transformed_boxes,
180
- multimask_output = False,
181
- )
182
-
183
- # masks: [1, 1, 512, 512]
184
-
185
- # inpainting pipeline
186
- if inpaint_mode == 'merge':
187
- masks = torch.sum(masks, dim=0).unsqueeze(0)
188
- masks = torch.where(masks > 0, True, False)
189
- else:
190
- mask = masks[0][0].cpu().numpy() # simply choose the first mask, which will be refine in the future release
191
- mask_pil = Image.fromarray(mask)
192
- image_pil = Image.fromarray(image)
193
-
194
- pipe = StableDiffusionInpaintPipeline.from_pretrained(
195
- "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
196
- )
197
- pipe = pipe.to("cuda")
198
-
199
- image_pil = image_pil.resize((512, 512))
200
- mask_pil = mask_pil.resize((512, 512))
201
- # prompt = "A sofa, high quality, detailed"
202
- image = pipe(prompt=inpaint_prompt, image=image_pil, mask_image=mask_pil).images[0]
203
- image = image.resize(size)
204
- image.save(os.path.join(output_dir, "grounded_sam_inpainting_output.jpg"))
205
-
206
- # draw output image
207
- # plt.figure(figsize=(10, 10))
208
- # plt.imshow(image)
209
- # for mask in masks:
210
- # show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
211
- # for box, label in zip(boxes_filt, pred_phrases):
212
- # show_box(box.numpy(), plt.gca(), label)
213
- # plt.axis('off')
214
- # plt.savefig(os.path.join(output_dir, "grounded_sam_output.jpg"), bbox_inches="tight")
215
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
grounded_sam_whisper_demo.py DELETED
@@ -1,258 +0,0 @@
1
- import argparse
2
- import os
3
- import copy
4
-
5
- import numpy as np
6
- import json
7
- import torch
8
- import torchvision
9
- from PIL import Image, ImageDraw, ImageFont
10
-
11
- # Grounding DINO
12
- import GroundingDINO.groundingdino.datasets.transforms as T
13
- from GroundingDINO.groundingdino.models import build_model
14
- from GroundingDINO.groundingdino.util import box_ops
15
- from GroundingDINO.groundingdino.util.slconfig import SLConfig
16
- from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
17
-
18
- # segment anything
19
- from segment_anything import build_sam, SamPredictor
20
- import cv2
21
- import numpy as np
22
- import matplotlib.pyplot as plt
23
-
24
- # whisper
25
- import whisper
26
-
27
-
28
- def load_image(image_path):
29
- # load image
30
- image_pil = Image.open(image_path).convert("RGB") # load image
31
-
32
- transform = T.Compose(
33
- [
34
- T.RandomResize([800], max_size=1333),
35
- T.ToTensor(),
36
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
37
- ]
38
- )
39
- image, _ = transform(image_pil, None) # 3, h, w
40
- return image_pil, image
41
-
42
-
43
- def load_model(model_config_path, model_checkpoint_path, device):
44
- args = SLConfig.fromfile(model_config_path)
45
- args.device = device
46
- model = build_model(args)
47
- checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
48
- load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
49
- print(load_res)
50
- _ = model.eval()
51
- return model
52
-
53
-
54
- def get_grounding_output(model, image, caption, box_threshold, text_threshold,device="cpu"):
55
- caption = caption.lower()
56
- caption = caption.strip()
57
- if not caption.endswith("."):
58
- caption = caption + "."
59
- model = model.to(device)
60
- image = image.to(device)
61
- with torch.no_grad():
62
- outputs = model(image[None], captions=[caption])
63
- logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
64
- boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
65
- logits.shape[0]
66
-
67
- # filter output
68
- logits_filt = logits.clone()
69
- boxes_filt = boxes.clone()
70
- filt_mask = logits_filt.max(dim=1)[0] > box_threshold
71
- logits_filt = logits_filt[filt_mask] # num_filt, 256
72
- boxes_filt = boxes_filt[filt_mask] # num_filt, 4
73
- logits_filt.shape[0]
74
-
75
- # get phrase
76
- tokenlizer = model.tokenizer
77
- tokenized = tokenlizer(caption)
78
- # build pred
79
- pred_phrases = []
80
- scores = []
81
- for logit, box in zip(logits_filt, boxes_filt):
82
- pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
83
- pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
84
- scores.append(logit.max().item())
85
-
86
- return boxes_filt, torch.Tensor(scores), pred_phrases
87
-
88
- def show_mask(mask, ax, random_color=False):
89
- if random_color:
90
- color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
91
- else:
92
- color = np.array([30/255, 144/255, 255/255, 0.6])
93
- h, w = mask.shape[-2:]
94
- mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
95
- ax.imshow(mask_image)
96
-
97
-
98
- def show_box(box, ax, label):
99
- x0, y0 = box[0], box[1]
100
- w, h = box[2] - box[0], box[3] - box[1]
101
- ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
102
- ax.text(x0, y0, label)
103
-
104
-
105
- def save_mask_data(output_dir, mask_list, box_list, label_list):
106
- value = 0 # 0 for background
107
-
108
- mask_img = torch.zeros(mask_list.shape[-2:])
109
- for idx, mask in enumerate(mask_list):
110
- mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1
111
- plt.figure(figsize=(10, 10))
112
- plt.imshow(mask_img.numpy())
113
- plt.axis('off')
114
- plt.savefig(os.path.join(output_dir, 'mask.jpg'), bbox_inches="tight", dpi=300, pad_inches=0.0)
115
-
116
- json_data = [{
117
- 'value': value,
118
- 'label': 'background'
119
- }]
120
- for label, box in zip(label_list, box_list):
121
- value += 1
122
- name, logit = label.split('(')
123
- logit = logit[:-1] # the last is ')'
124
- json_data.append({
125
- 'value': value,
126
- 'label': name,
127
- 'logit': float(logit),
128
- 'box': box.numpy().tolist(),
129
- })
130
- with open(os.path.join(output_dir, 'mask.json'), 'w') as f:
131
- json.dump(json_data, f)
132
-
133
-
134
- def speech_recognition(speech_file, model):
135
- # whisper
136
- # load audio and pad/trim it to fit 30 seconds
137
- audio = whisper.load_audio(speech_file)
138
- audio = whisper.pad_or_trim(audio)
139
-
140
- # make log-Mel spectrogram and move to the same device as the model
141
- mel = whisper.log_mel_spectrogram(audio).to(model.device)
142
-
143
- # detect the spoken language
144
- _, probs = model.detect_language(mel)
145
- speech_language = max(probs, key=probs.get)
146
-
147
- # decode the audio
148
- options = whisper.DecodingOptions()
149
- result = whisper.decode(model, mel, options)
150
-
151
- # print the recognized text
152
- speech_text = result.text
153
- return speech_text, speech_language
154
-
155
- if __name__ == "__main__":
156
-
157
- parser = argparse.ArgumentParser("Grounded-Segment-Anything Demo", add_help=True)
158
- parser.add_argument("--config", type=str, required=True, help="path to config file")
159
- parser.add_argument(
160
- "--grounded_checkpoint", type=str, required=True, help="path to checkpoint file"
161
- )
162
- parser.add_argument(
163
- "--sam_checkpoint", type=str, required=True, help="path to checkpoint file"
164
- )
165
- parser.add_argument("--input_image", type=str, required=True, help="path to image file")
166
- parser.add_argument("--speech_file", type=str, required=True, help="speech file")
167
- parser.add_argument(
168
- "--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
169
- )
170
-
171
- parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
172
- parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
173
- parser.add_argument("--iou_threshold", type=float, default=0.5, help="iou threshold")
174
-
175
- parser.add_argument("--device", type=str, default="cpu", help="running on cpu only!, default=False")
176
- args = parser.parse_args()
177
-
178
- # cfg
179
- config_file = args.config # change the path of the model config file
180
- grounded_checkpoint = args.grounded_checkpoint # change the path of the model
181
- sam_checkpoint = args.sam_checkpoint
182
- image_path = args.input_image
183
- output_dir = args.output_dir
184
- box_threshold = args.box_threshold
185
- text_threshold = args.text_threshold
186
- iou_threshold = args.iou_threshold
187
- device = args.device
188
-
189
- # load speech
190
- whisper_model = whisper.load_model("base")
191
- speech_text, speech_language = speech_recognition(args.speech_file, whisper_model)
192
- print(f"speech_text: {speech_text}")
193
- print(f"speech_language: {speech_language}")
194
-
195
- # make dir
196
- os.makedirs(output_dir, exist_ok=True)
197
- # load image
198
- image_pil, image = load_image(image_path)
199
- # load model
200
- model = load_model(config_file, grounded_checkpoint, device=device)
201
-
202
- # visualize raw image
203
- image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
204
-
205
- # run grounding dino model
206
- text_prompt = speech_text
207
- boxes_filt, scores, pred_phrases = get_grounding_output(
208
- model, image, text_prompt, box_threshold, text_threshold, device=device
209
- )
210
-
211
- # initialize SAM
212
- predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(args.device))
213
- image = cv2.imread(image_path)
214
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
215
- predictor.set_image(image)
216
-
217
- size = image_pil.size
218
- H, W = size[1], size[0]
219
- for i in range(boxes_filt.size(0)):
220
- boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
221
- boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
222
- boxes_filt[i][2:] += boxes_filt[i][:2]
223
-
224
- boxes_filt = boxes_filt.cpu()
225
- # use NMS to handle overlapped boxes
226
- print(f"Before NMS: {boxes_filt.shape[0]} boxes")
227
- nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
228
- boxes_filt = boxes_filt[nms_idx]
229
- pred_phrases = [pred_phrases[idx] for idx in nms_idx]
230
- print(f"After NMS: {boxes_filt.shape[0]} boxes")
231
-
232
- transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
233
-
234
- masks, _, _ = predictor.predict_torch(
235
- point_coords = None,
236
- point_labels = None,
237
- boxes = transformed_boxes.to(args.device),
238
- multimask_output = False,
239
- )
240
-
241
- # draw output image
242
- plt.figure(figsize=(10, 10))
243
- plt.imshow(image)
244
- for mask in masks:
245
- show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
246
- for box, label in zip(boxes_filt, pred_phrases):
247
- show_box(box.numpy(), plt.gca(), label)
248
-
249
- plt.title(speech_text)
250
- plt.axis('off')
251
- plt.savefig(
252
- os.path.join(output_dir, "grounded_sam_whisper_output.jpg"),
253
- bbox_inches="tight", dpi=300, pad_inches=0.0
254
- )
255
-
256
-
257
- save_mask_data(output_dir, masks, boxes_filt, pred_phrases)
258
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
grounded_sam_whisper_inpainting_demo.py DELETED
@@ -1,281 +0,0 @@
1
- import argparse
2
- import os
3
- from warnings import warn
4
-
5
- import numpy as np
6
- import torch
7
- from PIL import Image, ImageDraw, ImageFont
8
-
9
- # Grounding DINO
10
- import GroundingDINO.groundingdino.datasets.transforms as T
11
- from GroundingDINO.groundingdino.models import build_model
12
- from GroundingDINO.groundingdino.util import box_ops
13
- from GroundingDINO.groundingdino.util.slconfig import SLConfig
14
- from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
15
-
16
- # segment anything
17
- from segment_anything import build_sam, SamPredictor
18
- import cv2
19
- import numpy as np
20
- import matplotlib.pyplot as plt
21
-
22
-
23
- # diffusers
24
- import PIL
25
- import requests
26
- import torch
27
- from io import BytesIO
28
- from diffusers import StableDiffusionInpaintPipeline
29
-
30
- # whisper
31
- import whisper
32
-
33
- # ChatGPT
34
- import openai
35
-
36
-
37
- def load_image(image_path):
38
- # load image
39
- image_pil = Image.open(image_path).convert("RGB") # load image
40
-
41
- transform = T.Compose(
42
- [
43
- T.RandomResize([800], max_size=1333),
44
- T.ToTensor(),
45
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
46
- ]
47
- )
48
- image, _ = transform(image_pil, None) # 3, h, w
49
- return image_pil, image
50
-
51
-
52
- def load_model(model_config_path, model_checkpoint_path, device):
53
- args = SLConfig.fromfile(model_config_path)
54
- args.device = device
55
- model = build_model(args)
56
- checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
57
- load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
58
- print(load_res)
59
- _ = model.eval()
60
- return model
61
-
62
-
63
- def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
64
- caption = caption.lower()
65
- caption = caption.strip()
66
- if not caption.endswith("."):
67
- caption = caption + "."
68
- model = model.to(device)
69
- image = image.to(device)
70
- with torch.no_grad():
71
- outputs = model(image[None], captions=[caption])
72
- logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
73
- boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
74
- logits.shape[0]
75
-
76
- # filter output
77
- logits_filt = logits.clone()
78
- boxes_filt = boxes.clone()
79
- filt_mask = logits_filt.max(dim=1)[0] > box_threshold
80
- logits_filt = logits_filt[filt_mask] # num_filt, 256
81
- boxes_filt = boxes_filt[filt_mask] # num_filt, 4
82
- logits_filt.shape[0]
83
-
84
- # get phrase
85
- tokenlizer = model.tokenizer
86
- tokenized = tokenlizer(caption)
87
- # build pred
88
- pred_phrases = []
89
- for logit, box in zip(logits_filt, boxes_filt):
90
- pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
91
- if with_logits:
92
- pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
93
- else:
94
- pred_phrases.append(pred_phrase)
95
-
96
- return boxes_filt, pred_phrases
97
-
98
- def show_mask(mask, ax, random_color=False):
99
- if random_color:
100
- color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
101
- else:
102
- color = np.array([30/255, 144/255, 255/255, 0.6])
103
- h, w = mask.shape[-2:]
104
- mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
105
- ax.imshow(mask_image)
106
-
107
-
108
- def show_box(box, ax, label):
109
- x0, y0 = box[0], box[1]
110
- w, h = box[2] - box[0], box[3] - box[1]
111
- ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
112
- ax.text(x0, y0, label)
113
-
114
-
115
- def speech_recognition(speech_file, model):
116
- # whisper
117
- # load audio and pad/trim it to fit 30 seconds
118
- audio = whisper.load_audio(speech_file)
119
- audio = whisper.pad_or_trim(audio)
120
-
121
- # make log-Mel spectrogram and move to the same device as the model
122
- mel = whisper.log_mel_spectrogram(audio).to(model.device)
123
-
124
- # detect the spoken language
125
- _, probs = model.detect_language(mel)
126
- speech_language = max(probs, key=probs.get)
127
-
128
- # decode the audio
129
- options = whisper.DecodingOptions()
130
- result = whisper.decode(model, mel, options)
131
-
132
- # print the recognized text
133
- speech_text = result.text
134
- return speech_text, speech_language
135
-
136
-
137
- def filter_prompts_with_chatgpt(caption, max_tokens=100, model="gpt-3.5-turbo"):
138
- prompt = [
139
- {
140
- 'role': 'system',
141
- 'content': f"Extract the main object to be replaced and marked it as 'main_object', " + \
142
- f"Extract the remaining part as 'other prompt' " + \
143
- f"Return (main_object, other prompt)" + \
144
- f'Given caption: {caption}.'
145
- }
146
- ]
147
- response = openai.ChatCompletion.create(model=model, messages=prompt, temperature=0.6, max_tokens=max_tokens)
148
- reply = response['choices'][0]['message']['content']
149
- try:
150
- det_prompt, inpaint_prompt = reply.split('\n')[0].split(':')[-1].strip(), reply.split('\n')[1].split(':')[-1].strip()
151
- except:
152
- warn(f"Failed to extract tags from caption") # use caption as det_prompt, inpaint_prompt
153
- det_prompt, inpaint_prompt = caption, caption
154
- return det_prompt, inpaint_prompt
155
-
156
-
157
- if __name__ == "__main__":
158
-
159
- parser = argparse.ArgumentParser("Grounded-Segment-Anything Demo", add_help=True)
160
- parser.add_argument("--config", type=str, required=True, help="path to config file")
161
- parser.add_argument(
162
- "--grounded_checkpoint", type=str, required=True, help="path to checkpoint file"
163
- )
164
- parser.add_argument(
165
- "--sam_checkpoint", type=str, required=True, help="path to checkpoint file"
166
- )
167
- parser.add_argument("--input_image", type=str, required=True, help="path to image file")
168
- parser.add_argument(
169
- "--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
170
- )
171
- parser.add_argument("--det_speech_file", type=str, help="grounding speech file")
172
- parser.add_argument("--inpaint_speech_file", type=str, help="inpaint speech file")
173
- parser.add_argument("--prompt_speech_file", type=str, help="prompt speech file, no need to provide det_speech_file")
174
- parser.add_argument("--enable_chatgpt", action="store_true", help="enable chatgpt")
175
- parser.add_argument("--openai_key", type=str, help="key for chatgpt")
176
- parser.add_argument("--openai_proxy", default=None, type=str, help="proxy for chatgpt")
177
- parser.add_argument("--whisper_model", type=str, default="small", help="whisper model version: tiny, base, small, medium, large")
178
- parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
179
- parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
180
- parser.add_argument("--inpaint_mode", type=str, default="first", help="inpaint mode")
181
- parser.add_argument("--device", type=str, default="cpu", help="running on cpu only!, default=False")
182
- parser.add_argument("--prompt_extra", type=str, default=" high resolution, real scene", help="extra prompt for inpaint")
183
- args = parser.parse_args()
184
-
185
- # cfg
186
- config_file = args.config # change the path of the model config file
187
- grounded_checkpoint = args.grounded_checkpoint # change the path of the model
188
- sam_checkpoint = args.sam_checkpoint
189
- image_path = args.input_image
190
-
191
- output_dir = args.output_dir
192
- box_threshold = args.box_threshold
193
- text_threshold = args.box_threshold
194
- inpaint_mode = args.inpaint_mode
195
- device = args.device
196
-
197
- # make dir
198
- os.makedirs(output_dir, exist_ok=True)
199
- # load image
200
- image_pil, image = load_image(image_path)
201
- # load model
202
- model = load_model(config_file, grounded_checkpoint, device=device)
203
-
204
- # visualize raw image
205
- image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
206
-
207
- # recognize speech
208
- whisper_model = whisper.load_model(args.whisper_model)
209
-
210
- if args.enable_chatgpt:
211
- openai.api_key = args.openai_key
212
- if args.openai_proxy:
213
- openai.proxy = {"http": args.openai_proxy, "https": args.openai_proxy}
214
- speech_text, _ = speech_recognition(args.prompt_speech_file, whisper_model)
215
- det_prompt, inpaint_prompt = filter_prompts_with_chatgpt(speech_text)
216
- inpaint_prompt += args.prompt_extra
217
- print(f"det_prompt: {det_prompt}, inpaint_prompt: {inpaint_prompt}")
218
- else:
219
- det_prompt, det_speech_language = speech_recognition(args.det_speech_file, whisper_model)
220
- inpaint_prompt, inpaint_speech_language = speech_recognition(args.inpaint_speech_file, whisper_model)
221
- print(f"det_prompt: {det_prompt}, using language: {det_speech_language}")
222
- print(f"inpaint_prompt: {inpaint_prompt}, using language: {inpaint_speech_language}")
223
-
224
- # run grounding dino model
225
- boxes_filt, pred_phrases = get_grounding_output(
226
- model, image, det_prompt, box_threshold, text_threshold, device=device
227
- )
228
-
229
- # initialize SAM
230
- predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
231
- image = cv2.imread(image_path)
232
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
233
- predictor.set_image(image)
234
-
235
- size = image_pil.size
236
- H, W = size[1], size[0]
237
- for i in range(boxes_filt.size(0)):
238
- boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
239
- boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
240
- boxes_filt[i][2:] += boxes_filt[i][:2]
241
-
242
- boxes_filt = boxes_filt.cpu()
243
- transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
244
-
245
- masks, _, _ = predictor.predict_torch(
246
- point_coords = None,
247
- point_labels = None,
248
- boxes = transformed_boxes,
249
- multimask_output = False,
250
- )
251
-
252
- # masks: [1, 1, 512, 512]
253
-
254
- # inpainting pipeline
255
- if inpaint_mode == 'merge':
256
- masks = torch.sum(masks, dim=0).unsqueeze(0)
257
- masks = torch.where(masks > 0, True, False)
258
- else:
259
- mask = masks[0][0].cpu().numpy() # simply choose the first mask, which will be refine in the future release
260
- mask_pil = Image.fromarray(mask)
261
- image_pil = Image.fromarray(image)
262
-
263
- pipe = StableDiffusionInpaintPipeline.from_pretrained(
264
- "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
265
- )
266
- pipe = pipe.to("cuda")
267
-
268
- # prompt = "A sofa, high quality, detailed"
269
- image = pipe(prompt=inpaint_prompt, image=image_pil, mask_image=mask_pil).images[0]
270
- image.save(os.path.join(output_dir, "grounded_sam_inpainting_output.jpg"))
271
-
272
- # draw output image
273
- # plt.figure(figsize=(10, 10))
274
- # plt.imshow(image)
275
- # for mask in masks:
276
- # show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
277
- # for box, label in zip(boxes_filt, pred_phrases):
278
- # show_box(box.numpy(), plt.gca(), label)
279
- # plt.axis('off')
280
- # plt.savefig(os.path.join(output_dir, "grounded_sam_output.jpg"), bbox_inches="tight")
281
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -21,3 +21,12 @@ transformers
21
  yapf
22
  numba
23
  segment_anything
 
 
 
 
 
 
 
 
 
 
21
  yapf
22
  numba
23
  segment_anything
24
+
25
+ # ftfy
26
+ # uuid
27
+ # psutil
28
+ # facexlib
29
+ lama-cleaner==0.25.0
30
+ # tensorflow
31
+ # easydict
32
+