liuyizhang commited on
Commit
779c33a
·
1 Parent(s): f2bd037

update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -14
app.py CHANGED
@@ -326,9 +326,10 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
326
  if (task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw:
327
  pass
328
  else:
329
- assert text_prompt, 'text_prompt is not found!'
330
 
331
- logger.info(f'run_grounded_sam_1_')
 
332
 
333
  # make dir
334
  os.makedirs(output_dir, exist_ok=True)
@@ -337,9 +338,7 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
337
  input_mask = np.array(input_mask_pil.convert("L"))
338
 
339
  image_pil, image = load_image(input_image['image'].convert("RGB"))
340
-
341
- file_temp = int(time.time())
342
-
343
  # visualize raw image
344
  # image_pil.save(os.path.join(output_dir, f"raw_image_{file_temp}.jpg"))
345
 
@@ -376,7 +375,7 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
376
  os.remove(image_path)
377
  output_images.append(detection_image_result)
378
 
379
- logger.info(f'run_grounded_sam_2_')
380
  if task_type == 'segment' or ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_segment):
381
  image = np.array(input_image['image'])
382
  sam_predictor.set_image(image)
@@ -412,15 +411,15 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
412
  os.remove(image_path)
413
  output_images.append(segment_image_result)
414
 
415
- logger.info(f'run_grounded_sam_3_')
416
  if task_type == 'detection' or task_type == 'segment':
417
- logger.info(f'run_grounded_sam_9_{task_type}_')
418
  return output_images
419
  elif task_type == 'inpainting' or task_type == 'remove':
420
  if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
421
  task_type = 'remove'
422
 
423
- logger.info(f'run_grounded_sam_4_{task_type}_')
424
  if mask_source_radio == mask_source_draw:
425
  mask_pil = input_mask_pil
426
  mask = input_mask
@@ -487,12 +486,12 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
487
  image_inpainting.save(image_path)
488
  image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
489
  os.remove(image_path)
490
- logger.info(f'run_grounded_sam_9_{task_type}_')
491
  output_images.append(image_result)
492
  return output_images
493
  else:
494
  logger.info(f"task_type:{task_type} error!")
495
- logger.info(f'run_grounded_sam_9_9_')
496
  return output_images
497
 
498
  def change_radio_display(task_type, mask_source_radio):
@@ -525,15 +524,15 @@ if __name__ == "__main__":
525
  mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
526
  value=mask_source_segment, label="Mask from",
527
  interactive=True, visible=False)
528
- text_prompt = gr.Textbox(label="Detection Prompt", placeholder="Cannot be empty")
529
  inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
530
  run_button = gr.Button(label="Run")
531
  with gr.Accordion("Advanced options", open=False):
532
  box_threshold = gr.Slider(
533
- label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
534
  )
535
  text_threshold = gr.Slider(
536
- label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
537
  )
538
  iou_threshold = gr.Slider(
539
  label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.001
 
326
  if (task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw:
327
  pass
328
  else:
329
+ assert text_prompt, f'text_prompt for {task_type} is not found!'
330
 
331
+ file_temp = int(time.time())
332
+ logger.info(f'run_grounded_sam_[{file_temp}]_{task_type}_[{text_prompt}]_1_')
333
 
334
  # make dir
335
  os.makedirs(output_dir, exist_ok=True)
 
338
  input_mask = np.array(input_mask_pil.convert("L"))
339
 
340
  image_pil, image = load_image(input_image['image'].convert("RGB"))
341
+
 
 
342
  # visualize raw image
343
  # image_pil.save(os.path.join(output_dir, f"raw_image_{file_temp}.jpg"))
344
 
 
375
  os.remove(image_path)
376
  output_images.append(detection_image_result)
377
 
378
+ logger.info(f'run_grounded_sam_[{file_temp}]_{task_type}_2_')
379
  if task_type == 'segment' or ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_segment):
380
  image = np.array(input_image['image'])
381
  sam_predictor.set_image(image)
 
411
  os.remove(image_path)
412
  output_images.append(segment_image_result)
413
 
414
+ logger.info(f'run_grounded_sam_[{file_temp}]_{task_type}_3_')
415
  if task_type == 'detection' or task_type == 'segment':
416
+ logger.info(f'run_grounded_sam_[{file_temp}]_{task_type}_9_')
417
  return output_images
418
  elif task_type == 'inpainting' or task_type == 'remove':
419
  if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
420
  task_type = 'remove'
421
 
422
+ logger.info(f'run_grounded_sam_[{file_temp}]_{task_type}_4_')
423
  if mask_source_radio == mask_source_draw:
424
  mask_pil = input_mask_pil
425
  mask = input_mask
 
486
  image_inpainting.save(image_path)
487
  image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
488
  os.remove(image_path)
489
+ logger.info(f'run_grounded_sam_[{file_temp}]_{task_type}_9_')
490
  output_images.append(image_result)
491
  return output_images
492
  else:
493
  logger.info(f"task_type:{task_type} error!")
494
+ logger.info(f'run_grounded_sam_[{file_temp}]_9_9_')
495
  return output_images
496
 
497
  def change_radio_display(task_type, mask_source_radio):
 
524
  mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
525
  value=mask_source_segment, label="Mask from",
526
  interactive=True, visible=False)
527
+ text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each name with '.' , Like this: cat . dog . chair ]", placeholder="Cannot be empty")
528
  inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
529
  run_button = gr.Button(label="Run")
530
  with gr.Accordion("Advanced options", open=False):
531
  box_threshold = gr.Slider(
532
+ label="Box Threshold", minimum=0.0, maximum=1.0, value=0.6, step=0.001
533
  )
534
  text_threshold = gr.Slider(
535
+ label="Text Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.001
536
  )
537
  iou_threshold = gr.Slider(
538
  label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.001