Spaces:
Runtime error
Runtime error
Ubuntu
commited on
Commit
β’
e5efca7
1
Parent(s):
c071a86
Update Inpainting Demo
Browse files- .gitignore +1 -0
- .log/log.txt +6 -0
- SegFormer +1 -0
- output.png +0 -0
- requirements.txt +2 -2
- test.png +0 -0
- test.py +168 -76
.gitignore
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
__pycache__
|
2 |
*.pyc
|
3 |
checkpoints/
|
|
|
4 |
*.pth
|
|
|
1 |
__pycache__
|
2 |
*.pyc
|
3 |
checkpoints/
|
4 |
+
I2SB/
|
5 |
*.pth
|
.log/log.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[19:02:29] INFO (0:00:00) Loaded options from opt_pkl_path=PosixPath('I2SB/results/inpaint-freeform2030/options.pkl')!
|
2 |
+
INFO (0:00:00) [Diffusion] Built I2SB diffusion: steps=1000!
|
3 |
+
[19:02:33] INFO (0:00:03) [Net] Initialized network from ckpt_pkl='I2SB/data/256x256_diffusion_uncond_fixedsigma.pkl'! Size=552807171!
|
4 |
+
[19:02:44] INFO (0:00:14) [Net] Loaded pretrained adm ckpt_pt='I2SB/data/256x256_diffusion_uncond_fixedsigma.pt'!
|
5 |
+
[19:02:49] INFO (0:00:19) [Net] Loaded network ckpt: I2SB/results/inpaint-freeform2030/latest.pt!
|
6 |
+
[19:02:50] INFO (0:00:20) [Ema] Loaded ema ckpt: I2SB/results/inpaint-freeform2030/latest.pt!
|
SegFormer
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 64ab11278eb30b8e2d8ea1d10a777fc5b1563948
|
output.png
ADDED
requirements.txt
CHANGED
@@ -18,8 +18,8 @@ timm
|
|
18 |
# torch==2.0.0
|
19 |
# torchvision==0.15.1
|
20 |
|
21 |
-
torch==2.2.1
|
22 |
-
torchvision==0.17.1
|
23 |
|
24 |
gevent
|
25 |
yapf
|
|
|
18 |
# torch==2.0.0
|
19 |
# torchvision==0.15.1
|
20 |
|
21 |
+
# torch==2.2.1
|
22 |
+
# torchvision==0.17.1
|
23 |
|
24 |
gevent
|
25 |
yapf
|
test.png
ADDED
test.py
CHANGED
@@ -36,6 +36,34 @@ from GroundingDINO.groundingdino.util import box_ops
|
|
36 |
from GroundingDINO.groundingdino.util.slconfig import SLConfig
|
37 |
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
import cv2
|
40 |
import numpy as np
|
41 |
import matplotlib
|
@@ -126,6 +154,30 @@ kosmos_processor = None
|
|
126 |
colors = [(255, 0, 0), (0, 255, 0)]
|
127 |
markers = [1, 5]
|
128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
def get_point(img, sel_pix, evt: gr.SelectData):
|
130 |
img = np.array(img, dtype=np.uint8)
|
131 |
sel_pix.append(evt.index)
|
@@ -146,6 +198,10 @@ def undo_button(orig_img, sel_pix):
|
|
146 |
for point in sel_pix:
|
147 |
cv2.drawMarker(temp, point, colors[0], markerType=markers[0], markerSize=6, thickness=2)
|
148 |
return Image.fromarray(temp).convert("RGB")
|
|
|
|
|
|
|
|
|
149 |
|
150 |
def toggle_button(orig_img, task_type):
|
151 |
print(task_type)
|
@@ -173,6 +229,37 @@ def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
|
|
173 |
_ = model.eval()
|
174 |
return model
|
175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
def plot_boxes_to_image(image_pil, tgt):
|
177 |
H, W = tgt["size"]
|
178 |
boxes = tgt["boxes"]
|
@@ -238,6 +325,8 @@ def load_image(image_path):
|
|
238 |
image, _ = transform(image_pil, None) # 3, h, w
|
239 |
return image_pil, image
|
240 |
|
|
|
|
|
241 |
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
|
242 |
caption = caption.lower()
|
243 |
caption = caption.strip()
|
@@ -357,6 +446,24 @@ def load_sd_model(device):
|
|
357 |
torch_dtype=torch.float16,
|
358 |
)
|
359 |
sd_model = sd_model.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
360 |
|
361 |
def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
|
362 |
try:
|
@@ -511,7 +618,7 @@ def concatenate_images_vertical(image1, image2):
|
|
511 |
return new_image
|
512 |
|
513 |
mask_source_draw = "draw a mask on input image"
|
514 |
-
mask_source_segment = "
|
515 |
|
516 |
def get_time_cost(run_task_time, time_cost_str):
|
517 |
now_time = int(time.time()*1000)
|
@@ -524,11 +631,8 @@ def get_time_cost(run_task_time, time_cost_str):
|
|
524 |
run_task_time = now_time
|
525 |
return run_task_time, time_cost_str
|
526 |
|
527 |
-
def run_anything_task(input_image, input_points, origin_image,
|
528 |
-
|
529 |
-
|
530 |
-
text_prompt = getTextTrans(text_prompt, source='zh', target='en')
|
531 |
-
inpaint_prompt = getTextTrans(inpaint_prompt, source='zh', target='en')
|
532 |
|
533 |
run_task_time = 0
|
534 |
time_cost_str = ''
|
@@ -543,27 +647,19 @@ def run_anything_task(input_image, input_points, origin_image, text_prompt, task
|
|
543 |
image_pil, image = load_image(input_image.convert("RGB"))
|
544 |
input_img = input_image
|
545 |
|
546 |
-
kosmos_image, kosmos_text, kosmos_entities = kosmos_generate_predictions(image_pil,
|
547 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
548 |
return None, None, time_cost_str, kosmos_image, gr.Textbox.update(visible=(time_cost_str !='')), kosmos_text, kosmos_entities
|
549 |
|
550 |
-
text_prompt = text_prompt.strip()
|
551 |
-
# if not ((task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_draw):
|
552 |
-
# if text_prompt == '':
|
553 |
-
# return [], gr.Gallery.update(label='Detection prompt is not found!ππππ'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
554 |
-
|
555 |
if input_image is None:
|
556 |
return [], gr.Gallery.update(label='Please upload a image!ππππ'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
557 |
|
558 |
file_temp = int(time.time())
|
559 |
-
logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/
|
560 |
|
561 |
output_images = []
|
562 |
|
563 |
# load image
|
564 |
-
if mask_source_radio == mask_source_draw:
|
565 |
-
input_mask_pil = input_image['mask']
|
566 |
-
input_mask = np.array(input_mask_pil.convert("L"))
|
567 |
|
568 |
if isinstance(input_image, dict):
|
569 |
image_pil, image = load_image(input_image['image'].convert("RGB"))
|
@@ -626,17 +722,17 @@ def run_anything_task(input_image, input_points, origin_image, text_prompt, task
|
|
626 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
627 |
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
628 |
elif task_type in ['inpainting', 'outpainting'] or task_type == 'remove':
|
629 |
-
if
|
630 |
task_type = 'remove'
|
631 |
|
632 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_4_')
|
633 |
if mask_source_radio == mask_source_draw:
|
|
|
|
|
634 |
mask_pil = input_mask_pil
|
635 |
mask = input_mask
|
636 |
else:
|
637 |
masks_ori = copy.deepcopy(masks)
|
638 |
-
if inpaint_mode == 'merge':
|
639 |
-
masks = torch.sum(masks, dim=0).unsqueeze(0)
|
640 |
masks = torch.where(masks > 0, True, False)
|
641 |
mask = masks[0][0].cpu().numpy()
|
642 |
mask_pil = Image.fromarray(mask)
|
@@ -644,18 +740,11 @@ def run_anything_task(input_image, input_points, origin_image, text_prompt, task
|
|
644 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
645 |
|
646 |
if task_type in ['inpainting', 'outpainting']:
|
647 |
-
#
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
img_arr = np.array(image_mask_for_inpaint)
|
653 |
-
img_arr = np.where(img_arr > 0, 1, img_arr)
|
654 |
-
img_arr = 1 - img_arr
|
655 |
-
image_mask_for_inpaint = Image.fromarray(255*img_arr.astype('uint8'))
|
656 |
-
output_images.append(image_mask_for_inpaint.convert("RGB"))
|
657 |
-
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
658 |
-
image_inpainting = sd_model(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
|
659 |
else:
|
660 |
# remove from mask
|
661 |
aasds = 1
|
@@ -681,8 +770,6 @@ def run_anything_task(input_image, input_points, origin_image, text_prompt, task
|
|
681 |
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
682 |
|
683 |
def change_radio_display(task_type, mask_source_radio, orig_img):
|
684 |
-
text_prompt_visible = True
|
685 |
-
inpaint_prompt_visible = False
|
686 |
mask_source_radio_visible = False
|
687 |
num_relation_visible = False
|
688 |
|
@@ -693,35 +780,29 @@ def change_radio_display(task_type, mask_source_radio, orig_img):
|
|
693 |
print(task_type)
|
694 |
if task_type == "Kosmos-2":
|
695 |
if kosmos_enable:
|
696 |
-
text_prompt_visible = False
|
697 |
image_gallery_visible = False
|
698 |
kosmos_input_visible = True
|
699 |
kosmos_output_visible = True
|
700 |
kosmos_text_output_visible = True
|
701 |
|
702 |
-
if task_type in ['inpainting', 'outpainting']:
|
703 |
-
inpaint_prompt_visible = False
|
704 |
if task_type in ['inpainting', 'outpainting'] or task_type == "remove":
|
705 |
mask_source_radio_visible = True
|
706 |
-
if mask_source_radio == mask_source_draw:
|
707 |
-
text_prompt_visible = False
|
708 |
if task_type == "relate anything":
|
709 |
-
text_prompt_visible = False
|
710 |
num_relation_visible = True
|
711 |
if task_type == "segment":
|
712 |
ret = gr.Image(value= orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "editor")# tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
|
713 |
elif task_type == "inpainting":
|
714 |
ret = gr.Image(value = orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
|
715 |
|
716 |
-
return (gr.
|
717 |
-
gr.Textbox.update(visible=inpaint_prompt_visible),
|
718 |
-
gr.Radio.update(visible=mask_source_radio_visible),
|
719 |
gr.Slider.update(visible=num_relation_visible),
|
720 |
gr.Gallery.update(visible=image_gallery_visible),
|
721 |
gr.Radio.update(visible=kosmos_input_visible),
|
722 |
gr.Image.update(visible=kosmos_output_visible),
|
723 |
gr.HighlightedText.update(visible=kosmos_text_output_visible),
|
724 |
-
ret, [],
|
|
|
|
|
725 |
|
726 |
def get_model_device(module):
|
727 |
try:
|
@@ -770,42 +851,52 @@ def main_gradio(args):
|
|
770 |
[input_image, selected_points],
|
771 |
[input_image]
|
772 |
)
|
773 |
-
|
774 |
-
|
775 |
-
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
780 |
print(dir(input_image))
|
781 |
task_type = gr.Radio(task_types, value="segment",
|
782 |
label='Task type', visible=True)
|
783 |
mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
|
784 |
-
value=
|
785 |
visible=False)
|
786 |
-
text_prompt = gr.Textbox(label="Detection", placeholder="Cannot be empty")
|
787 |
-
inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
|
788 |
num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
|
789 |
|
790 |
kosmos_input = gr.Radio(["Brief", "Detailed"], label="Kosmos Description Type", value="Brief", visible=False)
|
791 |
|
792 |
run_button = gr.Button(label="Run", visible=True)
|
793 |
-
with gr.Accordion("Advanced options", open=False) as advanced_options:
|
794 |
-
|
795 |
-
|
796 |
-
|
797 |
-
|
798 |
-
|
799 |
-
|
800 |
-
|
801 |
-
|
802 |
-
|
803 |
-
|
804 |
-
|
805 |
-
|
806 |
-
|
807 |
-
|
808 |
-
|
809 |
|
810 |
with gr.Column():
|
811 |
image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", height=512, visible=True
|
@@ -841,15 +932,15 @@ def main_gradio(args):
|
|
841 |
selected.change(update_output_image, [kosmos_output, kosmos_output, entity_output, selected], [kosmos_output])
|
842 |
|
843 |
run_button.click(fn=run_anything_task, inputs=[
|
844 |
-
input_image, selected_points, original_image,
|
845 |
-
|
846 |
outputs=[image_gallery, image_gallery, time_cost, time_cost, kosmos_output, kosmos_text_output, entity_output], show_progress=True, queue=True)
|
847 |
|
848 |
mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
|
849 |
-
outputs=[
|
850 |
task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
|
851 |
-
outputs=[
|
852 |
-
image_gallery, kosmos_input, kosmos_output, kosmos_text_output, input_image, selected_points, undo_point_button
|
853 |
])
|
854 |
|
855 |
# DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
|
@@ -895,8 +986,9 @@ if __name__ == "__main__":
|
|
895 |
if sam_enable:
|
896 |
load_sam_model(device)
|
897 |
|
898 |
-
|
899 |
-
|
|
|
900 |
|
901 |
# if lama_cleaner_enable:
|
902 |
# load_lama_cleaner_model(device)
|
|
|
36 |
from GroundingDINO.groundingdino.util.slconfig import SLConfig
|
37 |
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
|
38 |
|
39 |
+
# I2SB
|
40 |
+
import sys
|
41 |
+
|
42 |
+
sys.path.insert(0, "/home/ubuntu/Thesis-Demo/I2SB")
|
43 |
+
|
44 |
+
import numpy as np
|
45 |
+
import torch
|
46 |
+
import torch.distributed as dist
|
47 |
+
import torchvision.transforms as transforms
|
48 |
+
import torchvision.utils as tu
|
49 |
+
from easydict import EasyDict as edict
|
50 |
+
from fastapi import (Body, Depends, FastAPI, File, Form, HTTPException, Query,
|
51 |
+
UploadFile)
|
52 |
+
from ipdb import set_trace as debug
|
53 |
+
from PIL import Image
|
54 |
+
from torch.multiprocessing import Process
|
55 |
+
from torch.utils.data import DataLoader, Subset
|
56 |
+
from torch_ema import ExponentialMovingAverage
|
57 |
+
|
58 |
+
import I2SB.distributed_util as dist_util
|
59 |
+
from I2SB.corruption import build_corruption
|
60 |
+
from I2SB.dataset import air_liquide
|
61 |
+
from I2SB.i2sb import Runner, ckpt_util, download_ckpt
|
62 |
+
from I2SB.logger import Logger
|
63 |
+
from I2SB.sample import *
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
import cv2
|
68 |
import numpy as np
|
69 |
import matplotlib
|
|
|
154 |
colors = [(255, 0, 0), (0, 255, 0)]
|
155 |
markers = [1, 5]
|
156 |
|
157 |
+
i2sb_opt = edict(
|
158 |
+
distributed=False,
|
159 |
+
device="cuda",
|
160 |
+
batch_size=1,
|
161 |
+
nfe=10,
|
162 |
+
dataset="sample",
|
163 |
+
dataset_dir=Path(f"dataset/sample"),
|
164 |
+
n_gpu_per_node=1,
|
165 |
+
use_fp16=False,
|
166 |
+
ckpt="inpaint-freeform2030",
|
167 |
+
image_size=256,
|
168 |
+
partition=None,
|
169 |
+
global_size=1,
|
170 |
+
global_rank=0,
|
171 |
+
clip_denoise=True
|
172 |
+
)
|
173 |
+
|
174 |
+
i2sb_transforms = transforms.Compose([
|
175 |
+
transforms.Resize(i2sb_opt.image_size),
|
176 |
+
transforms.CenterCrop(i2sb_opt.image_size),
|
177 |
+
transforms.ToTensor(),
|
178 |
+
transforms.Lambda(lambda t: (t * 2) - 1) # [0,1] --> [-1, 1]
|
179 |
+
])
|
180 |
+
|
181 |
def get_point(img, sel_pix, evt: gr.SelectData):
|
182 |
img = np.array(img, dtype=np.uint8)
|
183 |
sel_pix.append(evt.index)
|
|
|
198 |
for point in sel_pix:
|
199 |
cv2.drawMarker(temp, point, colors[0], markerType=markers[0], markerSize=6, thickness=2)
|
200 |
return Image.fromarray(temp).convert("RGB")
|
201 |
+
|
202 |
+
def clear_button(orig_img):
|
203 |
+
|
204 |
+
return orig_img, []
|
205 |
|
206 |
def toggle_button(orig_img, task_type):
|
207 |
print(task_type)
|
|
|
229 |
_ = model.eval()
|
230 |
return model
|
231 |
|
232 |
+
def load_i2sb_model():
|
233 |
+
RESULT_DIR = Path("I2SB/results")
|
234 |
+
global i2sb_model
|
235 |
+
global ckpt_opt
|
236 |
+
global corrupt_type
|
237 |
+
global nfe
|
238 |
+
|
239 |
+
s = time.time()
|
240 |
+
|
241 |
+
# main from here
|
242 |
+
log = Logger(0, ".log")
|
243 |
+
|
244 |
+
# get (default) ckpt option
|
245 |
+
ckpt_opt = ckpt_util.build_ckpt_option(i2sb_opt, log, RESULT_DIR / i2sb_opt.ckpt)
|
246 |
+
corrupt_type = ckpt_opt.corrupt
|
247 |
+
nfe = i2sb_opt.nfe or ckpt_opt.interval-1
|
248 |
+
|
249 |
+
# build corruption method
|
250 |
+
# corrupt_method = build_corruption(i2sb_opt, log, corrupt_type=cor
|
251 |
+
# rupt_type)
|
252 |
+
runner = Runner(ckpt_opt, log, save_opt=False)
|
253 |
+
if i2sb_opt.use_fp16:
|
254 |
+
runner.ema.copy_to() # copy weight from ema to net
|
255 |
+
runner.net.diffusion_model.convert_to_fp16()
|
256 |
+
runner.ema = ExponentialMovingAverage(
|
257 |
+
runner.net.parameters(), decay=0.99) # re-init ema with fp16 weight
|
258 |
+
|
259 |
+
print("Loading time:", (time.time()-s)*1e3, "ms.")
|
260 |
+
i2sb_model = runner
|
261 |
+
return runner
|
262 |
+
|
263 |
def plot_boxes_to_image(image_pil, tgt):
|
264 |
H, W = tgt["size"]
|
265 |
boxes = tgt["boxes"]
|
|
|
325 |
image, _ = transform(image_pil, None) # 3, h, w
|
326 |
return image_pil, image
|
327 |
|
328 |
+
|
329 |
+
|
330 |
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
|
331 |
caption = caption.lower()
|
332 |
caption = caption.strip()
|
|
|
446 |
torch_dtype=torch.float16,
|
447 |
)
|
448 |
sd_model = sd_model.to(device)
|
449 |
+
|
450 |
+
def forward_i2sb(img, mask):
|
451 |
+
print(np.unique(img),mask.shape)
|
452 |
+
mask = np.where(mask > 0, 1, 0)
|
453 |
+
img_tensor = i2sb_transforms(img).to(
|
454 |
+
i2sb_opt.device).unsqueeze(0)
|
455 |
+
|
456 |
+
mask_tensor = torch.from_numpy(np.resize(np.array(mask), (256,256))).to(
|
457 |
+
i2sb_opt.device).unsqueeze(0).unsqueeze(0)
|
458 |
+
print("POST PROCESSING\t", torch.unique(img_tensor))
|
459 |
+
# corrupt_tensor = img_tensor * (1. - mask_tensor) + mask_tensor
|
460 |
+
f = time.time()
|
461 |
+
xs, _ = i2sb_model.ddpm_sampling(
|
462 |
+
ckpt_opt, img_tensor, mask=mask_tensor, cond=None, clip_denoise=i2sb_opt.clip_denoise, nfe=nfe, verbose=i2sb_opt.n_gpu_per_node == 1)
|
463 |
+
recon_img = xs[:, 0, ...].to(i2sb_opt.device)
|
464 |
+
tu.save_image((recon_img+1)/2, "output.png")
|
465 |
+
print(recon_img.shape)
|
466 |
+
return transforms.ToPILImage()(((recon_img+1)/2)[0])
|
467 |
|
468 |
def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
|
469 |
try:
|
|
|
618 |
return new_image
|
619 |
|
620 |
mask_source_draw = "draw a mask on input image"
|
621 |
+
mask_source_segment = "upload a mask"
|
622 |
|
623 |
def get_time_cost(run_task_time, time_cost_str):
|
624 |
now_time = int(time.time()*1000)
|
|
|
631 |
run_task_time = now_time
|
632 |
return run_task_time, time_cost_str
|
633 |
|
634 |
+
def run_anything_task(input_image, input_points, origin_image, task_type,
|
635 |
+
mask_source_radio, cleaner_size_limit=1080):
|
|
|
|
|
|
|
636 |
|
637 |
run_task_time = 0
|
638 |
time_cost_str = ''
|
|
|
647 |
image_pil, image = load_image(input_image.convert("RGB"))
|
648 |
input_img = input_image
|
649 |
|
650 |
+
kosmos_image, kosmos_text, kosmos_entities = kosmos_generate_predictions(image_pil, kosmos_model, kosmos_processor)
|
651 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
652 |
return None, None, time_cost_str, kosmos_image, gr.Textbox.update(visible=(time_cost_str !='')), kosmos_text, kosmos_entities
|
653 |
|
|
|
|
|
|
|
|
|
|
|
654 |
if input_image is None:
|
655 |
return [], gr.Gallery.update(label='Please upload a image!ππππ'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
656 |
|
657 |
file_temp = int(time.time())
|
658 |
+
logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/[{mask_source_radio}]_1_')
|
659 |
|
660 |
output_images = []
|
661 |
|
662 |
# load image
|
|
|
|
|
|
|
663 |
|
664 |
if isinstance(input_image, dict):
|
665 |
image_pil, image = load_image(input_image['image'].convert("RGB"))
|
|
|
722 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
723 |
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
724 |
elif task_type in ['inpainting', 'outpainting'] or task_type == 'remove':
|
725 |
+
if mask_source_radio == mask_source_segment:
|
726 |
task_type = 'remove'
|
727 |
|
728 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_4_')
|
729 |
if mask_source_radio == mask_source_draw:
|
730 |
+
input_mask_pil = input_image['mask']
|
731 |
+
input_mask = np.array(input_mask_pil.convert("L"))
|
732 |
mask_pil = input_mask_pil
|
733 |
mask = input_mask
|
734 |
else:
|
735 |
masks_ori = copy.deepcopy(masks)
|
|
|
|
|
736 |
masks = torch.where(masks > 0, True, False)
|
737 |
mask = masks[0][0].cpu().numpy()
|
738 |
mask_pil = Image.fromarray(mask)
|
|
|
740 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
741 |
|
742 |
if task_type in ['inpainting', 'outpainting']:
|
743 |
+
# image_inpainting = sd_model(prompt = "", image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
|
744 |
+
input_img.save("test.png")
|
745 |
+
image_inpainting = forward_i2sb(input_img, mask)
|
746 |
+
|
747 |
+
print("RESULT\t", np.array(image_inpainting))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
748 |
else:
|
749 |
# remove from mask
|
750 |
aasds = 1
|
|
|
770 |
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
771 |
|
772 |
def change_radio_display(task_type, mask_source_radio, orig_img):
|
|
|
|
|
773 |
mask_source_radio_visible = False
|
774 |
num_relation_visible = False
|
775 |
|
|
|
780 |
print(task_type)
|
781 |
if task_type == "Kosmos-2":
|
782 |
if kosmos_enable:
|
|
|
783 |
image_gallery_visible = False
|
784 |
kosmos_input_visible = True
|
785 |
kosmos_output_visible = True
|
786 |
kosmos_text_output_visible = True
|
787 |
|
|
|
|
|
788 |
if task_type in ['inpainting', 'outpainting'] or task_type == "remove":
|
789 |
mask_source_radio_visible = True
|
|
|
|
|
790 |
if task_type == "relate anything":
|
|
|
791 |
num_relation_visible = True
|
792 |
if task_type == "segment":
|
793 |
ret = gr.Image(value= orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "editor")# tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
|
794 |
elif task_type == "inpainting":
|
795 |
ret = gr.Image(value = orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
|
796 |
|
797 |
+
return (gr.Radio.update(visible=mask_source_radio_visible),
|
|
|
|
|
798 |
gr.Slider.update(visible=num_relation_visible),
|
799 |
gr.Gallery.update(visible=image_gallery_visible),
|
800 |
gr.Radio.update(visible=kosmos_input_visible),
|
801 |
gr.Image.update(visible=kosmos_output_visible),
|
802 |
gr.HighlightedText.update(visible=kosmos_text_output_visible),
|
803 |
+
ret, [],
|
804 |
+
gr.Button("Undo point", visible = task_type == "segment"),
|
805 |
+
gr.Button("Clear point", visible = task_type == "segment"),)
|
806 |
|
807 |
def get_model_device(module):
|
808 |
try:
|
|
|
851 |
[input_image, selected_points],
|
852 |
[input_image]
|
853 |
)
|
854 |
+
with gr.Row():
|
855 |
+
with gr.Column():
|
856 |
+
|
857 |
+
undo_point_button = gr.Button("Undo point")
|
858 |
+
undo_point_button.click(
|
859 |
+
fn= undo_button,
|
860 |
+
inputs=[original_image, selected_points],
|
861 |
+
outputs=[input_image]
|
862 |
+
)
|
863 |
+
|
864 |
+
with gr.Column():
|
865 |
+
|
866 |
+
clear_point_button = gr.Button("Clear point")
|
867 |
+
clear_point_button.click(
|
868 |
+
fn= clear_button,
|
869 |
+
inputs=[original_image],
|
870 |
+
outputs=[input_image, selected_points]
|
871 |
+
)
|
872 |
+
|
873 |
print(dir(input_image))
|
874 |
task_type = gr.Radio(task_types, value="segment",
|
875 |
label='Task type', visible=True)
|
876 |
mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
|
877 |
+
value=mask_source_draw, label="Mask from",
|
878 |
visible=False)
|
|
|
|
|
879 |
num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
|
880 |
|
881 |
kosmos_input = gr.Radio(["Brief", "Detailed"], label="Kosmos Description Type", value="Brief", visible=False)
|
882 |
|
883 |
run_button = gr.Button(label="Run", visible=True)
|
884 |
+
# with gr.Accordion("Advanced options", open=False) as advanced_options:
|
885 |
+
# box_threshold = gr.Slider(
|
886 |
+
# label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
|
887 |
+
# )
|
888 |
+
# text_threshold = gr.Slider(
|
889 |
+
# label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
|
890 |
+
# )
|
891 |
+
# iou_threshold = gr.Slider(
|
892 |
+
# label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001
|
893 |
+
# )
|
894 |
+
# inpaint_mode = gr.Radio(["merge", "first"], value="merge", label="inpaint_mode")
|
895 |
+
# with gr.Row():
|
896 |
+
# with gr.Column(scale=1):
|
897 |
+
# remove_mode = gr.Radio(["segment", "rectangle"], value="segment", label='remove mode')
|
898 |
+
# with gr.Column(scale=1):
|
899 |
+
# remove_mask_extend = gr.Textbox(label="remove_mask_extend", value='10')
|
900 |
|
901 |
with gr.Column():
|
902 |
image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", height=512, visible=True
|
|
|
932 |
selected.change(update_output_image, [kosmos_output, kosmos_output, entity_output, selected], [kosmos_output])
|
933 |
|
934 |
run_button.click(fn=run_anything_task, inputs=[
|
935 |
+
input_image, selected_points, original_image, task_type,
|
936 |
+
mask_source_radio],
|
937 |
outputs=[image_gallery, image_gallery, time_cost, time_cost, kosmos_output, kosmos_text_output, entity_output], show_progress=True, queue=True)
|
938 |
|
939 |
mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
|
940 |
+
outputs=[mask_source_radio, num_relation])
|
941 |
task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
|
942 |
+
outputs=[mask_source_radio, num_relation,
|
943 |
+
image_gallery, kosmos_input, kosmos_output, kosmos_text_output, input_image, selected_points, undo_point_button, clear_point_button
|
944 |
])
|
945 |
|
946 |
# DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
|
|
|
986 |
if sam_enable:
|
987 |
load_sam_model(device)
|
988 |
|
989 |
+
if inpainting_enable:
|
990 |
+
load_sd_model(device)
|
991 |
+
load_i2sb_model()
|
992 |
|
993 |
# if lama_cleaner_enable:
|
994 |
# load_lama_cleaner_model(device)
|