xuan2k commited on
Commit
cde08ad
·
1 Parent(s): 8511bb4

update thesis demo with SAM

Browse files
Files changed (2) hide show
  1. gradio_test.py +69 -0
  2. test.py +92 -207
gradio_test.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ block = gr.Blocks(
7
+ title="SAM and others",
8
+ # theme="shivi/calm_seafoam@>=0.0.1,<1.0.0",
9
+ )
10
+ colors = [(255, 0, 0), (0, 255, 0)]
11
+ markers = [1, 5]
12
+
13
+ def get_point(img, sel_pix, evt: gr.SelectData):
14
+ img = np.array(img, dtype=np.uint8)
15
+ sel_pix.append(evt.index)
16
+ # draw points
17
+
18
+ print(sel_pix)
19
+ for point in sel_pix:
20
+ cv2.drawMarker(img, point, colors[0], markerType=markers[0], markerSize=6, thickness=2)
21
+ return Image.fromarray(img).convert("RGB")
22
+
23
+
24
+
25
+ def undo_button(orig_img, sel_pix):
26
+ temp = orig_img.copy()
27
+ temp = np.array(temp, dtype=np.uint8)
28
+ if len(sel_pix) != 0:
29
+ sel_pix.pop()
30
+ for point in sel_pix:
31
+ cv2.drawMarker(temp, point, colors[0], markerType=markers[0], markerSize=6, thickness=2)
32
+ return Image.fromarray(temp).convert("RGB")
33
+
34
+ def toggle_button(orig_img, mode):
35
+ print(mode)
36
+ if mode:
37
+ ret = gr.Image(value= orig_img,elem_id="image_upload", type='pil', label="Upload", height=512, tool = "editor")# tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
38
+ else:
39
+ ret = gr.Image(value = orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
40
+ mode = not mode
41
+ return ret, mode
42
+
43
+ def store_img(img):
44
+ print("call for store")
45
+ return img, [] # when new image is uploaded, `selected_points` should be empty
46
+
47
+ with block:
48
+ selected_points = gr.State([])
49
+ original_image = gr.State()
50
+ mode = gr.State(True)
51
+ input_image = gr.Image(elem_id="image_upload", type='pil', label="Upload", height=512,)# tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
52
+ undo = gr.Button("undo mode", visible=True)
53
+ toggle = gr.Button("toggle mode", visible=True)
54
+ input_image.upload(
55
+ store_img,
56
+ [input_image],
57
+ [original_image, selected_points]
58
+ )
59
+
60
+ input_image.select(
61
+ get_point,
62
+ [input_image, selected_points],
63
+ [input_image]
64
+ )
65
+
66
+ undo.click(fn=undo_button, inputs=[original_image, selected_points], outputs=[input_image])
67
+ toggle.click(fn=toggle_button, inputs=[original_image, mode], outputs=[input_image, mode])
68
+
69
+ block.launch()
test.py CHANGED
@@ -123,6 +123,44 @@ ram_model = None
123
  kosmos_model = None
124
  kosmos_processor = None
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
127
  args = SLConfig.fromfile(model_config_path)
128
  model = build_model(args)
@@ -290,13 +328,6 @@ def set_device(args):
290
  device = 'cpu'
291
  print(f'device={device}')
292
 
293
- def load_groundingdino_model(device):
294
- # initialize groundingdino model
295
- global groundingdino_model
296
- logger.info(f"initialize groundingdino model...")
297
- groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae, device=device) #'cpu')
298
- logger.info(f"initialize groundingdino model...{type(groundingdino_model)}")
299
-
300
  def get_sam_vit_h_4b8939():
301
  if not os.path.exists('./sam_vit_h_4b8939.pth'):
302
  logger.info(f"get sam_vit_h_4b8939.pth...")
@@ -327,16 +358,6 @@ def load_sd_model(device):
327
  )
328
  sd_model = sd_model.to(device)
329
 
330
- def load_lama_cleaner_model(device):
331
- # initialize lama_cleaner
332
- global lama_cleaner_model
333
- logger.info(f"initialize lama_cleaner...")
334
-
335
- lama_cleaner_model = ModelManager(
336
- name='lama',
337
- device=device,
338
- )
339
-
340
  def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
341
  try:
342
  logger.info(f'_______lama_cleaner_process_______1____')
@@ -413,41 +434,6 @@ def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
413
  image = None
414
  return image
415
 
416
- class Ram_Predictor(RamPredictor):
417
- def __init__(self, config, device='cpu'):
418
- self.config = config
419
- self.device = torch.device(device)
420
- self._build_model()
421
-
422
- def _build_model(self):
423
- self.model = RamModel(**self.config.model).to(self.device)
424
- if self.config.load_from is not None:
425
- self.model.load_state_dict(torch.load(self.config.load_from, map_location=self.device))
426
- self.model.train()
427
-
428
- def load_ram_model(device):
429
- # load ram model
430
- global ram_model
431
- if os.environ.get('IS_MY_DEBUG') is not None:
432
- return
433
- model_path = "./checkpoints/ram_epoch12.pth"
434
- ram_config = dict(
435
- model=dict(
436
- pretrained_model_name_or_path='bert-base-uncased',
437
- load_pretrained_weights=False,
438
- num_transformer_layer=2,
439
- input_feature_size=256,
440
- output_feature_size=768,
441
- cls_feature_size=512,
442
- num_relation_classes=56,
443
- pred_type='attention',
444
- loss_type='multi_label_ce',
445
- ),
446
- load_from=model_path,
447
- )
448
- ram_config = mmengine_Config(ram_config)
449
- ram_model = Ram_Predictor(ram_config, device)
450
-
451
  # visualization
452
  def draw_selected_mask(mask, draw):
453
  color = (255, 0, 0, 153)
@@ -524,52 +510,6 @@ def concatenate_images_vertical(image1, image2):
524
 
525
  return new_image
526
 
527
- def relate_anything(input_image, k):
528
- logger.info(f'relate_anything_1_{input_image.size}_')
529
- w, h = input_image.size
530
- max_edge = 1500
531
- if w > max_edge or h > max_edge:
532
- ratio = max(w, h) / max_edge
533
- new_size = (int(w / ratio), int(h / ratio))
534
- input_image.thumbnail(new_size)
535
-
536
- logger.info(f'relate_anything_2_')
537
- # load image
538
- pil_image = input_image.convert('RGBA')
539
- image = np.array(input_image)
540
- sam_masks = sam_mask_generator.generate(image)
541
- filtered_masks = sort_and_deduplicate(sam_masks)
542
-
543
- logger.info(f'relate_anything_3_')
544
- feat_list = []
545
- for fm in filtered_masks:
546
- feat = torch.Tensor(fm['feat']).unsqueeze(0).unsqueeze(0).to(device)
547
- feat_list.append(feat)
548
- feat = torch.cat(feat_list, dim=1).to(device)
549
- matrix_output, rel_triplets = ram_model.predict(feat)
550
-
551
- logger.info(f'relate_anything_4_')
552
- pil_image_list = []
553
- for i, rel in enumerate(rel_triplets[:k]):
554
- s,o,r = int(rel[0]),int(rel[1]),int(rel[2])
555
- relation = relation_classes[r]
556
-
557
- mask_image = Image.new('RGBA', pil_image.size, color=(0, 0, 0, 0))
558
- mask_draw = ImageDraw.Draw(mask_image)
559
-
560
- draw_selected_mask(filtered_masks[s]['segmentation'], mask_draw)
561
- draw_object_mask(filtered_masks[o]['segmentation'], mask_draw)
562
-
563
- current_pil_image = pil_image.copy()
564
- current_pil_image.alpha_composite(mask_image)
565
-
566
- title_image = create_title_image('Red', relation, 'Blue', current_pil_image.size[0])
567
- concate_pil_image = concatenate_images_vertical(current_pil_image, title_image)
568
- pil_image_list.append(concate_pil_image)
569
-
570
- logger.info(f'relate_anything_5_{len(pil_image_list)}')
571
- return pil_image_list
572
-
573
  mask_source_draw = "draw a mask on input image"
574
  mask_source_segment = "type what to detect below"
575
 
@@ -584,7 +524,7 @@ def get_time_cost(run_task_time, time_cost_str):
584
  run_task_time = now_time
585
  return run_task_time, time_cost_str
586
 
587
- def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
588
  iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input, cleaner_size_limit=1080):
589
 
590
  text_prompt = getTextTrans(text_prompt, source='zh', target='en')
@@ -607,15 +547,10 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
607
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
608
  return None, None, time_cost_str, kosmos_image, gr.Textbox.update(visible=(time_cost_str !='')), kosmos_text, kosmos_entities
609
 
610
- if (task_type == 'relate anything'):
611
- output_images = relate_anything(input_image['image'], num_relation)
612
- run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
613
- return output_images, gr.Gallery.update(label='relate images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
614
-
615
  text_prompt = text_prompt.strip()
616
- if not ((task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_draw):
617
- if text_prompt == '':
618
- return [], gr.Gallery.update(label='Detection prompt is not found!😂😂😂😂'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
619
 
620
  if input_image is None:
621
  return [], gr.Gallery.update(label='Please upload a image!😂😂😂😂'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
@@ -649,30 +584,6 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
649
  pass
650
  else:
651
  groundingdino_device = 'cpu'
652
- if device != 'cpu':
653
- try:
654
- from groundingdino import _C
655
- groundingdino_device = 'cuda:0'
656
- except:
657
- warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only in groundingdino!")
658
-
659
- boxes_filt, pred_phrases = get_grounding_output(
660
- groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
661
- )
662
- if boxes_filt.size(0) == 0:
663
- logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1___{groundingdino_device}/[No objects detected, please try others.]_')
664
- return [], gr.Gallery.update(label='No objects detected, please try others.😂😂😂😂'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
665
- boxes_filt_ori = copy.deepcopy(boxes_filt)
666
-
667
- pred_dict = {
668
- "boxes": boxes_filt,
669
- "size": [size[1], size[0]], # H,W
670
- "labels": pred_phrases,
671
- }
672
-
673
- image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
674
- output_images.append(image_with_box)
675
- run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
676
 
677
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
678
  if task_type == 'segment' or ((task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_segment):
@@ -680,37 +591,24 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
680
  if sam_predictor:
681
  sam_predictor.set_image(image)
682
 
683
- for i in range(boxes_filt.size(0)):
684
- boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
685
- boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
686
- boxes_filt[i][2:] += boxes_filt[i][:2]
687
-
688
  if sam_predictor:
689
- boxes_filt = boxes_filt.to(sam_device)
690
- transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
691
-
692
- masks, _, _, _ = sam_predictor.predict_torch(
693
- point_coords = None,
694
- point_labels = None,
695
  # boxes = transformed_boxes,
696
  multimask_output = False,
697
  )
698
  # masks: [9, 1, 512, 512]
699
  assert sam_checkpoint, 'sam_checkpoint is not found!'
700
  else:
701
- masks = torch.zeros(len(boxes_filt), 1, H, W)
702
- mask_count = 0
703
- for box in boxes_filt:
704
- masks[mask_count, 0, int(box[1]):int(box[3]), int(box[0]):int(box[2])] = 1
705
- mask_count += 1
706
- masks = torch.where(masks > 0, True, False)
707
  run_mode = "rectangle"
708
 
709
  # draw output image
710
  plt.figure(figsize=(10, 10))
711
- plt.imshow(image)
712
  for mask in masks:
713
- show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
714
  # for box, label in zip(boxes_filt, pred_phrases):
715
  # show_box(box.cpu().numpy(), plt.gca(), label)
716
  plt.axis('off')
@@ -760,35 +658,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
760
  image_inpainting = sd_model(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
761
  else:
762
  # remove from mask
763
- if mask_source_radio == mask_source_segment:
764
- mask_imgs = []
765
- masks_shape = masks_ori.shape
766
- boxes_filt_ori_array = boxes_filt_ori.numpy()
767
- if inpaint_mode == 'merge':
768
- extend_shape_0 = masks_shape[0]
769
- extend_shape_1 = masks_shape[1]
770
- else:
771
- extend_shape_0 = 1
772
- extend_shape_1 = 1
773
- for i in range(extend_shape_0):
774
- for j in range(extend_shape_1):
775
- mask = masks_ori[i][j].cpu().numpy()
776
- mask_pil = Image.fromarray(mask)
777
- if remove_mode == 'segment':
778
- useRectangle = False
779
- else:
780
- useRectangle = True
781
- try:
782
- remove_mask_extend = int(remove_mask_extend)
783
- except:
784
- remove_mask_extend = 10
785
- mask_pil_exp = mask_extend(copy.deepcopy(mask_pil).convert("RGB"),
786
- xywh_to_xyxy(torch.tensor(boxes_filt_ori_array[i]), W, H),
787
- extend_pixels=remove_mask_extend, useRectangle=useRectangle)
788
- mask_imgs.append(mask_pil_exp)
789
- mask_pil = mix_masks(mask_imgs)
790
- output_images.append(mask_pil.convert("RGB"))
791
- run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
792
 
793
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_6_')
794
  image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")), cleaner_size_limit)
@@ -810,7 +680,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
810
  logger.info(f'run_anything_task_[{file_temp}]_9_9_')
811
  return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
812
 
813
- def change_radio_display(task_type, mask_source_radio):
814
  text_prompt_visible = True
815
  inpaint_prompt_visible = False
816
  mask_source_radio_visible = False
@@ -830,7 +700,7 @@ def change_radio_display(task_type, mask_source_radio):
830
  kosmos_text_output_visible = True
831
 
832
  if task_type in ['inpainting', 'outpainting']:
833
- inpaint_prompt_visible = True
834
  if task_type in ['inpainting', 'outpainting'] or task_type == "remove":
835
  mask_source_radio_visible = True
836
  if mask_source_radio == mask_source_draw:
@@ -838,7 +708,11 @@ def change_radio_display(task_type, mask_source_radio):
838
  if task_type == "relate anything":
839
  text_prompt_visible = False
840
  num_relation_visible = True
841
-
 
 
 
 
842
  return (gr.Textbox.update(visible=text_prompt_visible),
843
  gr.Textbox.update(visible=inpaint_prompt_visible),
844
  gr.Radio.update(visible=mask_source_radio_visible),
@@ -846,7 +720,8 @@ def change_radio_display(task_type, mask_source_radio):
846
  gr.Gallery.update(visible=image_gallery_visible),
847
  gr.Radio.update(visible=kosmos_input_visible),
848
  gr.Image.update(visible=kosmos_output_visible),
849
- gr.HighlightedText.update(visible=kosmos_text_output_visible))
 
850
 
851
  def get_model_device(module):
852
  try:
@@ -869,29 +744,39 @@ def click_callback(coords):
869
 
870
  def main_gradio(args):
871
  block = gr.Blocks(
872
- title="SAM and others",
873
- theme="shivi/calm_seafoam@>=0.0.1,<1.0.0",
874
  )
875
  with block:
876
  with gr.Row():
877
  with gr.Column():
 
 
878
  task_types = ["segment"]
879
- # if sam_enable:
880
- # task_types.append("segment")
881
  if inpainting_enable:
882
  task_types.append("inpainting")
883
- # task_types.append("outpainting")
884
- # if lama_cleaner_enable:
885
- # task_types.append("remove")
886
- # if ram_enable:
887
- # task_types.append("relate anything")
888
- # if kosmos_enable:
889
- # task_types.append("Kosmos-2")
890
- # task_types.append("inpainting")
891
 
892
 
893
- input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload",
894
- height=512, brush_color='#00FFFF', mask_opacity=0.6)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
895
  print(dir(input_image))
896
  task_type = gr.Radio(task_types, value="segment",
897
  label='Task type', visible=True)
@@ -956,15 +841,15 @@ def main_gradio(args):
956
  selected.change(update_output_image, [kosmos_output, kosmos_output, entity_output, selected], [kosmos_output])
957
 
958
  run_button.click(fn=run_anything_task, inputs=[
959
- input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
960
  iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input],
961
  outputs=[image_gallery, image_gallery, time_cost, time_cost, kosmos_output, kosmos_text_output, entity_output], show_progress=True, queue=True)
962
 
963
- mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio],
964
  outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
965
- task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio],
966
  outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation,
967
- image_gallery, kosmos_input, kosmos_output, kosmos_text_output
968
  ])
969
 
970
  # DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
@@ -1001,17 +886,17 @@ if __name__ == "__main__":
1001
  if device == 'cpu':
1002
  kosmos_enable = False
1003
 
1004
- if kosmos_enable:
1005
- kosmos_model, kosmos_processor = load_kosmos_model(device)
1006
 
1007
- if groundingdino_enable:
1008
- load_groundingdino_model('cpu')
1009
 
1010
  if sam_enable:
1011
  load_sam_model(device)
1012
 
1013
- if inpainting_enable:
1014
- load_sd_model(device)
1015
 
1016
  # if lama_cleaner_enable:
1017
  # load_lama_cleaner_model(device)
 
123
  kosmos_model = None
124
  kosmos_processor = None
125
 
126
+ colors = [(255, 0, 0), (0, 255, 0)]
127
+ markers = [1, 5]
128
+
129
+ def get_point(img, sel_pix, evt: gr.SelectData):
130
+ img = np.array(img, dtype=np.uint8)
131
+ sel_pix.append(evt.index)
132
+ # draw points
133
+
134
+ print(sel_pix)
135
+ for point in sel_pix:
136
+ cv2.drawMarker(img, point, colors[0], markerType=markers[0], markerSize=6, thickness=2)
137
+ return Image.fromarray(img).convert("RGB")
138
+
139
+
140
+
141
+ def undo_button(orig_img, sel_pix):
142
+ temp = orig_img.copy()
143
+ temp = np.array(temp, dtype=np.uint8)
144
+ if len(sel_pix) != 0:
145
+ sel_pix.pop()
146
+ for point in sel_pix:
147
+ cv2.drawMarker(temp, point, colors[0], markerType=markers[0], markerSize=6, thickness=2)
148
+ return Image.fromarray(temp).convert("RGB")
149
+
150
+ def toggle_button(orig_img, task_type):
151
+ print(task_type)
152
+ if task_type == "segment":
153
+ ret = gr.Image(value= orig_img,elem_id="image_upload", type='pil', label="Upload", height=512, tool = "editor")# tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
154
+ elif task_type == "inpainting":
155
+ ret = gr.Image(value = orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
156
+ task_type = not task_type
157
+ return ret, task_type
158
+
159
+
160
+ def store_img(img):
161
+ print("call for store")
162
+ return img, [] # when new image is uploaded, `selected_points` should be empty
163
+
164
  def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
165
  args = SLConfig.fromfile(model_config_path)
166
  model = build_model(args)
 
328
  device = 'cpu'
329
  print(f'device={device}')
330
 
 
 
 
 
 
 
 
331
  def get_sam_vit_h_4b8939():
332
  if not os.path.exists('./sam_vit_h_4b8939.pth'):
333
  logger.info(f"get sam_vit_h_4b8939.pth...")
 
358
  )
359
  sd_model = sd_model.to(device)
360
 
 
 
 
 
 
 
 
 
 
 
361
  def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
362
  try:
363
  logger.info(f'_______lama_cleaner_process_______1____')
 
434
  image = None
435
  return image
436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
  # visualization
438
  def draw_selected_mask(mask, draw):
439
  color = (255, 0, 0, 153)
 
510
 
511
  return new_image
512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
  mask_source_draw = "draw a mask on input image"
514
  mask_source_segment = "type what to detect below"
515
 
 
524
  run_task_time = now_time
525
  return run_task_time, time_cost_str
526
 
527
+ def run_anything_task(input_image, input_points, origin_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
528
  iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input, cleaner_size_limit=1080):
529
 
530
  text_prompt = getTextTrans(text_prompt, source='zh', target='en')
 
547
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
548
  return None, None, time_cost_str, kosmos_image, gr.Textbox.update(visible=(time_cost_str !='')), kosmos_text, kosmos_entities
549
 
 
 
 
 
 
550
  text_prompt = text_prompt.strip()
551
+ # if not ((task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_draw):
552
+ # if text_prompt == '':
553
+ # return [], gr.Gallery.update(label='Detection prompt is not found!😂😂😂😂'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
554
 
555
  if input_image is None:
556
  return [], gr.Gallery.update(label='Please upload a image!😂😂😂😂'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
 
584
  pass
585
  else:
586
  groundingdino_device = 'cpu'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
587
 
588
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
589
  if task_type == 'segment' or ((task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_segment):
 
591
  if sam_predictor:
592
  sam_predictor.set_image(image)
593
 
 
 
 
 
 
594
  if sam_predictor:
595
+ logger.info(f"Forward with: {input_points}")
596
+ masks, _, _, _ = sam_predictor.predict(
597
+ point_coords = np.array(input_points),
598
+ point_labels = np.array([1 for _ in range(len(input_points))]),
 
 
599
  # boxes = transformed_boxes,
600
  multimask_output = False,
601
  )
602
  # masks: [9, 1, 512, 512]
603
  assert sam_checkpoint, 'sam_checkpoint is not found!'
604
  else:
 
 
 
 
 
 
605
  run_mode = "rectangle"
606
 
607
  # draw output image
608
  plt.figure(figsize=(10, 10))
609
+ plt.imshow(origin_image)
610
  for mask in masks:
611
+ show_mask(mask, plt.gca(), random_color=True)
612
  # for box, label in zip(boxes_filt, pred_phrases):
613
  # show_box(box.cpu().numpy(), plt.gca(), label)
614
  plt.axis('off')
 
658
  image_inpainting = sd_model(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
659
  else:
660
  # remove from mask
661
+ aasds = 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
662
 
663
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_6_')
664
  image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")), cleaner_size_limit)
 
680
  logger.info(f'run_anything_task_[{file_temp}]_9_9_')
681
  return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
682
 
683
+ def change_radio_display(task_type, mask_source_radio, orig_img):
684
  text_prompt_visible = True
685
  inpaint_prompt_visible = False
686
  mask_source_radio_visible = False
 
700
  kosmos_text_output_visible = True
701
 
702
  if task_type in ['inpainting', 'outpainting']:
703
+ inpaint_prompt_visible = False
704
  if task_type in ['inpainting', 'outpainting'] or task_type == "remove":
705
  mask_source_radio_visible = True
706
  if mask_source_radio == mask_source_draw:
 
708
  if task_type == "relate anything":
709
  text_prompt_visible = False
710
  num_relation_visible = True
711
+ if task_type == "segment":
712
+ ret = gr.Image(value= orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "editor")# tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
713
+ elif task_type == "inpainting":
714
+ ret = gr.Image(value = orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
715
+
716
  return (gr.Textbox.update(visible=text_prompt_visible),
717
  gr.Textbox.update(visible=inpaint_prompt_visible),
718
  gr.Radio.update(visible=mask_source_radio_visible),
 
720
  gr.Gallery.update(visible=image_gallery_visible),
721
  gr.Radio.update(visible=kosmos_input_visible),
722
  gr.Image.update(visible=kosmos_output_visible),
723
+ gr.HighlightedText.update(visible=kosmos_text_output_visible),
724
+ ret, [], gr.Button("Undo point", visible = task_type == "segment"))
725
 
726
  def get_model_device(module):
727
  try:
 
744
 
745
  def main_gradio(args):
746
  block = gr.Blocks(
747
+ title="Thesis-Demo",
748
+ # theme="shivi/calm_seafoam@>=0.0.1,<1.0.0",
749
  )
750
  with block:
751
  with gr.Row():
752
  with gr.Column():
753
+ selected_points = gr.State([])
754
+ original_image = gr.State()
755
  task_types = ["segment"]
 
 
756
  if inpainting_enable:
757
  task_types.append("inpainting")
 
 
 
 
 
 
 
 
758
 
759
 
760
+ input_image = gr.Image(elem_id="image_upload", type='pil', label="Upload", height=512)
761
+
762
+ input_image.upload(
763
+ store_img,
764
+ [input_image],
765
+ [original_image, selected_points]
766
+ )
767
+
768
+ input_image.select(
769
+ get_point,
770
+ [input_image, selected_points],
771
+ [input_image]
772
+ )
773
+
774
+ undo_point_button = gr.Button("Undo point")
775
+ undo_point_button.click(
776
+ fn= undo_button,
777
+ inputs=[original_image, selected_points],
778
+ outputs=[input_image]
779
+ )
780
  print(dir(input_image))
781
  task_type = gr.Radio(task_types, value="segment",
782
  label='Task type', visible=True)
 
841
  selected.change(update_output_image, [kosmos_output, kosmos_output, entity_output, selected], [kosmos_output])
842
 
843
  run_button.click(fn=run_anything_task, inputs=[
844
+ input_image, selected_points, original_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
845
  iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input],
846
  outputs=[image_gallery, image_gallery, time_cost, time_cost, kosmos_output, kosmos_text_output, entity_output], show_progress=True, queue=True)
847
 
848
+ mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
849
  outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
850
+ task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
851
  outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation,
852
+ image_gallery, kosmos_input, kosmos_output, kosmos_text_output, input_image, selected_points, undo_point_button
853
  ])
854
 
855
  # DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
 
886
  if device == 'cpu':
887
  kosmos_enable = False
888
 
889
+ # if kosmos_enable:
890
+ # kosmos_model, kosmos_processor = load_kosmos_model(device)
891
 
892
+ # if groundingdino_enable:
893
+ # load_groundingdino_model('cpu')
894
 
895
  if sam_enable:
896
  load_sam_model(device)
897
 
898
+ # if inpainting_enable:
899
+ # load_sd_model(device)
900
 
901
  # if lama_cleaner_enable:
902
  # load_lama_cleaner_model(device)