Thesis-Demo / test.py
xuan2k's picture
update demo
4979800
import warnings
warnings.filterwarnings('ignore')
import subprocess, io, os, sys, time
os.system("pip install gradio==3.50.2")
import gradio as gr
from loguru import logger
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# if os.environ.get('IS_MY_DEBUG') is None:
# result = subprocess.run(['pip', 'install', '-e', 'GroundingDINO'], check=True)
# print(f'pip install GroundingDINO = {result}')
logger.info(f"Start app...")
result = subprocess.run(['pip', 'list'], check=True)
print(f'pip list = {result}')
sys.path.insert(0, './GroundingDINO')
import argparse
import copy
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont, ImageOps
# Grounding DINO
import GroundingDINO.groundingdino.datasets.transforms as T
from GroundingDINO.groundingdino.models import build_model
from GroundingDINO.groundingdino.util import box_ops
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
# I2SB
import sys
sys.path.insert(0, "/home/ubuntu/Thesis-Demo/I2SB")
sys.path.insert(0, "/home/ubuntu/Thesis-Demo/SegFormer")
import numpy as np
import torch
import torch.distributed as dist
import torchvision.transforms as transforms
import torchvision.utils as tu
from easydict import EasyDict as edict
from fastapi import (Body, Depends, FastAPI, File, Form, HTTPException, Query,
UploadFile)
from ipdb import set_trace as debug
from PIL import Image
from torch.multiprocessing import Process
from torch.utils.data import DataLoader, Subset
from torch_ema import ExponentialMovingAverage
import I2SB.distributed_util as dist_util
from I2SB.corruption import build_corruption
from I2SB.dataset import air_liquide
from I2SB.i2sb import Runner, ckpt_util, download_ckpt
from I2SB.logger import Logger
from I2SB.sample import *
from pathlib import Path
inpaint_checkpoint = Path("/home/ubuntu/Thesis-Demo/I2SB/results")
if not inpaint_checkpoint.exists():
os.system("pip install transformers==4.32.0")
# SegFormer
from PIL import Image
from SegFormer.mmseg.apis import inference_segmentor, init_segmentor, visualize_result_pyplot
from SegFormer.mmseg.core.evaluation import get_palette
import cv2
import numpy as np
import matplotlib
matplotlib.use('AGG')
plt = matplotlib.pyplot
# import matplotlib.pyplot as plt
groundingdino_enable = True
sam_enable = True
inpainting_enable = True
ram_enable = False
lama_cleaner_enable = True
kosmos_enable = False
# qwen_enable = True
# from qwen_utils import *
if os.environ.get('IS_MY_DEBUG') is not None:
sam_enable = False
ram_enable = False
inpainting_enable = False
kosmos_enable = False
# segment anything
from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
# diffusers
import PIL
import requests
import torch
from io import BytesIO
from diffusers import StableDiffusionInpaintPipeline
from huggingface_hub import hf_hub_download
from util_computer import computer_info
# relate anything
from ram_utils import iou, sort_and_deduplicate, relation_classes, MLP, show_anns, ram_show_mask
from ram_train_eval import RamModel, RamPredictor
from mmengine.config import Config as mmengine_Config
if lama_cleaner_enable:
from lama_cleaner.helper import (
load_img,
numpy_to_bytes,
resize_max_size,
)
# from transformers import AutoProcessor, AutoModelForVision2Seq
import ast
if kosmos_enable:
os.system("pip install transformers@git+https://github.com/huggingface/transformers.git@main")
# os.system("pip install transformers==4.32.0")
from kosmos_utils import *
from util_tencent import getTextTrans
config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
ckpt_repo_id = "ShilongLiu/GroundingDINO"
ckpt_filenmae = "groundingdino_swint_ogc.pth"
sam_checkpoint = './sam_vit_h_4b8939.pth'
output_dir = "outputs"
device = 'cpu'
os.makedirs(output_dir, exist_ok=True)
groundingdino_model = None
sam_device = None
sam_model = None
sam_predictor = None
sam_mask_generator = None
sd_model = None
lama_cleaner_model= None
ram_model = None
kosmos_model = None
kosmos_processor = None
colors = [(255, 0, 0), (0, 255, 0)]
markers = [1, 5]
i2sb_opt = edict(
distributed=False,
device="cuda",
batch_size=1,
nfe=10,
dataset="sample",
dataset_dir=Path(f"dataset/sample"),
n_gpu_per_node=1,
use_fp16=False,
ckpt="inpaint-freeform2030",
image_size=256,
partition=None,
global_size=1,
global_rank=0,
clip_denoise=True
)
i2sb_transforms = transforms.Compose([
transforms.Resize(i2sb_opt.image_size),
transforms.CenterCrop(i2sb_opt.image_size),
transforms.ToTensor(),
transforms.Lambda(lambda t: (t * 2) - 1) # [0,1] --> [-1, 1]
])
def get_point(img, sel_pix, evt: gr.SelectData):
img = np.array(img, dtype=np.uint8)
sel_pix.append(evt.index)
# draw points
print(sel_pix)
for point in sel_pix:
cv2.drawMarker(img, point, colors[0], markerType=markers[0], markerSize=6, thickness=2)
return Image.fromarray(img).convert("RGB")
def undo_button(orig_img, sel_pix):
if orig_img:
temp = orig_img.copy()
temp = np.array(temp, dtype=np.uint8)
if len(sel_pix) != 0:
sel_pix.pop()
for point in sel_pix:
cv2.drawMarker(temp, point, colors[0], markerType=markers[0], markerSize=6, thickness=2)
return Image.fromarray(temp).convert("RGB")
return orig_img
def clear_button(orig_img):
return orig_img, []
def toggle_button(orig_img, task_type):
print(task_type)
if task_type == "segment":
ret = gr.Image(value= orig_img,elem_id="image_upload", type='pil', label="Upload", height=512, tool = "editor")# tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
elif task_type == "inpainting":
ret = gr.Image(value = orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
task_type = not task_type
return ret, task_type
def store_img(img):
print("call for store")
return img, [] # when new image is uploaded, `selected_points` should be empty
def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
args = SLConfig.fromfile(model_config_path)
model = build_model(args)
args.device = device
cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
checkpoint = torch.load(cache_file, map_location=device)
log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
print("Model loaded from {} \n => {}".format(cache_file, log))
_ = model.eval()
return model
def load_i2sb_model():
RESULT_DIR = Path("I2SB/results")
global i2sb_model
global ckpt_opt
global corrupt_type
global nfe
s = time.time()
# main from here
log = Logger(0, ".log")
# get (default) ckpt option
ckpt_opt = ckpt_util.build_ckpt_option(i2sb_opt, log, RESULT_DIR / i2sb_opt.ckpt)
corrupt_type = ckpt_opt.corrupt
nfe = i2sb_opt.nfe or ckpt_opt.interval-1
# build corruption method
# corrupt_method = build_corruption(i2sb_opt, log, corrupt_type=cor
# rupt_type)
runner = Runner(ckpt_opt, log, save_opt=False)
if i2sb_opt.use_fp16:
runner.ema.copy_to() # copy weight from ema to net
runner.net.diffusion_model.convert_to_fp16()
runner.ema = ExponentialMovingAverage(
runner.net.parameters(), decay=0.99) # re-init ema with fp16 weight
logger.info(f"I2SB Loading time:\t {(time.time()-s)*1e3} ms.")
print("Loading time:", (time.time()-s)*1e3, "ms.")
i2sb_model = runner
return runner
def load_segformer(device):
global segformer_model
s = time.time()
config = "SegFormer/local_configs/segformer/B3/segformer.b3.256x256.wtm.160k.py"
checkpoint = "SegFormer/work_dirs/segformer.b3.256x256.wtm.160k/iter_160000.pth"
model = init_segmentor(config, checkpoint, device=device)
logger.info(f"SegFormer Loading time:\t {(time.time()-s)*1e3} ms.")
segformer_model = model
return model
def plot_boxes_to_image(image_pil, tgt):
H, W = tgt["size"]
boxes = tgt["boxes"]
labels = tgt["labels"]
assert len(boxes) == len(labels), "boxes and labels must have same length"
draw = ImageDraw.Draw(image_pil)
mask = Image.new("L", image_pil.size, 0)
mask_draw = ImageDraw.Draw(mask)
# draw boxes and masks
for box, label in zip(boxes, labels):
# from 0..1 to 0..W, 0..H
box = box * torch.Tensor([W, H, W, H])
# from xywh to xyxy
box[:2] -= box[2:] / 2
box[2:] += box[:2]
# random color
color = tuple(np.random.randint(0, 255, size=3).tolist())
# draw
x0, y0, x1, y1 = box
x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
# draw.text((x0, y0), str(label), fill=color)
font = ImageFont.load_default()
if hasattr(font, "getbbox"):
bbox = draw.textbbox((x0, y0), str(label), font)
else:
w, h = draw.textsize(str(label), font)
bbox = (x0, y0, w + x0, y0 + h)
# bbox = draw.textbbox((x0, y0), str(label))
draw.rectangle(bbox, fill=color)
try:
font = os.path.join(cv2.__path__[0],'qt','fonts','DejaVuSans.ttf')
font_size = 36
new_font = ImageFont.truetype(font, font_size)
draw.text((x0+2, y0+2), str(label), font=new_font, fill="white")
except Exception as e:
pass
mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6)
return image_pil, mask
def load_image(image_path):
# # load image
if isinstance(image_path, PIL.Image.Image):
image_pil = image_path
else:
image_pil = Image.open(image_path).convert("RGB") # load image
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image, _ = transform(image_pil, None) # 3, h, w
return image_pil, image
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_box(box, ax, label):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
ax.text(x0, y0, label)
def xywh_to_xyxy(box, sizeW, sizeH):
if isinstance(box, list):
box = torch.Tensor(box)
box = box * torch.Tensor([sizeW, sizeH, sizeW, sizeH])
box[:2] -= box[2:] / 2
box[2:] += box[:2]
box = box.numpy()
return box
def mask_extend(img, box, extend_pixels=10, useRectangle=True):
box[0] = int(box[0])
box[1] = int(box[1])
box[2] = int(box[2])
box[3] = int(box[3])
region = img.crop(tuple(box))
new_width = box[2] - box[0] + 2*extend_pixels
new_height = box[3] - box[1] + 2*extend_pixels
region_BILINEAR = region.resize((int(new_width), int(new_height)))
if useRectangle:
region_draw = ImageDraw.Draw(region_BILINEAR)
region_draw.rectangle((0, 0, new_width, new_height), fill=(255, 255, 255))
img.paste(region_BILINEAR, (int(box[0]-extend_pixels), int(box[1]-extend_pixels)))
return img
def mix_masks(imgs):
re_img = 1 - np.asarray(imgs[0].convert("1"))
for i in range(len(imgs)-1):
re_img = np.multiply(re_img, 1 - np.asarray(imgs[i+1].convert("1")))
re_img = 1 - re_img
return Image.fromarray(np.uint8(255*re_img))
def set_device(args):
global device
if os.environ.get('IS_MY_DEBUG') is None:
device = args.cuda if torch.cuda.is_available() else 'cpu'
else:
device = 'cpu'
print(f'device={device}')
def get_sam_vit_h_4b8939():
if not os.path.exists('./sam_vit_h_4b8939.pth'):
logger.info(f"get sam_vit_h_4b8939.pth...")
result = subprocess.run(['wget', '-nv', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'], check=True)
print(f'wget sam_vit_h_4b8939.pth result = {result}')
def load_sam_model(device):
# initialize SAM
global sam_model, sam_predictor, sam_mask_generator, sam_device
get_sam_vit_h_4b8939()
logger.info(f"initialize SAM model...")
sam_device = device
sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
sam_predictor = SamPredictor(sam_model)
sam_mask_generator = SamAutomaticMaskGenerator(sam_model)
def load_sd_model(device):
# initialize stable-diffusion-inpainting
global sd_model
logger.info(f"initialize stable-diffusion-inpainting...")
sd_model = None
if os.environ.get('IS_MY_DEBUG') is None:
sd_model = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
revision="fp16",
# "stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float16,
)
sd_model = sd_model.to(device)
def forward_i2sb(img, mask, dilation_mask_extend):
print(np.unique(mask),mask.shape)
mask = np.where(mask > 0, 1, 0)
print(np.unique(mask),mask.shape)
mask = mask.astype(np.uint8)
if dilation_mask_extend.isdigit():
kernel_size = int(dilation_mask_extend)
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (int(kernel_size), int(kernel_size)))
mask = cv2.dilate(mask, kernel, iterations = 1)
img_tensor = i2sb_transforms(img).to(
i2sb_opt.device).unsqueeze(0)
mask_tensor = torch.from_numpy(np.resize(np.array(mask), (256,256))).to(
i2sb_opt.device).unsqueeze(0).unsqueeze(0)
# print("POST PROCESSING\t", torch.unique(img_tensor))
corrupt_tensor = img_tensor * (1. - mask_tensor) + mask_tensor
print("DOUBLE CHECK:\t", corrupt_tensor.shape)
print("DOUBLE CHECK:\t", img_tensor.shape)
print("DOUBLE CHECK:\t", mask_tensor.shape)
f = time.time()
xs, _ = i2sb_model.ddpm_sampling(
ckpt_opt, img_tensor, mask=mask_tensor, cond=None, clip_denoise=i2sb_opt.clip_denoise, nfe=nfe, verbose=i2sb_opt.n_gpu_per_node == 1)
recon_img = xs[:, 0, ...].to(i2sb_opt.device)
# tu.save_image((recon_img+1)/2, "output.png")
# tu.save_image((corrupt_tensor+1)/2, "output.png")
print(recon_img.shape)
return transforms.ToPILImage()(((recon_img+1)/2)[0]), transforms.ToPILImage()(((corrupt_tensor+1)/2)[0])
def forward_segformer(img):
img_np = np.array(img)
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
result = inference_segmentor(segformer_model, img_np)
return np.asarray(result[0], dtype=np.uint8)
# visualization
def draw_selected_mask(mask, draw):
color = (255, 0, 0, 153)
nonzero_coords = np.transpose(np.nonzero(mask))
for coord in nonzero_coords:
draw.point(coord[::-1], fill=color)
def draw_object_mask(mask, draw):
color = (0, 0, 255, 153)
nonzero_coords = np.transpose(np.nonzero(mask))
for coord in nonzero_coords:
draw.point(coord[::-1], fill=color)
def create_title_image(word1, word2, word3, width, font_path='./assets/OpenSans-Bold.ttf'):
# Define the colors to use for each word
color_red = (255, 0, 0)
color_black = (0, 0, 0)
color_blue = (0, 0, 255)
# Define the initial font size and spacing between words
font_size = 40
# Create a new image with the specified width and white background
image = Image.new('RGB', (width, 60), (255, 255, 255))
try:
# Load the specified font
font = ImageFont.truetype(font_path, font_size)
# Keep increasing the font size until all words fit within the desired width
while True:
# Create a draw object for the image
draw = ImageDraw.Draw(image)
word_spacing = font_size / 2
# Draw each word in the appropriate color
x_offset = word_spacing
draw.text((x_offset, 0), word1, color_red, font=font)
x_offset += font.getsize(word1)[0] + word_spacing
draw.text((x_offset, 0), word2, color_black, font=font)
x_offset += font.getsize(word2)[0] + word_spacing
draw.text((x_offset, 0), word3, color_blue, font=font)
word_sizes = [font.getsize(word) for word in [word1, word2, word3]]
total_width = sum([size[0] for size in word_sizes]) + word_spacing * 3
# Stop increasing font size if the image is within the desired width
if total_width <= width:
break
# Increase font size and reset the draw object
font_size -= 1
image = Image.new('RGB', (width, 50), (255, 255, 255))
font = ImageFont.truetype(font_path, font_size)
draw = None
except Exception as e:
pass
return image
def concatenate_images_vertical(image1, image2):
# Get the dimensions of the two images
width1, height1 = image1.size
width2, height2 = image2.size
# Create a new image with the combined height and the maximum width
new_image = Image.new('RGBA', (max(width1, width2), height1 + height2))
# Paste the first image at the top of the new image
new_image.paste(image1, (0, 0))
# Paste the second image below the first image
new_image.paste(image2, (0, height1))
return new_image
mask_source_draw = "draw a mask on input image"
mask_source_segment = "upload a mask"
def get_time_cost(run_task_time, time_cost_str):
now_time = int(time.time()*1000)
if run_task_time == 0:
time_cost_str = 'start'
else:
if time_cost_str != '':
time_cost_str += f'-->'
time_cost_str += f'{now_time - run_task_time}'
run_task_time = now_time
return run_task_time, time_cost_str
def run_anything_task(input_image, input_points, origin_image, task_type,
mask_source_radio, segmentation_radio, dilation_mask_extend):
run_task_time = 0
time_cost_str = ''
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
print("HERE................", task_type)
if input_image is None:
return [], gr.Gallery.update(label='Please upload a image!😂😂😂😂'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !=''))
file_temp = int(time.time())
logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/[{mask_source_radio}]_1_')
output_images = []
# load image
if isinstance(input_image, dict):
image_pil, image = load_image(input_image['image'].convert("RGB"))
input_img = input_image['image']
output_images.append(input_image['image'])
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
else:
image_pil, image = load_image(input_image.convert("RGB"))
input_img = input_image
output_images.append(input_image)
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
size = image_pil.size
H, W = size[1], size[0]
# run grounding dino model
if (task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_draw:
pass
else:
groundingdino_device = 'cpu'
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
if task_type == 'segment' or task_type == 'pipeline':
image = np.array(origin_image)
if segmentation_radio == "SAM":
if sam_predictor:
sam_predictor.set_image(image)
if sam_predictor:
logger.info(f"Forward with: {input_points}")
masks, _, _, _ = sam_predictor.predict(
point_coords = np.array(input_points),
point_labels = np.array([1 for _ in range(len(input_points))]),
# boxes = transformed_boxes,
multimask_output = False,
)
# masks: [9, 1, 512, 512]
assert sam_checkpoint, 'sam_checkpoint is not found!'
else:
run_mode = "rectangle"
# draw output image
plt.figure(figsize=(10, 10))
plt.imshow(origin_image)
for mask in masks:
show_mask(mask, plt.gca(), random_color=True)
# for box, label in zip(boxes_filt, pred_phrases):
# show_box(box.cpu().numpy(), plt.gca(), label)
plt.axis('off')
image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.jpg")
plt.savefig(image_path, bbox_inches="tight")
plt.clf()
plt.close('all')
segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
os.remove(image_path)
else:
masks = forward_segformer(image)
segment_image_result = visualize_result_pyplot(segformer_model, image, masks, get_palette("wtm"), dilation=dilation_mask_extend)# if task_type == "pipeline" else None)
output_images.append(Image.fromarray(segment_image_result))
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
if task_type == 'detection' or task_type == 'segment':
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !=''))
elif task_type in ['inpainting', 'outpainting'] or task_type == 'pipeline':
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_4_')
if task_type == "pipeline":
if segmentation_radio == "SAM":
masks_ori = copy.deepcopy(masks)
print(masks.shape)
# masks = torch.where(masks > 0, True, False)
mask = masks[0]
mask_pil = Image.fromarray(mask)
mask = np.where(mask == True, 1, 0)
else:
mask = masks
save_mask = copy.deepcopy(mask)
save_mask = np.where(mask > 0, 255, 0).astype(np.uint8)
print((save_mask.dtype))
mask_pil = Image.fromarray(save_mask)
else:
if mask_source_radio == mask_source_draw:
input_mask_pil = input_image['mask']
input_mask = np.array(input_mask_pil.convert("L"))
mask_pil = input_mask_pil
mask = input_mask
else:
pass
# masks_ori = copy.deepcopy(masks)
# masks = torch.where(masks > 0, True, False)
# mask = masks[0][0].cpu().numpy()
# mask_pil = Image.fromarray(mask)
output_images.append(mask_pil.convert("RGB"))
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
if task_type in ['inpainting', 'pipeline']:
# image_inpainting = sd_model(prompt = "", image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
# input_img.save("test.png")
w, h = input_img.size
input_img = input_img.resize((256,256))
image_inpainting, corrupted = forward_i2sb(input_img, mask, dilation_mask_extend)
input_img = input_img.resize((w,h))
corrupted = corrupted.resize((w,h))
image_inpainting = image_inpainting.resize((w,h))
# print("RESULT\t", np.array(image_inpainting))
else:
# remove from mask
aasds = 1
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_6_')
if image_inpainting is None:
logger.info(f'run_anything_task_failed_')
return None, None, None, None
# output_images.append(image_inpainting)
# run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_7_')
image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
output_images.append(corrupted)
output_images.append(image_inpainting)
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !=''))
else:
logger.info(f"task_type:{task_type} error!")
logger.info(f'run_anything_task_[{file_temp}]_9_9_')
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !=''))
def change_radio_display(task_type, mask_source_radio, orig_img):
mask_source_radio_visible = False
num_relation_visible = False
image_gallery_visible = True
kosmos_input_visible = False
kosmos_output_visible = False
kosmos_text_output_visible = False
print(task_type)
if task_type == "Kosmos-2":
if kosmos_enable:
image_gallery_visible = False
kosmos_input_visible = True
kosmos_output_visible = True
kosmos_text_output_visible = True
if task_type in ['inpainting', 'outpainting'] or task_type == "remove":
mask_source_radio_visible = True
if task_type == "relate anything":
num_relation_visible = True
if task_type == "inpainting":
ret = gr.Image(value = orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
elif task_type in ["segment", "pipeline"]:
ret = gr.Image(value= orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "editor")# tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
return (gr.Radio.update(visible=mask_source_radio_visible),
gr.Slider.update(visible=num_relation_visible),
gr.Gallery.update(visible=image_gallery_visible),
gr.Radio(["SegFormer", "SAM"], value="SAM", label="Segementation Model", visible= task_type != "inpainting"),
gr.Textbox(label="Dilation kernel size", value='7', visible= task_type == "pipeline"),
ret, [],
gr.Button("Undo point", visible = task_type != "inpainting"),
gr.Button("Clear point", visible = task_type != "inpainting"),)
def get_model_device(module):
try:
if module is None:
return 'None'
if isinstance(module, torch.nn.DataParallel):
module = module.module
for submodule in module.children():
if hasattr(submodule, "_parameters"):
parameters = submodule._parameters
if "weight" in parameters:
return parameters["weight"].device
return 'UnKnown'
except Exception as e:
return 'Error'
def click_callback(coords):
print("Clicked at here: ", coords)
def main_gradio(args):
block = gr.Blocks(
title="Thesis-Demo",
# theme="shivi/calm_seafoam@>=0.0.1,<1.0.0",
)
with block:
with gr.Row():
with gr.Column():
selected_points = gr.State([])
original_image = gr.State(None)
task_types = ["segment"]
if inpainting_enable:
task_types.append("inpainting")
task_types.append("pipeline")
input_image = gr.Image(elem_id="image_upload", type='pil', label="Upload", height=512)
input_image.upload(
store_img,
[input_image],
[original_image, selected_points]
)
input_image.select(
get_point,
[input_image, selected_points],
[input_image]
)
with gr.Row():
with gr.Column():
undo_point_button = gr.Button("Undo point", visible= True if original_image is not None else False)
undo_point_button.click(
fn= undo_button,
inputs=[original_image, selected_points],
outputs=[input_image]
)
with gr.Column():
clear_point_button = gr.Button("Clear point", visible= True if original_image is not None else False)
clear_point_button.click(
fn= clear_button,
inputs=[original_image],
outputs=[input_image, selected_points]
)
print(dir(input_image))
task_type = gr.Radio(task_types, value="segment",
label='Task type', visible=True)
mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
value=mask_source_draw, label="Mask from",
visible=False)
segmentation_radio = gr.Radio(["SegFormer", "SAM"],
value="SAM", label="Segementation Model",
visible=True)
dilation_mask_extend = gr.Textbox(label="Dilation kernel size", value='5', visible=False)
num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
run_button = gr.Button(label="Run", visible=True)
# with gr.Accordion("Advanced options", open=False) as advanced_options:
# box_threshold = gr.Slider(
# label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
# )
# text_threshold = gr.Slider(
# label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
# )
# iou_threshold = gr.Slider(
# label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001
# )
# inpaint_mode = gr.Radio(["merge", "first"], value="merge", label="inpaint_mode")
# with gr.Row():
# with gr.Column(scale=1):
# remove_mode = gr.Radio(["segment", "rectangle"], value="segment", label='remove mode')
# with gr.Column(scale=1):
# remove_mask_extend = gr.Textbox(label="remove_mask_extend", value='10')
with gr.Column():
image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", height=512, visible=True
).style(preview=True, columns=[5], object_fit="scale-down", height=512)
time_cost = gr.Textbox(label="Time cost by step (ms):", visible=False, interactive=False)
run_button.click(fn=run_anything_task, inputs=[
input_image, selected_points, original_image, task_type,
mask_source_radio, segmentation_radio, dilation_mask_extend],
outputs=[image_gallery, image_gallery, time_cost, time_cost], show_progress=True, queue=True)
mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
outputs=[mask_source_radio, num_relation])
task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
outputs=[mask_source_radio, num_relation,
image_gallery, segmentation_radio, dilation_mask_extend, input_image, selected_points, undo_point_button, clear_point_button
])
# DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
# if lama_cleaner_enable:
# DESCRIPTION += f'Remove(cleaner) from [lama-cleaner](https://github.com/Sanster/lama-cleaner). <br>'
# if kosmos_enable:
# DESCRIPTION += f'Kosmos-2 from [Kosmos-2](https://github.com/microsoft/unilm/tree/master/kosmos-2). <br>'
# if ram_enable:
# DESCRIPTION += f'RAM from [RelateAnything](https://github.com/Luodian/RelateAnything). <br>'
# DESCRIPTION += f'Thanks for their excellent work.'
# DESCRIPTION += f'<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. \
# <a href="https://huggingface.co/spaces/yizhangliu/Grounded-Segment-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
# gr.Markdown(DESCRIPTION)
print(f'device = {device}')
print(f'torch.cuda.is_available = {torch.cuda.is_available()}')
computer_info()
block.queue(max_size=10, api_open=False)
block.launch(server_name='0.0.0.0', server_port=args.port, debug=args.debug, share=args.share)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
parser.add_argument("--debug", action="store_true", help="using debug mode")
parser.add_argument("--share", action="store_true", help="share the app")
parser.add_argument("--port", "-p", type=int, default=7860, help="port")
parser.add_argument("--cuda", "-c", type=str, default='cuda:0', help="cuda")
args, _ = parser.parse_known_args()
print(f'args = {args}')
# if os.environ.get('IS_MY_DEBUG') is None:
# os.system("pip list")
set_device(args)
if device == 'cpu':
kosmos_enable = False
# if kosmos_enable:
# kosmos_model, kosmos_processor = load_kosmos_model(device)
# if groundingdino_enable:
# load_groundingdino_model('cpu')
if sam_enable:
load_sam_model(device)
load_segformer(device)
if inpainting_enable:
load_sd_model(device)
load_i2sb_model()
# if lama_cleaner_enable:
# load_lama_cleaner_model(device)
# if ram_enable:
# load_ram_model(device)
# if os.environ.get('IS_MY_DEBUG') is None:
# os.system("pip list")
main_gradio(args)