Spaces:
Runtime error
Runtime error
add dilation bar and improve UI
Browse files
app.py
CHANGED
@@ -1,19 +1,38 @@
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
from pathlib import Path
|
4 |
from matplotlib import pyplot as plt
|
5 |
import torch
|
6 |
import tempfile
|
7 |
-
import os
|
8 |
-
from omegaconf import OmegaConf
|
9 |
-
from sam_segment import predict_masks_with_sam
|
10 |
from lama_inpaint import inpaint_img_with_lama, build_lama_model, inpaint_img_with_builded_lama
|
11 |
from utils import load_img_to_array, save_array_to_img, dilate_mask, \
|
12 |
show_mask, show_points
|
13 |
from PIL import Image
|
|
|
14 |
from segment_anything import SamPredictor, sam_model_registry
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
def mkstemp(suffix, dir=None):
|
18 |
fd, path = tempfile.mkstemp(suffix=f"{suffix}", dir=dir)
|
19 |
os.close(fd)
|
@@ -21,9 +40,7 @@ def mkstemp(suffix, dir=None):
|
|
21 |
|
22 |
|
23 |
def get_sam_feat(img):
|
24 |
-
# predictor.set_image(img)
|
25 |
model['sam'].set_image(img)
|
26 |
-
# self.is_image_set = False
|
27 |
features = model['sam'].features
|
28 |
orig_h = model['sam'].orig_h
|
29 |
orig_w = model['sam'].orig_w
|
@@ -33,24 +50,18 @@ def get_sam_feat(img):
|
|
33 |
return features, orig_h, orig_w, input_h, input_w
|
34 |
|
35 |
|
36 |
-
def get_masked_img(img, w, h, features, orig_h, orig_w, input_h, input_w):
|
37 |
point_coords = [w, h]
|
38 |
point_labels = [1]
|
39 |
-
dilate_kernel_size = 15
|
40 |
|
41 |
-
# model['sam'].is_image_set = False
|
42 |
model['sam'].is_image_set = True
|
43 |
model['sam'].features = features
|
44 |
model['sam'].orig_h = orig_h
|
45 |
model['sam'].orig_w = orig_w
|
46 |
model['sam'].input_h = input_h
|
47 |
model['sam'].input_w = input_w
|
48 |
-
|
49 |
-
# model['sam'].
|
50 |
-
# model['sam'].input_size = input_size
|
51 |
-
# model['sam'].is_image_set = True
|
52 |
-
|
53 |
-
model['sam'].set_image(img)
|
54 |
masks, _, _ = model['sam'].predict(
|
55 |
point_coords=np.array([point_coords]),
|
56 |
point_labels=np.array(point_labels),
|
@@ -77,6 +88,7 @@ def get_masked_img(img, w, h, features, orig_h, orig_w, input_h, input_w):
|
|
77 |
show_points(plt.gca(), [point_coords], point_labels,
|
78 |
size=(width*0.04)**2)
|
79 |
show_mask(plt.gca(), mask, random_color=False)
|
|
|
80 |
plt.savefig(tmp_p, bbox_inches='tight', pad_inches=0)
|
81 |
figs.append(fig)
|
82 |
plt.close()
|
@@ -84,8 +96,7 @@ def get_masked_img(img, w, h, features, orig_h, orig_w, input_h, input_w):
|
|
84 |
|
85 |
|
86 |
def get_inpainted_img(img, mask0, mask1, mask2):
|
87 |
-
lama_config =
|
88 |
-
# lama_ckpt = "pretrained_models/big-lama"
|
89 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
90 |
out = []
|
91 |
for mask in [mask0, mask1, mask2]:
|
@@ -97,25 +108,27 @@ def get_inpainted_img(img, mask0, mask1, mask2):
|
|
97 |
return out
|
98 |
|
99 |
|
100 |
-
|
|
|
|
|
|
|
|
|
101 |
model = {}
|
102 |
# build the sam model
|
103 |
model_type="vit_h"
|
104 |
-
ckpt_p=
|
105 |
model_sam = sam_model_registry[model_type](checkpoint=ckpt_p)
|
106 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
107 |
model_sam.to(device=device)
|
108 |
-
# predictor = SamPredictor(model_sam)
|
109 |
model['sam'] = SamPredictor(model_sam)
|
110 |
|
111 |
# build the lama model
|
112 |
-
lama_config =
|
113 |
-
lama_ckpt =
|
114 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
115 |
-
# model_lama = build_lama_model(lama_config, lama_ckpt, device=device)
|
116 |
model['lama'] = build_lama_model(lama_config, lama_ckpt, device=device)
|
117 |
|
118 |
-
|
119 |
with gr.Blocks() as demo:
|
120 |
features = gr.State(None)
|
121 |
orig_h = gr.State(None)
|
@@ -123,36 +136,59 @@ with gr.Blocks() as demo:
|
|
123 |
input_h = gr.State(None)
|
124 |
input_w = gr.State(None)
|
125 |
|
126 |
-
with gr.Row():
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
with gr.Row():
|
131 |
w = gr.Number(label="Point Coordinate W")
|
132 |
h = gr.Number(label="Point Coordinate H")
|
133 |
-
|
134 |
-
sam_mask = gr.Button("Predict Mask
|
135 |
-
lama = gr.Button("Inpaint Image
|
136 |
-
|
137 |
|
138 |
# todo: maybe we can delete this row, for it's unnecessary to show the original mask for customers
|
139 |
-
with gr.Row():
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
def get_select_coords(img, evt: gr.SelectData):
|
158 |
dpi = plt.rcParams['figure.dpi']
|
@@ -160,22 +196,17 @@ with gr.Blocks() as demo:
|
|
160 |
fig = plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
|
161 |
plt.imshow(img)
|
162 |
plt.axis('off')
|
|
|
163 |
show_points(plt.gca(), [[evt.index[0], evt.index[1]]], [1],
|
164 |
size=(width*0.04)**2)
|
165 |
return evt.index[0], evt.index[1], fig
|
166 |
|
167 |
img.select(get_select_coords, [img], [w, h, img_pointed])
|
168 |
-
# sam_feat.click(
|
169 |
-
# get_sam_feat,
|
170 |
-
# [img],
|
171 |
-
# []
|
172 |
-
# )
|
173 |
-
# img.change(get_sam_feat, [img], [])
|
174 |
img.upload(get_sam_feat, [img], [features, orig_h, orig_w, input_h, input_w])
|
175 |
|
176 |
sam_mask.click(
|
177 |
get_masked_img,
|
178 |
-
[img, w, h, features, orig_h, orig_w, input_h, input_w],
|
179 |
[img_with_mask_0, img_with_mask_1, img_with_mask_2, mask_0, mask_1, mask_2]
|
180 |
)
|
181 |
|
@@ -185,16 +216,16 @@ with gr.Blocks() as demo:
|
|
185 |
[img_rm_with_mask_0, img_rm_with_mask_1, img_rm_with_mask_2]
|
186 |
)
|
187 |
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
|
|
|
|
195 |
|
196 |
if __name__ == "__main__":
|
197 |
-
|
198 |
-
# demo.launch(max_threads=8)
|
199 |
-
demo.launch()
|
200 |
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
# sys.path.append(os.path.abspath(os.path.dirname(os.getcwd())))
|
4 |
+
# os.chdir("../")
|
5 |
import gradio as gr
|
6 |
import numpy as np
|
7 |
from pathlib import Path
|
8 |
from matplotlib import pyplot as plt
|
9 |
import torch
|
10 |
import tempfile
|
|
|
|
|
|
|
11 |
from lama_inpaint import inpaint_img_with_lama, build_lama_model, inpaint_img_with_builded_lama
|
12 |
from utils import load_img_to_array, save_array_to_img, dilate_mask, \
|
13 |
show_mask, show_points
|
14 |
from PIL import Image
|
15 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent / "third_party" / "segment-anything"))
|
16 |
from segment_anything import SamPredictor, sam_model_registry
|
17 |
+
import argparse
|
18 |
+
|
19 |
+
def setup_args(parser):
|
20 |
+
parser.add_argument(
|
21 |
+
"--lama_config", type=str,
|
22 |
+
default="./third_party/lama/configs/prediction/default.yaml",
|
23 |
+
help="The path to the config file of lama model. "
|
24 |
+
"Default: the config of big-lama",
|
25 |
+
)
|
26 |
+
parser.add_argument(
|
27 |
+
"--lama_ckpt", type=str,
|
28 |
+
default="pretrained_models/big-lama",
|
29 |
+
help="The path to the lama checkpoint.",
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--sam_ckpt", type=str,
|
33 |
+
default="./pretrained_models/sam_vit_h_4b8939.pth",
|
34 |
+
help="The path to the SAM checkpoint to use for mask generation.",
|
35 |
+
)
|
36 |
def mkstemp(suffix, dir=None):
|
37 |
fd, path = tempfile.mkstemp(suffix=f"{suffix}", dir=dir)
|
38 |
os.close(fd)
|
|
|
40 |
|
41 |
|
42 |
def get_sam_feat(img):
|
|
|
43 |
model['sam'].set_image(img)
|
|
|
44 |
features = model['sam'].features
|
45 |
orig_h = model['sam'].orig_h
|
46 |
orig_w = model['sam'].orig_w
|
|
|
50 |
return features, orig_h, orig_w, input_h, input_w
|
51 |
|
52 |
|
53 |
+
def get_masked_img(img, w, h, features, orig_h, orig_w, input_h, input_w, dilate_kernel_size):
|
54 |
point_coords = [w, h]
|
55 |
point_labels = [1]
|
|
|
56 |
|
|
|
57 |
model['sam'].is_image_set = True
|
58 |
model['sam'].features = features
|
59 |
model['sam'].orig_h = orig_h
|
60 |
model['sam'].orig_w = orig_w
|
61 |
model['sam'].input_h = input_h
|
62 |
model['sam'].input_w = input_w
|
63 |
+
|
64 |
+
# model['sam'].set_image(img) # todo : update here for accelerating
|
|
|
|
|
|
|
|
|
65 |
masks, _, _ = model['sam'].predict(
|
66 |
point_coords=np.array([point_coords]),
|
67 |
point_labels=np.array(point_labels),
|
|
|
88 |
show_points(plt.gca(), [point_coords], point_labels,
|
89 |
size=(width*0.04)**2)
|
90 |
show_mask(plt.gca(), mask, random_color=False)
|
91 |
+
plt.tight_layout()
|
92 |
plt.savefig(tmp_p, bbox_inches='tight', pad_inches=0)
|
93 |
figs.append(fig)
|
94 |
plt.close()
|
|
|
96 |
|
97 |
|
98 |
def get_inpainted_img(img, mask0, mask1, mask2):
|
99 |
+
lama_config = args.lama_config
|
|
|
100 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
101 |
out = []
|
102 |
for mask in [mask0, mask1, mask2]:
|
|
|
108 |
return out
|
109 |
|
110 |
|
111 |
+
# get args
|
112 |
+
parser = argparse.ArgumentParser()
|
113 |
+
setup_args(parser)
|
114 |
+
args = parser.parse_args(sys.argv[1:])
|
115 |
+
# build models
|
116 |
model = {}
|
117 |
# build the sam model
|
118 |
model_type="vit_h"
|
119 |
+
ckpt_p=args.sam_ckpt
|
120 |
model_sam = sam_model_registry[model_type](checkpoint=ckpt_p)
|
121 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
122 |
model_sam.to(device=device)
|
|
|
123 |
model['sam'] = SamPredictor(model_sam)
|
124 |
|
125 |
# build the lama model
|
126 |
+
lama_config = args.lama_config
|
127 |
+
lama_ckpt = args.lama_ckpt
|
128 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
129 |
model['lama'] = build_lama_model(lama_config, lama_ckpt, device=device)
|
130 |
|
131 |
+
button_size = (100,50)
|
132 |
with gr.Blocks() as demo:
|
133 |
features = gr.State(None)
|
134 |
orig_h = gr.State(None)
|
|
|
136 |
input_h = gr.State(None)
|
137 |
input_w = gr.State(None)
|
138 |
|
139 |
+
with gr.Row().style(mobile_collapse=False, equal_height=True):
|
140 |
+
with gr.Column(variant="panel"):
|
141 |
+
with gr.Row():
|
142 |
+
gr.Markdown("## Input Image")
|
143 |
+
with gr.Row():
|
144 |
+
img = gr.Image(label="Input Image").style(height="200px")
|
145 |
+
with gr.Column(variant="panel"):
|
146 |
+
with gr.Row():
|
147 |
+
gr.Markdown("## Pointed Image")
|
148 |
+
with gr.Row():
|
149 |
+
img_pointed = gr.Plot(label='Pointed Image')
|
150 |
+
with gr.Column(variant="panel"):
|
151 |
+
with gr.Row():
|
152 |
+
gr.Markdown("## Control Panel")
|
153 |
with gr.Row():
|
154 |
w = gr.Number(label="Point Coordinate W")
|
155 |
h = gr.Number(label="Point Coordinate H")
|
156 |
+
dilate_kernel_size = gr.Slider(label="Dilate Kernel Size", minimum=0, maximum=100, step=1, value=15)
|
157 |
+
sam_mask = gr.Button("Predict Mask", variant="primary").style(full_width=True, size="sm")
|
158 |
+
lama = gr.Button("Inpaint Image", variant="primary").style(full_width=True, size="sm")
|
159 |
+
clear_button_image = gr.Button(value="Reset", label="Reset", variant="secondary").style(full_width=True, size="sm")
|
160 |
|
161 |
# todo: maybe we can delete this row, for it's unnecessary to show the original mask for customers
|
162 |
+
with gr.Row(variant="panel"):
|
163 |
+
with gr.Column():
|
164 |
+
with gr.Row():
|
165 |
+
gr.Markdown("## Segmentation Mask")
|
166 |
+
with gr.Row():
|
167 |
+
mask_0 = gr.outputs.Image(type="numpy", label="Segmentation Mask 0").style(height="200px")
|
168 |
+
mask_1 = gr.outputs.Image(type="numpy", label="Segmentation Mask 1").style(height="200px")
|
169 |
+
mask_2 = gr.outputs.Image(type="numpy", label="Segmentation Mask 2").style(height="200px")
|
170 |
+
|
171 |
+
with gr.Row(variant="panel"):
|
172 |
+
with gr.Column():
|
173 |
+
with gr.Row():
|
174 |
+
gr.Markdown("## Image with Mask")
|
175 |
+
with gr.Row():
|
176 |
+
img_with_mask_0 = gr.Plot(label="Image with Segmentation Mask 0")
|
177 |
+
img_with_mask_1 = gr.Plot(label="Image with Segmentation Mask 1")
|
178 |
+
img_with_mask_2 = gr.Plot(label="Image with Segmentation Mask 2")
|
179 |
+
|
180 |
+
with gr.Row(variant="panel"):
|
181 |
+
with gr.Column():
|
182 |
+
with gr.Row():
|
183 |
+
gr.Markdown("## Image Removed with Mask")
|
184 |
+
with gr.Row():
|
185 |
+
img_rm_with_mask_0 = gr.outputs.Image(
|
186 |
+
type="numpy", label="Image Removed with Segmentation Mask 0").style(height="200px")
|
187 |
+
img_rm_with_mask_1 = gr.outputs.Image(
|
188 |
+
type="numpy", label="Image Removed with Segmentation Mask 1").style(height="200px")
|
189 |
+
img_rm_with_mask_2 = gr.outputs.Image(
|
190 |
+
type="numpy", label="Image Removed with Segmentation Mask 2").style(height="200px")
|
191 |
+
|
192 |
|
193 |
def get_select_coords(img, evt: gr.SelectData):
|
194 |
dpi = plt.rcParams['figure.dpi']
|
|
|
196 |
fig = plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
|
197 |
plt.imshow(img)
|
198 |
plt.axis('off')
|
199 |
+
plt.tight_layout()
|
200 |
show_points(plt.gca(), [[evt.index[0], evt.index[1]]], [1],
|
201 |
size=(width*0.04)**2)
|
202 |
return evt.index[0], evt.index[1], fig
|
203 |
|
204 |
img.select(get_select_coords, [img], [w, h, img_pointed])
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
img.upload(get_sam_feat, [img], [features, orig_h, orig_w, input_h, input_w])
|
206 |
|
207 |
sam_mask.click(
|
208 |
get_masked_img,
|
209 |
+
[img, w, h, features, orig_h, orig_w, input_h, input_w, dilate_kernel_size],
|
210 |
[img_with_mask_0, img_with_mask_1, img_with_mask_2, mask_0, mask_1, mask_2]
|
211 |
)
|
212 |
|
|
|
216 |
[img_rm_with_mask_0, img_rm_with_mask_1, img_rm_with_mask_2]
|
217 |
)
|
218 |
|
219 |
+
|
220 |
+
def reset(*args):
|
221 |
+
return [None for _ in args]
|
222 |
+
|
223 |
+
clear_button_image.click(
|
224 |
+
reset,
|
225 |
+
[img, features, img_pointed, w, h, mask_0, mask_1, mask_2, img_with_mask_0, img_with_mask_1, img_with_mask_2, img_rm_with_mask_0, img_rm_with_mask_1, img_rm_with_mask_2],
|
226 |
+
[img, features, img_pointed, w, h, mask_0, mask_1, mask_2, img_with_mask_0, img_with_mask_1, img_with_mask_2, img_rm_with_mask_0, img_rm_with_mask_1, img_rm_with_mask_2]
|
227 |
+
)
|
228 |
|
229 |
if __name__ == "__main__":
|
230 |
+
demo.launch(share=True)
|
|
|
|
|
231 |
|