ASAM / app.py
xhk's picture
Update app.py
84b6050 verified
import os, sys
import random
import warnings
import copy
os.system("python -m pip install -e asam")
os.system("python -m pip install -e GroundingDINO")
# os.system("python -m pip uninstall gradio")
os.system("python -m pip install gradio==3.38.0")
os.system("pip install opencv-python pycocotools matplotlib onnxruntime onnx ipykernel")
sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))
sys.path.append(os.path.join(os.getcwd(), "asam"))
warnings.filterwarnings("ignore")
import gradio as gr
import argparse
import numpy as np
import torch
import torchvision
from PIL import Image, ImageDraw, ImageFont
from scipy import ndimage
# Grounding DINO
import GroundingDINO.groundingdino.datasets.transforms as T
from GroundingDINO.groundingdino.models import build_model
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
# segment anything
from segment_anything import build_sam_vit_b, SamPredictor
import numpy as np
# BLIP
from transformers import BlipProcessor, BlipForConditionalGeneration
def generate_caption(processor, blip_model, raw_image):
# unconditional image captioning
inputs = processor(raw_image, return_tensors="pt").to(
device) #fp 16
out = blip_model.generate(**inputs)
caption = processor.decode(out[0], skip_special_tokens=True)
return caption
def transform_image(image_pil):
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image, _ = transform(image_pil, None) # 3, h, w
return image
def load_model(model_config_path, model_checkpoint_path, device):
args = SLConfig.fromfile(model_config_path)
args.device = device
model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
load_res = model.load_state_dict(
clean_state_dict(checkpoint["model"]), strict=False)
print(load_res)
_ = model.eval()
return model
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True):
caption = caption.lower()
caption = caption.strip()
if not caption.endswith("."):
caption = caption + "."
with torch.no_grad():
outputs = model(image[None], captions=[caption])
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
logits.shape[0]
# filter output
logits_filt = logits.clone()
boxes_filt = boxes.clone()
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
logits_filt = logits_filt[filt_mask] # num_filt, 256
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
logits_filt.shape[0]
# get phrase
tokenlizer = model.tokenizer
tokenized = tokenlizer(caption)
# build pred
pred_phrases = []
scores = []
for logit, box in zip(logits_filt, boxes_filt):
pred_phrase = get_phrases_from_posmap(
logit > text_threshold, tokenized, tokenlizer)
if with_logits:
pred_phrases.append(
pred_phrase + f"({str(logit.max().item())[:4]})")
else:
pred_phrases.append(pred_phrase)
scores.append(logit.max().item())
return boxes_filt, torch.Tensor(scores), pred_phrases
def draw_mask(mask, draw, random_color=False):
if random_color:
color = (random.randint(0, 255), random.randint(
0, 255), random.randint(0, 255), 153)
else:
color = (30, 144, 255, 153)
nonzero_coords = np.transpose(np.nonzero(mask))
for coord in nonzero_coords:
draw.point(coord[::-1], fill=color)
def draw_box(box, draw, label):
# random color
color = tuple(np.random.randint(0, 255, size=3).tolist())
draw.rectangle(((box[0], box[1]), (box[2], box[3])),
outline=color, width=2)
if label:
font = ImageFont.load_default()
if hasattr(font, "getbbox"):
bbox = draw.textbbox((box[0], box[1]), str(label), font)
else:
w, h = draw.textsize(str(label), font)
bbox = (box[0], box[1], w + box[0], box[1] + h)
draw.rectangle(bbox, fill=color)
draw.text((box[0], box[1]), str(label), fill="white")
draw.text((box[0], box[1]), label)
def draw_point(point, draw, r=10):
show_point = []
for p in point:
x,y = p
draw.ellipse((x-r, y-r, x+r, y+r), fill='green')
config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
ckpt_filenmae = "groundingdino_swint_ogc.pth"
sam_checkpoint = 'sam_vit_b_01ec64.pth'
asam_checkpoint = 'asam_vit_b.pth'
output_dir = "outputs"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
blip_processor = None
blip_model = None
groundingdino_model = None
sam_predictor = None
def run_grounded_sam(input_image, text_prompt, task_type, box_threshold, text_threshold, iou_threshold):
print(text_prompt, type(text_prompt))
global blip_processor, blip_model, groundingdino_model, sam_predictor
# make dir
os.makedirs(output_dir, exist_ok=True)
# load image
scribble = np.array(input_image["mask"])
image_pil = input_image["image"].convert("RGB")
transformed_image = transform_image(image_pil)
print('img sum:' ,torch.sum(transformed_image).to(torch.int).item())
if groundingdino_model is None:
groundingdino_model = load_model(
config_file, ckpt_filenmae, device=device)
if task_type == 'automatic':
# generate caption and tags
# use Tag2Text can generate better captions
# https://huggingface.co/spaces/xinyu1205/Tag2Text
# but there are some bugs...
blip_processor = blip_processor or BlipProcessor.from_pretrained(
"Salesforce/blip-image-captioning-large")
blip_model = blip_model or BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-large").to(device) #torch_dtype=torch.float16
text_prompt = generate_caption(blip_processor, blip_model, image_pil)
print(f"Caption: {text_prompt}")
# run grounding dino model
boxes_filt, scores, pred_phrases = get_grounding_output(
groundingdino_model, transformed_image, text_prompt, box_threshold, text_threshold
)
size = image_pil.size
# process boxes
H, W = size[1], size[0]
for i in range(boxes_filt.size(0)):
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]
boxes_filt = boxes_filt.cpu()
# nms
print(f"Before NMS: {boxes_filt.shape[0]} boxes")
nms_idx = torchvision.ops.nms(
boxes_filt, scores, iou_threshold).numpy().tolist()
boxes_filt = boxes_filt[nms_idx]
pred_phrases = [pred_phrases[idx] for idx in nms_idx]
print(f"After NMS: {boxes_filt.shape[0]} boxes")
if sam_predictor is None:
# initialize SAM
assert sam_checkpoint, 'sam_checkpoint is not found!'
sam = build_sam_vit_b(checkpoint=sam_checkpoint)
sam.to(device=device)
sam_predictor = SamPredictor(sam)
image = np.array(image_pil)
sam_predictor.set_image(image)
if task_type == 'automatic':
# use NMS to handle overlapped boxes
print(f"Revise caption with number: {text_prompt}")
if task_type == 'default_box' or task_type == 'automatic' or task_type == 'scribble_box':
if task_type == 'default_box':
id = torch.sum(transformed_image).to(torch.int).item()
if id == -1683627: #example 1 *
x_min, y_min, x_max, y_max = 204, 213, 813, 1023
elif id == 1137390: #example 2 *
x_min, y_min, x_max, y_max = 125, 168, 842, 904
elif id == 1145309: #example 3 *
x_min, y_min, x_max, y_max = 0, 486, 992, 899
elif id == 1091779: #example 4 *
x_min, y_min, x_max, y_max = 2, 73, 981, 968
elif id == -1335352: #example 5 *
x_min, y_min, x_max, y_max = 201, 195, 811, 1023
elif id == -1479645: #example 6
x_min, y_min, x_max, y_max = 428, 0, 992, 799
elif id == -544197: #example 7
x_min, y_min, x_max, y_max = 106, 419, 312, 783
elif id == -23873: #example 8
x_min, y_min, x_max, y_max = 250, 25, 774, 803
elif id == -1572157: #example 9 *
x_min, y_min, x_max, y_max = 15, 88, 1006, 977
elif id == -509470: #example 10
x_min, y_min, x_max, y_max = 190, 0, 530, 395
elif id == -42440: #example 11
x_min, y_min, x_max, y_max = 282, 134, 534, 394
else:
print(id, "not defined")
raise NotImplementedError
bbox = np.array([x_min, y_min, x_max, y_max])
bbox = torch.tensor(bbox).unsqueeze(0)
transformed_boxes = sam_predictor.transform.apply_boxes_torch(bbox, image.shape[:2]).to(device)
elif task_type == 'scribble_box':
scribble = scribble.transpose(2, 1, 0)[0]
labeled_array, num_features = ndimage.label(scribble >= 255)
centers = ndimage.center_of_mass(scribble, labeled_array, range(1, num_features+1))
centers = np.array(centers)
### (x1, y1, x2, y2)
x_min = centers[:, 0].min()
x_max = centers[:, 0].max()
y_min = centers[:, 1].min()
y_max = centers[:, 1].max()
bbox = np.array([x_min, y_min, x_max, y_max])
bbox = torch.tensor(bbox).unsqueeze(0)
transformed_boxes = sam_predictor.transform.apply_boxes_torch(bbox, image.shape[:2]).to(device)
else:
transformed_boxes = sam_predictor.transform.apply_boxes_torch(
boxes_filt, image.shape[:2]).to(device)
a_image_pil = copy.deepcopy(image_pil)
# sam`s output
sam_predictor.model.load_state_dict(torch.load(sam_checkpoint,map_location='cpu'))
masks, _, _ = sam_predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False,
)
print(torch.sum(masks), masks.device)
# masks: [1, 1, 512, 512]
mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
mask_draw = ImageDraw.Draw(mask_image)
for mask in masks:
draw_mask(mask[0].cpu().numpy(), mask_draw, random_color=True)
image_draw = ImageDraw.Draw(image_pil)
if task_type == 'scribble_box' or task_type == 'default_box':
for box in bbox:
draw_box(box, image_draw, None)
else:
for box, label in zip(boxes_filt, pred_phrases):
draw_box(box, image_draw, label)
if task_type == 'automatic':
image_draw.text((10, 10), text_prompt, fill='black')
image_pil = image_pil.convert('RGBA')
image_pil.alpha_composite(mask_image)
# asam`s output
total_weights = 0
for param in sam_predictor.model.parameters():
total_weights += param.data.sum()
print("Total sum of model weights:", total_weights.item())
sam_predictor.model.load_state_dict(torch.load(asam_checkpoint,map_location='cpu'))
total_weights = 0
for param in sam_predictor.model.parameters():
total_weights += param.data.sum()
print("Total sum of model weights:", total_weights.item())
a_masks, _, _ = sam_predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False,
)
print(torch.sum(a_masks))
# masks: [1, 1, 512, 512]
a_mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
a_mask_draw = ImageDraw.Draw(a_mask_image)
for a_mask in a_masks:
draw_mask(a_mask[0].cpu().numpy(), a_mask_draw, random_color=True)
a_image_draw = ImageDraw.Draw(a_image_pil)
if task_type == 'scribble_box' or task_type == 'default_box':
for box in bbox:
draw_box(box, a_image_draw, None)
else:
for box, label in zip(boxes_filt, pred_phrases):
draw_box(box, a_image_draw, label)
if task_type == 'automatic':
a_image_draw.text((10, 10), text_prompt, fill='black')
a_image_pil = a_image_pil.convert('RGBA')
a_image_pil.alpha_composite(a_mask_image)
return [[image_pil, mask_image],[a_image_pil, a_mask_image]]
elif task_type == 'scribble_point':
scribble = scribble.transpose(2, 1, 0)[0]
labeled_array, num_features = ndimage.label(scribble >= 255)
centers = ndimage.center_of_mass(scribble, labeled_array, range(1, num_features+1))
centers = np.array(centers)
point_coords = centers
point_labels = np.ones(point_coords.shape[0])
a_image_pil = copy.deepcopy(image_pil)
# sam`s output
sam_predictor.model.load_state_dict(torch.load(sam_checkpoint,map_location='cpu'))
masks, _, _ = sam_predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
box=None,
multimask_output=False,
)
mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
mask_draw = ImageDraw.Draw(mask_image)
for mask in masks:
draw_mask(mask, mask_draw, random_color=True)
image_draw = ImageDraw.Draw(image_pil)
draw_point(point_coords,image_draw)
image_pil = image_pil.convert('RGBA')
image_pil.alpha_composite(mask_image)
# asam`s output
sam_predictor.model.load_state_dict(torch.load(asam_checkpoint,map_location='cpu'))
a_masks, _, _ = sam_predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
box=None,
multimask_output=False,
)
a_mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
a_mask_draw = ImageDraw.Draw(a_mask_image)
for a_mask in a_masks:
draw_mask(a_mask, a_mask_draw, random_color=True)
a_image_draw = ImageDraw.Draw(a_image_pil)
draw_point(point_coords,a_image_draw)
a_image_pil = a_image_pil.convert('RGBA')
a_image_pil.alpha_composite(a_mask_image)
return [[image_pil, mask_image],[a_image_pil, a_mask_image]]
else:
print("task_type:{} error!".format(task_type))
if __name__ == "__main__":
parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
parser.add_argument("--debug", action="store_true",
help="using debug mode")
parser.add_argument("--share", action="store_true", help="share the app")
parser.add_argument('--no-gradio-queue', action="store_true",
help='path to the SAM checkpoint')
args = parser.parse_args()
print(args)
block = gr.Blocks()
if not args.no_gradio_queue:
block = block.queue()
with block:
gr.Markdown(
"""
# ASAM
Welcome to the ASAM demo <br/>
You may select different prompt types to get the output mask of target instance.
## Usage
You may check the instruction below, or check our github page about more details.
## Mode
You may select an example image or upload your image to start, we support 4 prompt types:
**default_box**: According to the mask label, automaticly generate the default box prompt, only used for examples.
**automatic**: Automaticly generate text prompt and the corresponding box input with BLIP and Grounding-DINO.
**scribble_point**: Click an point on the target instance.
**scribble_box**: Click on two points, the top-left point and the bottom-right point to represent a bounding box of the target instance.
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(
source='upload', type="pil", value="example9.jpg", tool="sketch",brush_radius=20)
task_type = gr.Dropdown(
["default_box","automatic", "scribble_point", "scribble_box"], value="default_box", label="task_type")
text_prompt = gr.Textbox(label="Text Prompt", placeholder="bench .", visible=False)
run_button = gr.Button(label="Run")
with gr.Accordion("Advanced options", open=False):
box_threshold = gr.Slider(
label="Box Threshold", minimum=0.0, maximum=1.0, value=0.4, step=0.001
)
text_threshold = gr.Slider(
label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
)
iou_threshold = gr.Slider(
label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001
)
with gr.Column():
with gr.Row(equal_height=True):
gr.Image(source='upload', value="meta&sam.png",tool="label", show_download_button=False, min_width=0, height=50, width=100,container=False, style="width: 0.5px; height: 0.5px; margin-right: 0px;")
# gr.Markdown(
# """
# # SAM-Output
# """)
gallery1 = gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery"
).style(preview=True, grid=2, object_fit="scale-down")
with gr.Row(equal_height=True):
gr.Image(source='upload', value="vivo&asam.png",tool="label", show_download_button=False, min_width=0, height=50, width=100,container=False, style="width: 0.5px; height: 0.5px; margin-right: 0px;")
# gr.Markdown(
# """
# # SAM-Output
# """)
gallery2 = gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery"
).style(preview=True, grid=2, object_fit="scale-down")
with gr.Row():
with gr.Column():
gr.Examples(["example1.jpg"], inputs=input_image)
with gr.Column():
gr.Examples(["example2.jpg"], inputs=input_image)
with gr.Column():
gr.Examples(["example3.jpg"], inputs=input_image)
with gr.Column():
gr.Examples(["example4.jpg"], inputs=input_image)
with gr.Column():
gr.Examples(["example5.jpg"], inputs=input_image)
with gr.Column():
gr.Examples(["example6.jpg"], inputs=input_image)
with gr.Column():
gr.Examples(["example7.jpg"], inputs=input_image)
with gr.Column():
gr.Examples(["example8.jpg"], inputs=input_image)
with gr.Column():
gr.Examples(["example9.jpg"], inputs=input_image)
with gr.Column():
gr.Examples(["example10.jpg"], inputs=input_image)
with gr.Column():
gr.Examples(["example11.jpg"], inputs=input_image)
run_button.click(fn=run_grounded_sam, inputs=[
input_image, text_prompt, task_type, box_threshold, text_threshold, iou_threshold], outputs=[gallery1,gallery2])
block.launch(debug=args.debug, share=args.share, show_error=True)