Spaces:
Runtime error
Runtime error
update thesis demo with SAM
Browse files- gradio_test.py +69 -0
- test.py +92 -207
gradio_test.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
block = gr.Blocks(
|
7 |
+
title="SAM and others",
|
8 |
+
# theme="shivi/calm_seafoam@>=0.0.1,<1.0.0",
|
9 |
+
)
|
10 |
+
colors = [(255, 0, 0), (0, 255, 0)]
|
11 |
+
markers = [1, 5]
|
12 |
+
|
13 |
+
def get_point(img, sel_pix, evt: gr.SelectData):
|
14 |
+
img = np.array(img, dtype=np.uint8)
|
15 |
+
sel_pix.append(evt.index)
|
16 |
+
# draw points
|
17 |
+
|
18 |
+
print(sel_pix)
|
19 |
+
for point in sel_pix:
|
20 |
+
cv2.drawMarker(img, point, colors[0], markerType=markers[0], markerSize=6, thickness=2)
|
21 |
+
return Image.fromarray(img).convert("RGB")
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
def undo_button(orig_img, sel_pix):
|
26 |
+
temp = orig_img.copy()
|
27 |
+
temp = np.array(temp, dtype=np.uint8)
|
28 |
+
if len(sel_pix) != 0:
|
29 |
+
sel_pix.pop()
|
30 |
+
for point in sel_pix:
|
31 |
+
cv2.drawMarker(temp, point, colors[0], markerType=markers[0], markerSize=6, thickness=2)
|
32 |
+
return Image.fromarray(temp).convert("RGB")
|
33 |
+
|
34 |
+
def toggle_button(orig_img, mode):
|
35 |
+
print(mode)
|
36 |
+
if mode:
|
37 |
+
ret = gr.Image(value= orig_img,elem_id="image_upload", type='pil', label="Upload", height=512, tool = "editor")# tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
|
38 |
+
else:
|
39 |
+
ret = gr.Image(value = orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
|
40 |
+
mode = not mode
|
41 |
+
return ret, mode
|
42 |
+
|
43 |
+
def store_img(img):
|
44 |
+
print("call for store")
|
45 |
+
return img, [] # when new image is uploaded, `selected_points` should be empty
|
46 |
+
|
47 |
+
with block:
|
48 |
+
selected_points = gr.State([])
|
49 |
+
original_image = gr.State()
|
50 |
+
mode = gr.State(True)
|
51 |
+
input_image = gr.Image(elem_id="image_upload", type='pil', label="Upload", height=512,)# tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
|
52 |
+
undo = gr.Button("undo mode", visible=True)
|
53 |
+
toggle = gr.Button("toggle mode", visible=True)
|
54 |
+
input_image.upload(
|
55 |
+
store_img,
|
56 |
+
[input_image],
|
57 |
+
[original_image, selected_points]
|
58 |
+
)
|
59 |
+
|
60 |
+
input_image.select(
|
61 |
+
get_point,
|
62 |
+
[input_image, selected_points],
|
63 |
+
[input_image]
|
64 |
+
)
|
65 |
+
|
66 |
+
undo.click(fn=undo_button, inputs=[original_image, selected_points], outputs=[input_image])
|
67 |
+
toggle.click(fn=toggle_button, inputs=[original_image, mode], outputs=[input_image, mode])
|
68 |
+
|
69 |
+
block.launch()
|
test.py
CHANGED
@@ -123,6 +123,44 @@ ram_model = None
|
|
123 |
kosmos_model = None
|
124 |
kosmos_processor = None
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
|
127 |
args = SLConfig.fromfile(model_config_path)
|
128 |
model = build_model(args)
|
@@ -290,13 +328,6 @@ def set_device(args):
|
|
290 |
device = 'cpu'
|
291 |
print(f'device={device}')
|
292 |
|
293 |
-
def load_groundingdino_model(device):
|
294 |
-
# initialize groundingdino model
|
295 |
-
global groundingdino_model
|
296 |
-
logger.info(f"initialize groundingdino model...")
|
297 |
-
groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae, device=device) #'cpu')
|
298 |
-
logger.info(f"initialize groundingdino model...{type(groundingdino_model)}")
|
299 |
-
|
300 |
def get_sam_vit_h_4b8939():
|
301 |
if not os.path.exists('./sam_vit_h_4b8939.pth'):
|
302 |
logger.info(f"get sam_vit_h_4b8939.pth...")
|
@@ -327,16 +358,6 @@ def load_sd_model(device):
|
|
327 |
)
|
328 |
sd_model = sd_model.to(device)
|
329 |
|
330 |
-
def load_lama_cleaner_model(device):
|
331 |
-
# initialize lama_cleaner
|
332 |
-
global lama_cleaner_model
|
333 |
-
logger.info(f"initialize lama_cleaner...")
|
334 |
-
|
335 |
-
lama_cleaner_model = ModelManager(
|
336 |
-
name='lama',
|
337 |
-
device=device,
|
338 |
-
)
|
339 |
-
|
340 |
def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
|
341 |
try:
|
342 |
logger.info(f'_______lama_cleaner_process_______1____')
|
@@ -413,41 +434,6 @@ def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
|
|
413 |
image = None
|
414 |
return image
|
415 |
|
416 |
-
class Ram_Predictor(RamPredictor):
|
417 |
-
def __init__(self, config, device='cpu'):
|
418 |
-
self.config = config
|
419 |
-
self.device = torch.device(device)
|
420 |
-
self._build_model()
|
421 |
-
|
422 |
-
def _build_model(self):
|
423 |
-
self.model = RamModel(**self.config.model).to(self.device)
|
424 |
-
if self.config.load_from is not None:
|
425 |
-
self.model.load_state_dict(torch.load(self.config.load_from, map_location=self.device))
|
426 |
-
self.model.train()
|
427 |
-
|
428 |
-
def load_ram_model(device):
|
429 |
-
# load ram model
|
430 |
-
global ram_model
|
431 |
-
if os.environ.get('IS_MY_DEBUG') is not None:
|
432 |
-
return
|
433 |
-
model_path = "./checkpoints/ram_epoch12.pth"
|
434 |
-
ram_config = dict(
|
435 |
-
model=dict(
|
436 |
-
pretrained_model_name_or_path='bert-base-uncased',
|
437 |
-
load_pretrained_weights=False,
|
438 |
-
num_transformer_layer=2,
|
439 |
-
input_feature_size=256,
|
440 |
-
output_feature_size=768,
|
441 |
-
cls_feature_size=512,
|
442 |
-
num_relation_classes=56,
|
443 |
-
pred_type='attention',
|
444 |
-
loss_type='multi_label_ce',
|
445 |
-
),
|
446 |
-
load_from=model_path,
|
447 |
-
)
|
448 |
-
ram_config = mmengine_Config(ram_config)
|
449 |
-
ram_model = Ram_Predictor(ram_config, device)
|
450 |
-
|
451 |
# visualization
|
452 |
def draw_selected_mask(mask, draw):
|
453 |
color = (255, 0, 0, 153)
|
@@ -524,52 +510,6 @@ def concatenate_images_vertical(image1, image2):
|
|
524 |
|
525 |
return new_image
|
526 |
|
527 |
-
def relate_anything(input_image, k):
|
528 |
-
logger.info(f'relate_anything_1_{input_image.size}_')
|
529 |
-
w, h = input_image.size
|
530 |
-
max_edge = 1500
|
531 |
-
if w > max_edge or h > max_edge:
|
532 |
-
ratio = max(w, h) / max_edge
|
533 |
-
new_size = (int(w / ratio), int(h / ratio))
|
534 |
-
input_image.thumbnail(new_size)
|
535 |
-
|
536 |
-
logger.info(f'relate_anything_2_')
|
537 |
-
# load image
|
538 |
-
pil_image = input_image.convert('RGBA')
|
539 |
-
image = np.array(input_image)
|
540 |
-
sam_masks = sam_mask_generator.generate(image)
|
541 |
-
filtered_masks = sort_and_deduplicate(sam_masks)
|
542 |
-
|
543 |
-
logger.info(f'relate_anything_3_')
|
544 |
-
feat_list = []
|
545 |
-
for fm in filtered_masks:
|
546 |
-
feat = torch.Tensor(fm['feat']).unsqueeze(0).unsqueeze(0).to(device)
|
547 |
-
feat_list.append(feat)
|
548 |
-
feat = torch.cat(feat_list, dim=1).to(device)
|
549 |
-
matrix_output, rel_triplets = ram_model.predict(feat)
|
550 |
-
|
551 |
-
logger.info(f'relate_anything_4_')
|
552 |
-
pil_image_list = []
|
553 |
-
for i, rel in enumerate(rel_triplets[:k]):
|
554 |
-
s,o,r = int(rel[0]),int(rel[1]),int(rel[2])
|
555 |
-
relation = relation_classes[r]
|
556 |
-
|
557 |
-
mask_image = Image.new('RGBA', pil_image.size, color=(0, 0, 0, 0))
|
558 |
-
mask_draw = ImageDraw.Draw(mask_image)
|
559 |
-
|
560 |
-
draw_selected_mask(filtered_masks[s]['segmentation'], mask_draw)
|
561 |
-
draw_object_mask(filtered_masks[o]['segmentation'], mask_draw)
|
562 |
-
|
563 |
-
current_pil_image = pil_image.copy()
|
564 |
-
current_pil_image.alpha_composite(mask_image)
|
565 |
-
|
566 |
-
title_image = create_title_image('Red', relation, 'Blue', current_pil_image.size[0])
|
567 |
-
concate_pil_image = concatenate_images_vertical(current_pil_image, title_image)
|
568 |
-
pil_image_list.append(concate_pil_image)
|
569 |
-
|
570 |
-
logger.info(f'relate_anything_5_{len(pil_image_list)}')
|
571 |
-
return pil_image_list
|
572 |
-
|
573 |
mask_source_draw = "draw a mask on input image"
|
574 |
mask_source_segment = "type what to detect below"
|
575 |
|
@@ -584,7 +524,7 @@ def get_time_cost(run_task_time, time_cost_str):
|
|
584 |
run_task_time = now_time
|
585 |
return run_task_time, time_cost_str
|
586 |
|
587 |
-
def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
588 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input, cleaner_size_limit=1080):
|
589 |
|
590 |
text_prompt = getTextTrans(text_prompt, source='zh', target='en')
|
@@ -607,15 +547,10 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
607 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
608 |
return None, None, time_cost_str, kosmos_image, gr.Textbox.update(visible=(time_cost_str !='')), kosmos_text, kosmos_entities
|
609 |
|
610 |
-
if (task_type == 'relate anything'):
|
611 |
-
output_images = relate_anything(input_image['image'], num_relation)
|
612 |
-
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
613 |
-
return output_images, gr.Gallery.update(label='relate images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
614 |
-
|
615 |
text_prompt = text_prompt.strip()
|
616 |
-
if not ((task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_draw):
|
617 |
-
|
618 |
-
|
619 |
|
620 |
if input_image is None:
|
621 |
return [], gr.Gallery.update(label='Please upload a image!😂😂😂😂'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
@@ -649,30 +584,6 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
649 |
pass
|
650 |
else:
|
651 |
groundingdino_device = 'cpu'
|
652 |
-
if device != 'cpu':
|
653 |
-
try:
|
654 |
-
from groundingdino import _C
|
655 |
-
groundingdino_device = 'cuda:0'
|
656 |
-
except:
|
657 |
-
warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only in groundingdino!")
|
658 |
-
|
659 |
-
boxes_filt, pred_phrases = get_grounding_output(
|
660 |
-
groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
|
661 |
-
)
|
662 |
-
if boxes_filt.size(0) == 0:
|
663 |
-
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1___{groundingdino_device}/[No objects detected, please try others.]_')
|
664 |
-
return [], gr.Gallery.update(label='No objects detected, please try others.😂😂😂😂'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
665 |
-
boxes_filt_ori = copy.deepcopy(boxes_filt)
|
666 |
-
|
667 |
-
pred_dict = {
|
668 |
-
"boxes": boxes_filt,
|
669 |
-
"size": [size[1], size[0]], # H,W
|
670 |
-
"labels": pred_phrases,
|
671 |
-
}
|
672 |
-
|
673 |
-
image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
|
674 |
-
output_images.append(image_with_box)
|
675 |
-
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
676 |
|
677 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
|
678 |
if task_type == 'segment' or ((task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_segment):
|
@@ -680,37 +591,24 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
680 |
if sam_predictor:
|
681 |
sam_predictor.set_image(image)
|
682 |
|
683 |
-
for i in range(boxes_filt.size(0)):
|
684 |
-
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
|
685 |
-
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
|
686 |
-
boxes_filt[i][2:] += boxes_filt[i][:2]
|
687 |
-
|
688 |
if sam_predictor:
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
point_coords = None,
|
694 |
-
point_labels = None,
|
695 |
# boxes = transformed_boxes,
|
696 |
multimask_output = False,
|
697 |
)
|
698 |
# masks: [9, 1, 512, 512]
|
699 |
assert sam_checkpoint, 'sam_checkpoint is not found!'
|
700 |
else:
|
701 |
-
masks = torch.zeros(len(boxes_filt), 1, H, W)
|
702 |
-
mask_count = 0
|
703 |
-
for box in boxes_filt:
|
704 |
-
masks[mask_count, 0, int(box[1]):int(box[3]), int(box[0]):int(box[2])] = 1
|
705 |
-
mask_count += 1
|
706 |
-
masks = torch.where(masks > 0, True, False)
|
707 |
run_mode = "rectangle"
|
708 |
|
709 |
# draw output image
|
710 |
plt.figure(figsize=(10, 10))
|
711 |
-
plt.imshow(
|
712 |
for mask in masks:
|
713 |
-
show_mask(mask
|
714 |
# for box, label in zip(boxes_filt, pred_phrases):
|
715 |
# show_box(box.cpu().numpy(), plt.gca(), label)
|
716 |
plt.axis('off')
|
@@ -760,35 +658,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
760 |
image_inpainting = sd_model(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
|
761 |
else:
|
762 |
# remove from mask
|
763 |
-
|
764 |
-
mask_imgs = []
|
765 |
-
masks_shape = masks_ori.shape
|
766 |
-
boxes_filt_ori_array = boxes_filt_ori.numpy()
|
767 |
-
if inpaint_mode == 'merge':
|
768 |
-
extend_shape_0 = masks_shape[0]
|
769 |
-
extend_shape_1 = masks_shape[1]
|
770 |
-
else:
|
771 |
-
extend_shape_0 = 1
|
772 |
-
extend_shape_1 = 1
|
773 |
-
for i in range(extend_shape_0):
|
774 |
-
for j in range(extend_shape_1):
|
775 |
-
mask = masks_ori[i][j].cpu().numpy()
|
776 |
-
mask_pil = Image.fromarray(mask)
|
777 |
-
if remove_mode == 'segment':
|
778 |
-
useRectangle = False
|
779 |
-
else:
|
780 |
-
useRectangle = True
|
781 |
-
try:
|
782 |
-
remove_mask_extend = int(remove_mask_extend)
|
783 |
-
except:
|
784 |
-
remove_mask_extend = 10
|
785 |
-
mask_pil_exp = mask_extend(copy.deepcopy(mask_pil).convert("RGB"),
|
786 |
-
xywh_to_xyxy(torch.tensor(boxes_filt_ori_array[i]), W, H),
|
787 |
-
extend_pixels=remove_mask_extend, useRectangle=useRectangle)
|
788 |
-
mask_imgs.append(mask_pil_exp)
|
789 |
-
mask_pil = mix_masks(mask_imgs)
|
790 |
-
output_images.append(mask_pil.convert("RGB"))
|
791 |
-
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
792 |
|
793 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_6_')
|
794 |
image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")), cleaner_size_limit)
|
@@ -810,7 +680,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
810 |
logger.info(f'run_anything_task_[{file_temp}]_9_9_')
|
811 |
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
812 |
|
813 |
-
def change_radio_display(task_type, mask_source_radio):
|
814 |
text_prompt_visible = True
|
815 |
inpaint_prompt_visible = False
|
816 |
mask_source_radio_visible = False
|
@@ -830,7 +700,7 @@ def change_radio_display(task_type, mask_source_radio):
|
|
830 |
kosmos_text_output_visible = True
|
831 |
|
832 |
if task_type in ['inpainting', 'outpainting']:
|
833 |
-
inpaint_prompt_visible =
|
834 |
if task_type in ['inpainting', 'outpainting'] or task_type == "remove":
|
835 |
mask_source_radio_visible = True
|
836 |
if mask_source_radio == mask_source_draw:
|
@@ -838,7 +708,11 @@ def change_radio_display(task_type, mask_source_radio):
|
|
838 |
if task_type == "relate anything":
|
839 |
text_prompt_visible = False
|
840 |
num_relation_visible = True
|
841 |
-
|
|
|
|
|
|
|
|
|
842 |
return (gr.Textbox.update(visible=text_prompt_visible),
|
843 |
gr.Textbox.update(visible=inpaint_prompt_visible),
|
844 |
gr.Radio.update(visible=mask_source_radio_visible),
|
@@ -846,7 +720,8 @@ def change_radio_display(task_type, mask_source_radio):
|
|
846 |
gr.Gallery.update(visible=image_gallery_visible),
|
847 |
gr.Radio.update(visible=kosmos_input_visible),
|
848 |
gr.Image.update(visible=kosmos_output_visible),
|
849 |
-
gr.HighlightedText.update(visible=kosmos_text_output_visible)
|
|
|
850 |
|
851 |
def get_model_device(module):
|
852 |
try:
|
@@ -869,29 +744,39 @@ def click_callback(coords):
|
|
869 |
|
870 |
def main_gradio(args):
|
871 |
block = gr.Blocks(
|
872 |
-
title="
|
873 |
-
theme="shivi/calm_seafoam@>=0.0.1,<1.0.0",
|
874 |
)
|
875 |
with block:
|
876 |
with gr.Row():
|
877 |
with gr.Column():
|
|
|
|
|
878 |
task_types = ["segment"]
|
879 |
-
# if sam_enable:
|
880 |
-
# task_types.append("segment")
|
881 |
if inpainting_enable:
|
882 |
task_types.append("inpainting")
|
883 |
-
# task_types.append("outpainting")
|
884 |
-
# if lama_cleaner_enable:
|
885 |
-
# task_types.append("remove")
|
886 |
-
# if ram_enable:
|
887 |
-
# task_types.append("relate anything")
|
888 |
-
# if kosmos_enable:
|
889 |
-
# task_types.append("Kosmos-2")
|
890 |
-
# task_types.append("inpainting")
|
891 |
|
892 |
|
893 |
-
input_image = gr.Image(
|
894 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
895 |
print(dir(input_image))
|
896 |
task_type = gr.Radio(task_types, value="segment",
|
897 |
label='Task type', visible=True)
|
@@ -956,15 +841,15 @@ def main_gradio(args):
|
|
956 |
selected.change(update_output_image, [kosmos_output, kosmos_output, entity_output, selected], [kosmos_output])
|
957 |
|
958 |
run_button.click(fn=run_anything_task, inputs=[
|
959 |
-
input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
960 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input],
|
961 |
outputs=[image_gallery, image_gallery, time_cost, time_cost, kosmos_output, kosmos_text_output, entity_output], show_progress=True, queue=True)
|
962 |
|
963 |
-
mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio],
|
964 |
outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
|
965 |
-
task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio],
|
966 |
outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation,
|
967 |
-
image_gallery, kosmos_input, kosmos_output, kosmos_text_output
|
968 |
])
|
969 |
|
970 |
# DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
|
@@ -1001,17 +886,17 @@ if __name__ == "__main__":
|
|
1001 |
if device == 'cpu':
|
1002 |
kosmos_enable = False
|
1003 |
|
1004 |
-
if kosmos_enable:
|
1005 |
-
|
1006 |
|
1007 |
-
if groundingdino_enable:
|
1008 |
-
|
1009 |
|
1010 |
if sam_enable:
|
1011 |
load_sam_model(device)
|
1012 |
|
1013 |
-
if inpainting_enable:
|
1014 |
-
|
1015 |
|
1016 |
# if lama_cleaner_enable:
|
1017 |
# load_lama_cleaner_model(device)
|
|
|
123 |
kosmos_model = None
|
124 |
kosmos_processor = None
|
125 |
|
126 |
+
colors = [(255, 0, 0), (0, 255, 0)]
|
127 |
+
markers = [1, 5]
|
128 |
+
|
129 |
+
def get_point(img, sel_pix, evt: gr.SelectData):
|
130 |
+
img = np.array(img, dtype=np.uint8)
|
131 |
+
sel_pix.append(evt.index)
|
132 |
+
# draw points
|
133 |
+
|
134 |
+
print(sel_pix)
|
135 |
+
for point in sel_pix:
|
136 |
+
cv2.drawMarker(img, point, colors[0], markerType=markers[0], markerSize=6, thickness=2)
|
137 |
+
return Image.fromarray(img).convert("RGB")
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
def undo_button(orig_img, sel_pix):
|
142 |
+
temp = orig_img.copy()
|
143 |
+
temp = np.array(temp, dtype=np.uint8)
|
144 |
+
if len(sel_pix) != 0:
|
145 |
+
sel_pix.pop()
|
146 |
+
for point in sel_pix:
|
147 |
+
cv2.drawMarker(temp, point, colors[0], markerType=markers[0], markerSize=6, thickness=2)
|
148 |
+
return Image.fromarray(temp).convert("RGB")
|
149 |
+
|
150 |
+
def toggle_button(orig_img, task_type):
|
151 |
+
print(task_type)
|
152 |
+
if task_type == "segment":
|
153 |
+
ret = gr.Image(value= orig_img,elem_id="image_upload", type='pil', label="Upload", height=512, tool = "editor")# tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
|
154 |
+
elif task_type == "inpainting":
|
155 |
+
ret = gr.Image(value = orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
|
156 |
+
task_type = not task_type
|
157 |
+
return ret, task_type
|
158 |
+
|
159 |
+
|
160 |
+
def store_img(img):
|
161 |
+
print("call for store")
|
162 |
+
return img, [] # when new image is uploaded, `selected_points` should be empty
|
163 |
+
|
164 |
def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
|
165 |
args = SLConfig.fromfile(model_config_path)
|
166 |
model = build_model(args)
|
|
|
328 |
device = 'cpu'
|
329 |
print(f'device={device}')
|
330 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
def get_sam_vit_h_4b8939():
|
332 |
if not os.path.exists('./sam_vit_h_4b8939.pth'):
|
333 |
logger.info(f"get sam_vit_h_4b8939.pth...")
|
|
|
358 |
)
|
359 |
sd_model = sd_model.to(device)
|
360 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
|
362 |
try:
|
363 |
logger.info(f'_______lama_cleaner_process_______1____')
|
|
|
434 |
image = None
|
435 |
return image
|
436 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
437 |
# visualization
|
438 |
def draw_selected_mask(mask, draw):
|
439 |
color = (255, 0, 0, 153)
|
|
|
510 |
|
511 |
return new_image
|
512 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
513 |
mask_source_draw = "draw a mask on input image"
|
514 |
mask_source_segment = "type what to detect below"
|
515 |
|
|
|
524 |
run_task_time = now_time
|
525 |
return run_task_time, time_cost_str
|
526 |
|
527 |
+
def run_anything_task(input_image, input_points, origin_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
528 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input, cleaner_size_limit=1080):
|
529 |
|
530 |
text_prompt = getTextTrans(text_prompt, source='zh', target='en')
|
|
|
547 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
548 |
return None, None, time_cost_str, kosmos_image, gr.Textbox.update(visible=(time_cost_str !='')), kosmos_text, kosmos_entities
|
549 |
|
|
|
|
|
|
|
|
|
|
|
550 |
text_prompt = text_prompt.strip()
|
551 |
+
# if not ((task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_draw):
|
552 |
+
# if text_prompt == '':
|
553 |
+
# return [], gr.Gallery.update(label='Detection prompt is not found!😂😂😂😂'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
554 |
|
555 |
if input_image is None:
|
556 |
return [], gr.Gallery.update(label='Please upload a image!😂😂😂😂'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
|
|
584 |
pass
|
585 |
else:
|
586 |
groundingdino_device = 'cpu'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
587 |
|
588 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
|
589 |
if task_type == 'segment' or ((task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_segment):
|
|
|
591 |
if sam_predictor:
|
592 |
sam_predictor.set_image(image)
|
593 |
|
|
|
|
|
|
|
|
|
|
|
594 |
if sam_predictor:
|
595 |
+
logger.info(f"Forward with: {input_points}")
|
596 |
+
masks, _, _, _ = sam_predictor.predict(
|
597 |
+
point_coords = np.array(input_points),
|
598 |
+
point_labels = np.array([1 for _ in range(len(input_points))]),
|
|
|
|
|
599 |
# boxes = transformed_boxes,
|
600 |
multimask_output = False,
|
601 |
)
|
602 |
# masks: [9, 1, 512, 512]
|
603 |
assert sam_checkpoint, 'sam_checkpoint is not found!'
|
604 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
605 |
run_mode = "rectangle"
|
606 |
|
607 |
# draw output image
|
608 |
plt.figure(figsize=(10, 10))
|
609 |
+
plt.imshow(origin_image)
|
610 |
for mask in masks:
|
611 |
+
show_mask(mask, plt.gca(), random_color=True)
|
612 |
# for box, label in zip(boxes_filt, pred_phrases):
|
613 |
# show_box(box.cpu().numpy(), plt.gca(), label)
|
614 |
plt.axis('off')
|
|
|
658 |
image_inpainting = sd_model(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
|
659 |
else:
|
660 |
# remove from mask
|
661 |
+
aasds = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
662 |
|
663 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_6_')
|
664 |
image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")), cleaner_size_limit)
|
|
|
680 |
logger.info(f'run_anything_task_[{file_temp}]_9_9_')
|
681 |
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
682 |
|
683 |
+
def change_radio_display(task_type, mask_source_radio, orig_img):
|
684 |
text_prompt_visible = True
|
685 |
inpaint_prompt_visible = False
|
686 |
mask_source_radio_visible = False
|
|
|
700 |
kosmos_text_output_visible = True
|
701 |
|
702 |
if task_type in ['inpainting', 'outpainting']:
|
703 |
+
inpaint_prompt_visible = False
|
704 |
if task_type in ['inpainting', 'outpainting'] or task_type == "remove":
|
705 |
mask_source_radio_visible = True
|
706 |
if mask_source_radio == mask_source_draw:
|
|
|
708 |
if task_type == "relate anything":
|
709 |
text_prompt_visible = False
|
710 |
num_relation_visible = True
|
711 |
+
if task_type == "segment":
|
712 |
+
ret = gr.Image(value= orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "editor")# tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
|
713 |
+
elif task_type == "inpainting":
|
714 |
+
ret = gr.Image(value = orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
|
715 |
+
|
716 |
return (gr.Textbox.update(visible=text_prompt_visible),
|
717 |
gr.Textbox.update(visible=inpaint_prompt_visible),
|
718 |
gr.Radio.update(visible=mask_source_radio_visible),
|
|
|
720 |
gr.Gallery.update(visible=image_gallery_visible),
|
721 |
gr.Radio.update(visible=kosmos_input_visible),
|
722 |
gr.Image.update(visible=kosmos_output_visible),
|
723 |
+
gr.HighlightedText.update(visible=kosmos_text_output_visible),
|
724 |
+
ret, [], gr.Button("Undo point", visible = task_type == "segment"))
|
725 |
|
726 |
def get_model_device(module):
|
727 |
try:
|
|
|
744 |
|
745 |
def main_gradio(args):
|
746 |
block = gr.Blocks(
|
747 |
+
title="Thesis-Demo",
|
748 |
+
# theme="shivi/calm_seafoam@>=0.0.1,<1.0.0",
|
749 |
)
|
750 |
with block:
|
751 |
with gr.Row():
|
752 |
with gr.Column():
|
753 |
+
selected_points = gr.State([])
|
754 |
+
original_image = gr.State()
|
755 |
task_types = ["segment"]
|
|
|
|
|
756 |
if inpainting_enable:
|
757 |
task_types.append("inpainting")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
758 |
|
759 |
|
760 |
+
input_image = gr.Image(elem_id="image_upload", type='pil', label="Upload", height=512)
|
761 |
+
|
762 |
+
input_image.upload(
|
763 |
+
store_img,
|
764 |
+
[input_image],
|
765 |
+
[original_image, selected_points]
|
766 |
+
)
|
767 |
+
|
768 |
+
input_image.select(
|
769 |
+
get_point,
|
770 |
+
[input_image, selected_points],
|
771 |
+
[input_image]
|
772 |
+
)
|
773 |
+
|
774 |
+
undo_point_button = gr.Button("Undo point")
|
775 |
+
undo_point_button.click(
|
776 |
+
fn= undo_button,
|
777 |
+
inputs=[original_image, selected_points],
|
778 |
+
outputs=[input_image]
|
779 |
+
)
|
780 |
print(dir(input_image))
|
781 |
task_type = gr.Radio(task_types, value="segment",
|
782 |
label='Task type', visible=True)
|
|
|
841 |
selected.change(update_output_image, [kosmos_output, kosmos_output, entity_output, selected], [kosmos_output])
|
842 |
|
843 |
run_button.click(fn=run_anything_task, inputs=[
|
844 |
+
input_image, selected_points, original_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
845 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input],
|
846 |
outputs=[image_gallery, image_gallery, time_cost, time_cost, kosmos_output, kosmos_text_output, entity_output], show_progress=True, queue=True)
|
847 |
|
848 |
+
mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
|
849 |
outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
|
850 |
+
task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
|
851 |
outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation,
|
852 |
+
image_gallery, kosmos_input, kosmos_output, kosmos_text_output, input_image, selected_points, undo_point_button
|
853 |
])
|
854 |
|
855 |
# DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
|
|
|
886 |
if device == 'cpu':
|
887 |
kosmos_enable = False
|
888 |
|
889 |
+
# if kosmos_enable:
|
890 |
+
# kosmos_model, kosmos_processor = load_kosmos_model(device)
|
891 |
|
892 |
+
# if groundingdino_enable:
|
893 |
+
# load_groundingdino_model('cpu')
|
894 |
|
895 |
if sam_enable:
|
896 |
load_sam_model(device)
|
897 |
|
898 |
+
# if inpainting_enable:
|
899 |
+
# load_sd_model(device)
|
900 |
|
901 |
# if lama_cleaner_enable:
|
902 |
# load_lama_cleaner_model(device)
|