import spaces import multiprocessing as mp import numpy as np from PIL import Image import torch try: import detectron2 except: import os os.system('pip install git+https://github.com/facebookresearch/detectron2.git') from detectron2.config import get_cfg from detectron2.projects.deeplab import add_deeplab_config from detectron2.data.detection_utils import read_image from mask_adapter import add_maskformer2_config, add_fcclip_config, add_mask_adapter_config from mask_adapter.sam_maskadapter import SAMVisualizationDemo, SAMPointVisualizationDemo import gradio as gr import open_clip from sam2.build_sam import build_sam2 from mask_adapter.modeling.meta_arch.mask_adapter_head import build_mask_adapter def setup_cfg(config_file): cfg = get_cfg() add_deeplab_config(cfg) add_maskformer2_config(cfg) add_fcclip_config(cfg) add_mask_adapter_config(cfg) cfg.merge_from_file(config_file) cfg.freeze() return cfg @spaces.GPU @torch.no_grad() @torch.autocast(device_type="cuda", dtype=torch.float32) def inference_automatic(input_img, class_names): mp.set_start_method("spawn", force=True) config_file = './configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml' cfg = setup_cfg(config_file) demo = SAMVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter) class_names = class_names.split(',') img = read_image(input_img, format="BGR") if len(class_names) == 1: class_names.append('others') txts = [f'a photo of {cls_name}' for cls_name in class_names] text = open_clip.tokenize(txts) text_features = clip_model.encode_text(text.cuda()) text_features /= text_features.norm(dim=-1, keepdim=True) _, visualized_output = demo.run_on_image(img, class_names,text_features) return Image.fromarray(np.uint8(visualized_output.get_image())).convert('RGB') @spaces.GPU @torch.no_grad() @torch.autocast(device_type="cuda", dtype=torch.float32) def inference_point(input_img, evt: gr.SelectData,): mp.set_start_method("spawn", force=True) x, y = evt.index[0], evt.index[1] points = [[x, y]] print(f"Selected point: {points}") config_file = './configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml' cfg = setup_cfg(config_file) demo = SAMPointVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter) img = read_image(input_img, format="BGR") text_features = torch.from_numpy(np.load("./text_embedding/lvis_coco_text_embedding.npy")) _, visualized_output = demo.run_on_image_with_points(img, points,text_features) return visualized_output sam2_model = None clip_model = None mask_adapter = None def initialize_models(sam_path, adapter_pth, model_cfg, cfg): cfg = setup_cfg(cfg) global sam2_model, clip_model, mask_adapter if sam2_model is None: sam2_model = build_sam2(model_cfg, sam_path, device="cpu", apply_postprocessing=False) sam2_model = sam2_model.to("cuda") print("SAM2 model initialized.") if clip_model is None: clip_model, _, _ = open_clip.create_model_and_transforms("convnext_large_d_320", pretrained="laion2b_s29b_b131k_ft_soup") clip_model = clip_model.eval() clip_model = clip_model.to("cuda") print("CLIP model initialized.") if mask_adapter is None: mask_adapter = build_mask_adapter(cfg, "MASKAdapterHead").to("cuda") mask_adapter = mask_adapter.eval() adapter_state_dict = torch.load(adapter_pth) mask_adapter.load_state_dict(adapter_state_dict) print("Mask Adapter model initialized.") model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" sam_path = './sam2.1_hiera_large.pt' adapter_pth = './model_0279999_with_sem_new.pth' cfg = './configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml' initialize_models(sam_path, adapter_pth, model_cfg, cfg) # Examples for testing examples = [ ['./demo/images/000000001025.jpg', 'dog, beach, trees, sea, sky, snow, person, rocks, buildings, birds, beach umbrella, beach chair'], ['./demo/images/ADE_val_00000979.jpg', 'sky,sea,mountain,pier,beach,island,,landscape,horizon'], ['./demo/images/ADE_val_00001200.jpg', 'bridge, mountains, trees, water, sky, buildings, boats, animals, flowers, waterfalls, grasslands, rocks'], ] output_labels = ['segmentation map'] title = '