File size: 2,687 Bytes
bf29adc
 
 
 
46abea7
 
bf29adc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ebade2
bf29adc
 
a96b320
bf29adc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
import numpy as np
import cv2

import spaces

from PIL import Image

from src.plot_utils import show_masks
from gradio_image_annotation import image_annotator


from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

choice_mapping = {
    "tiny": ["sam2_hiera_t.yaml", "assets/checkpoints/sam2_hiera_tiny.pt"],
    "small": ["sam2_hiera_s.yaml", "assets/checkpoints/sam2_hiera_small.pt"],
    "base_plus": ["sam2_hiera_b+.yaml", "assets/checkpoints/sam2_hiera_base_plus.pt"],
    "large": ["sam2_hiera_l.yaml", "assets/checkpoints/sam2_hiera_large.pt"],
}

# @spaces.GPU
def predict(model_choice: str, annotations, image):
    config_file, ckpt_path = choice_mapping[str(model_choice)]
    sam2_model = build_sam2(config_file, ckpt_path, device="cuda")
    predictor = SAM2ImagePredictor(sam2_model)
    predictor.set_image(image)
    coordinates = np.array(
        [
            int(annotations["boxes"][0]["xmin"]),
            int(annotations["boxes"][0]["ymin"]),
            int(annotations["boxes"][0]["xmax"]),
            int(annotations["boxes"][0]["ymax"]),
        ]
    )
    masks, scores, _ = predictor.predict(
        point_coords=None,
        point_labels=None,
        box=coordinates[None, :],
        multimask_output=False,
    )
    mask = masks.transpose(1, 2, 0)
    mask_image = (mask * 255).astype(np.uint8)  # Convert to uint8 format
    cv2.imwrite("mask.png", mask_image)

    return [
        show_masks(image, masks, scores, box_coords=coordinates),
        gr.DownloadButton("Download Mask", value="mask.png", visible=True),
    ]


with gr.Blocks(delete_cache=(30, 30)) as demo:
    gr.Markdown(
        """
        # 1. Choose Model Checkpoint
        """
    )
    with gr.Row():
        model = gr.Dropdown(
            choices=["tiny", "small", "base_plus", "large"],
            value="tiny",
            label="Model Checkpoint",
            info="Which model checkpoint to load?",
        )

    gr.Markdown(
        """
        # 2. Upload an Image
        """
    )

    with gr.Row():
        img = gr.Image(value="./assets/img.png", type="numpy", label="Input Image")

    gr.Markdown(
        """
        # 3. Draw Bounding Box
        """
    )

    annotator = image_annotator(
        value={"image": img.value["path"]},
        disable_edit_boxes=True,
        single_box=True,
        label="Draw a bounding box",
    )
    btn = gr.Button("Get Segmentation Mask")
    download_btn = gr.DownloadButton("Download Mask", value="mask.png", visible=False)
    btn.click(
        fn=predict, inputs=[model, annotator, img], outputs=[gr.Plot(), download_btn]
    )

demo.launch()