File size: 4,668 Bytes
2cd9d38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad09938
2cd9d38
 
 
ad09938
2cd9d38
 
 
 
 
 
 
 
 
 
 
ad09938
2cd9d38
 
 
ad09938
2cd9d38
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
import numpy as np
import torch
from PIL import Image
from transformers import SamModel, SamProcessor
from gradio_image_prompter import ImagePrompter


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
slimsam_model = SamModel.from_pretrained("nielsr/slimsam-50-uniform").to(device)
slimsam_processor = SamProcessor.from_pretrained("nielsr/slimsam-50-uniform")

def sam_box_inference(image, model, x_min, y_min, x_max, y_max):

    inputs = sam_processor(
        Image.fromarray(image),
        input_boxes=[[[[x_min, y_min, x_max, y_max]]]],
        return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        outputs = model(**inputs)

    mask = sam_processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(),
        inputs["original_sizes"].cpu(),
        inputs["reshaped_input_sizes"].cpu()
    )[0][0][0].numpy()
    mask = mask[np.newaxis, ...]
    print(mask)
    print(mask.shape)
    return [(mask, "mask")]


def sam_point_inference(image, model, x, y):
    inputs = sam_processor(
        image,
        input_points=[[[x, y]]],
        return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = sam_model(**inputs)

    mask = sam_processor.post_process_masks(
        outputs.pred_masks.cpu(),
        inputs["original_sizes"].cpu(),
        inputs["reshaped_input_sizes"].cpu()
    )[0][0][0].numpy()
    mask = mask[np.newaxis, ...]
    print(type(mask))
    print(mask.shape)
    return [(mask, "mask")]

def infer_point(img):
    if img is None:
        gr.Error("Please upload an image and select a point.")
    if img["background"] is None:
        gr.Error("Please upload an image and select a point.")
    # background (original image) layers[0] ( point prompt) composite (total image)
    image = img["background"].convert("RGB")
    point_prompt = img["layers"][0]
    total_image = img["composite"]
    img_arr = np.array(point_prompt)
    if not np.any(img_arr):
        gr.Error("Please select a point on top of the image.")
    else:
        nonzero_indices = np.nonzero(img_arr)
        img_arr = np.array(point_prompt)
        nonzero_indices = np.nonzero(img_arr)
        center_x = int(np.mean(nonzero_indices[1]))
        center_y = int(np.mean(nonzero_indices[0]))
    print("Point inference returned.")
    return ((image, sam_point_inference(image, slimsam_model, center_x, center_y)),
    (image, sam_point_inference(image, sam_model, center_x, center_y)))

def infer_box(prompts):
    # background (original image) layers[0] ( point prompt) composite (total image)
    image = prompts["image"]
    if image is None:
      gr.Error("Please upload an image and draw a box before submitting")
    points = prompts["points"][0]
    if points is None:
      gr.Error("Please draw a box before submitting.")
    print(points)

    # x_min = points[0] x_max = points[3] y_min = points[1] y_max = points[4]
    return ((image, sam_box_inference(image, slimsam_model, points[0], points[1], points[3], points[4])),
    (image, sam_box_inference(image, sam_model, points[0], points[1], points[3], points[4])))
with gr.Blocks(title="SlimSAM") as demo:
  gr.Markdown("# SlimSAM")
  gr.Markdown("SlimSAM is the pruned-distilled version of SAM that is smaller.")
  gr.Markdown("In this demo, you can compare SlimSAM and SAM outputs in point and box prompts.")
  
  with gr.Tab("**Box Prompt**"):
    with gr.Row():
        with gr.Column(scale=1):
            # Title
            gr.Markdown("To try box prompting, simply upload and image and draw a box on it.")
    with gr.Row():
        with gr.Column():
            im = ImagePrompter()
            btn = gr.Button("Submit")
        with gr.Column():
          output_box_slimsam = gr.AnnotatedImage(label="SlimSAM Output")
          output_box_sam = gr.AnnotatedImage(label="SAM Output")


    btn.click(infer_box, inputs=im, outputs=[output_box_slimsam, output_box_sam])

  with gr.Tab("**Point Prompt**"):
    with gr.Row():
        with gr.Column(scale=1):
            # Title
            gr.Markdown("To try point prompting, simply upload and image and leave a dot on it.")
    with gr.Row():
        with gr.Column():
            im = gr.ImageEditor(
                type="pil",
            )
        with gr.Column():
          output_slimsam = gr.AnnotatedImage(label="SlimSAM Output")
          output_sam = gr.AnnotatedImage(label="SAM Output")

    im.change(infer_point, inputs=im, outputs=[output_slimsam, output_sam])
demo.launch(debug=True)