pg56714 commited on
Commit
b6c4754
·
verified ·
1 Parent(s): 238b74d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +466 -462
app.py CHANGED
@@ -1,462 +1,466 @@
1
- import os
2
- import sys
3
-
4
- sys.path.append(os.path.abspath(os.path.dirname(os.getcwd())))
5
- # os.chdir("../")
6
- import cv2
7
- import gradio as gr
8
- import numpy as np
9
- from pathlib import Path
10
- from matplotlib import pyplot as plt
11
- import torch
12
- import tempfile
13
-
14
- from stable_diffusion_inpaint import fill_img_with_sd, replace_img_with_sd
15
- from lama_inpaint import (
16
- inpaint_img_with_lama,
17
- build_lama_model,
18
- inpaint_img_with_builded_lama,
19
- )
20
- from utils import (
21
- load_img_to_array,
22
- save_array_to_img,
23
- dilate_mask,
24
- show_mask,
25
- show_points,
26
- )
27
- from PIL import Image
28
- from segment_anything import SamPredictor, sam_model_registry
29
- import argparse
30
-
31
-
32
- def setup_args(parser):
33
- parser.add_argument(
34
- "--lama_config",
35
- type=str,
36
- default="./lama/configs/prediction/default.yaml",
37
- help="The path to the config file of lama model. "
38
- "Default: the config of big-lama",
39
- )
40
- parser.add_argument(
41
- "--lama_ckpt",
42
- type=str,
43
- default="pretrained_models/big-lama",
44
- help="The path to the lama checkpoint.",
45
- )
46
- parser.add_argument(
47
- "--sam_ckpt",
48
- type=str,
49
- default="./pretrained_models/sam_vit_h_4b8939.pth",
50
- help="The path to the SAM checkpoint to use for mask generation.",
51
- )
52
-
53
-
54
- def mkstemp(suffix, dir=None):
55
- fd, path = tempfile.mkstemp(suffix=f"{suffix}", dir=dir)
56
- os.close(fd)
57
- return Path(path)
58
-
59
-
60
- def get_sam_feat(img):
61
- model["sam"].set_image(img)
62
- features = model["sam"].features
63
- orig_h = model["sam"].orig_h
64
- orig_w = model["sam"].orig_w
65
- input_h = model["sam"].input_h
66
- input_w = model["sam"].input_w
67
- model["sam"].reset_image()
68
- return features, orig_h, orig_w, input_h, input_w
69
-
70
-
71
- def get_fill_img_with_sd(image, mask, image_resolution, text_prompt):
72
- device = "cuda" if torch.cuda.is_available() else "cpu"
73
- if len(mask.shape) == 3:
74
- mask = mask[:, :, 0]
75
- np_image = np.array(image, dtype=np.uint8)
76
- H, W, C = np_image.shape
77
- np_image = HWC3(np_image)
78
- np_image = resize_image(np_image, image_resolution)
79
- mask = cv2.resize(
80
- mask, (np_image.shape[1], np_image.shape[0]), interpolation=cv2.INTER_NEAREST
81
- )
82
-
83
- img_fill = fill_img_with_sd(np_image, mask, text_prompt, device=device)
84
- img_fill = img_fill.astype(np.uint8)
85
- return img_fill
86
-
87
-
88
- def get_replace_img_with_sd(image, mask, image_resolution, text_prompt):
89
- device = "cuda" if torch.cuda.is_available() else "cpu"
90
- if len(mask.shape) == 3:
91
- mask = mask[:, :, 0]
92
- np_image = np.array(image, dtype=np.uint8)
93
- H, W, C = np_image.shape
94
- np_image = HWC3(np_image)
95
- np_image = resize_image(np_image, image_resolution)
96
- mask = cv2.resize(
97
- mask, (np_image.shape[1], np_image.shape[0]), interpolation=cv2.INTER_NEAREST
98
- )
99
-
100
- img_replaced = replace_img_with_sd(np_image, mask, text_prompt, device=device)
101
- img_replaced = img_replaced.astype(np.uint8)
102
- return img_replaced
103
-
104
-
105
- def HWC3(x):
106
- assert x.dtype == np.uint8
107
- if x.ndim == 2:
108
- x = x[:, :, None]
109
- assert x.ndim == 3
110
- H, W, C = x.shape
111
- assert C == 1 or C == 3 or C == 4
112
- if C == 3:
113
- return x
114
- if C == 1:
115
- return np.concatenate([x, x, x], axis=2)
116
- if C == 4:
117
- color = x[:, :, 0:3].astype(np.float32)
118
- alpha = x[:, :, 3:4].astype(np.float32) / 255.0
119
- y = color * alpha + 255.0 * (1.0 - alpha)
120
- y = y.clip(0, 255).astype(np.uint8)
121
- return y
122
-
123
-
124
- def resize_image(input_image, resolution):
125
- H, W, C = input_image.shape
126
- k = float(resolution) / min(H, W)
127
- H = int(np.round(H * k / 64.0)) * 64
128
- W = int(np.round(W * k / 64.0)) * 64
129
- img = cv2.resize(
130
- input_image,
131
- (W, H),
132
- interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA,
133
- )
134
- return img
135
-
136
-
137
- def resize_points(clicked_points, original_shape, resolution):
138
- original_height, original_width, _ = original_shape
139
- original_height = float(original_height)
140
- original_width = float(original_width)
141
-
142
- scale_factor = float(resolution) / min(original_height, original_width)
143
- resized_points = []
144
-
145
- for point in clicked_points:
146
- x, y, lab = point
147
- resized_x = int(round(x * scale_factor))
148
- resized_y = int(round(y * scale_factor))
149
- resized_point = (resized_x, resized_y, lab)
150
- resized_points.append(resized_point)
151
-
152
- return resized_points
153
-
154
-
155
- def get_click_mask(
156
- clicked_points, features, orig_h, orig_w, input_h, input_w, dilate_kernel_size
157
- ):
158
- # model['sam'].set_image(image)
159
- model["sam"].is_image_set = True
160
- model["sam"].features = features
161
- model["sam"].orig_h = orig_h
162
- model["sam"].orig_w = orig_w
163
- model["sam"].input_h = input_h
164
- model["sam"].input_w = input_w
165
-
166
- # Separate the points and labels
167
- points, labels = zip(*[(point[:2], point[2]) for point in clicked_points])
168
-
169
- # Convert the points and labels to numpy arrays
170
- input_point = np.array(points)
171
- input_label = np.array(labels)
172
-
173
- masks, _, _ = model["sam"].predict(
174
- point_coords=input_point,
175
- point_labels=input_label,
176
- multimask_output=False,
177
- )
178
- if dilate_kernel_size is not None:
179
- masks = [dilate_mask(mask, dilate_kernel_size) for mask in masks]
180
- else:
181
- masks = [mask for mask in masks]
182
-
183
- return masks
184
-
185
-
186
- def process_image_click(
187
- original_image,
188
- point_prompt,
189
- clicked_points,
190
- image_resolution,
191
- features,
192
- orig_h,
193
- orig_w,
194
- input_h,
195
- input_w,
196
- dilate_kernel_size,
197
- evt: gr.SelectData,
198
- ):
199
- if clicked_points is None:
200
- clicked_points = []
201
-
202
- # print("Received click event:", evt)
203
- if original_image is None:
204
- # print("No image loaded.")
205
- return None, clicked_points, None
206
-
207
- clicked_coords = evt.index
208
- if clicked_coords is None:
209
- # print("No valid coordinates received.")
210
- return None, clicked_points, None
211
-
212
- x, y = clicked_coords
213
- label = point_prompt
214
- lab = 1 if label == "Foreground Point" else 0
215
- clicked_points.append((x, y, lab))
216
- # print("Updated points list:", clicked_points)
217
-
218
- input_image = np.array(original_image, dtype=np.uint8)
219
- H, W, C = input_image.shape
220
- input_image = HWC3(input_image)
221
- img = resize_image(input_image, image_resolution)
222
- # print("Processed image size:", img.shape)
223
-
224
- resized_points = resize_points(clicked_points, input_image.shape, image_resolution)
225
- mask_click_np = get_click_mask(
226
- resized_points, features, orig_h, orig_w, input_h, input_w, dilate_kernel_size
227
- )
228
- mask_click_np = np.transpose(mask_click_np, (1, 2, 0)) * 255.0
229
- mask_image = HWC3(mask_click_np.astype(np.uint8))
230
- mask_image = cv2.resize(mask_image, (W, H), interpolation=cv2.INTER_LINEAR)
231
- # print("Mask image prepared.")
232
-
233
- edited_image = input_image
234
- for x, y, lab in clicked_points:
235
- color = (255, 0, 0) if lab == 1 else (0, 0, 255)
236
- edited_image = cv2.circle(edited_image, (x, y), 20, color, -1)
237
-
238
- opacity_mask = 0.75
239
- opacity_edited = 1.0
240
- overlay_image = cv2.addWeighted(
241
- edited_image,
242
- opacity_edited,
243
- (mask_image * np.array([0 / 255, 255 / 255, 0 / 255])).astype(np.uint8),
244
- opacity_mask,
245
- 0,
246
- )
247
-
248
- no_mask_overlay = edited_image.copy()
249
-
250
- return no_mask_overlay, overlay_image, clicked_points, mask_image
251
-
252
-
253
- def image_upload(image, image_resolution):
254
- if image is None:
255
- return None, None, None, None, None, None
256
- else:
257
- np_image = np.array(image, dtype=np.uint8)
258
- H, W, C = np_image.shape
259
- np_image = HWC3(np_image)
260
- np_image = resize_image(np_image, image_resolution)
261
- features, orig_h, orig_w, input_h, input_w = get_sam_feat(np_image)
262
- return image, features, orig_h, orig_w, input_h, input_w
263
-
264
-
265
- def get_inpainted_img(image, mask, image_resolution):
266
- lama_config = args.lama_config
267
- device = "cuda" if torch.cuda.is_available() else "cpu"
268
- if len(mask.shape) == 3:
269
- mask = mask[:, :, 0]
270
- img_inpainted = inpaint_img_with_builded_lama(
271
- model["lama"], image, mask, lama_config, device=device
272
- )
273
- return img_inpainted
274
-
275
-
276
- # get args
277
- parser = argparse.ArgumentParser()
278
- setup_args(parser)
279
- args = parser.parse_args(sys.argv[1:])
280
- # build models
281
- model = {}
282
- # build the sam model
283
- model_type = "vit_h"
284
- ckpt_p = args.sam_ckpt
285
- model_sam = sam_model_registry[model_type](checkpoint=ckpt_p)
286
- device = "cuda" if torch.cuda.is_available() else "cpu"
287
- model_sam.to(device=device)
288
- model["sam"] = SamPredictor(model_sam)
289
-
290
- # build the lama model
291
- lama_config = args.lama_config
292
- lama_ckpt = args.lama_ckpt
293
- device = "cuda" if torch.cuda.is_available() else "cpu"
294
- model["lama"] = build_lama_model(lama_config, lama_ckpt, device=device)
295
-
296
- button_size = (100, 50)
297
- with gr.Blocks() as demo:
298
- clicked_points = gr.State([])
299
- # origin_image = gr.State(None)
300
- click_mask = gr.State(None)
301
- features = gr.State(None)
302
- orig_h = gr.State(None)
303
- orig_w = gr.State(None)
304
- input_h = gr.State(None)
305
- input_w = gr.State(None)
306
-
307
- with gr.Row():
308
- with gr.Column(variant="panel"):
309
- with gr.Row():
310
- gr.Markdown("## Upload an image and click the region you want to edit.")
311
- with gr.Row():
312
- source_image_click = gr.Image(
313
- type="numpy",
314
- interactive=True,
315
- label="Upload and Edit Image",
316
- )
317
-
318
- image_edit_complete = gr.Image(
319
- type="numpy",
320
- interactive=False,
321
- label="Editing Complete",
322
- )
323
- with gr.Row():
324
- point_prompt = gr.Radio(
325
- choices=["Foreground Point", "Background Point"],
326
- value="Foreground Point",
327
- label="Point Label",
328
- interactive=True,
329
- show_label=False,
330
- )
331
- image_resolution = gr.Slider(
332
- label="Image Resolution",
333
- minimum=256,
334
- maximum=768,
335
- value=512,
336
- step=64,
337
- )
338
- dilate_kernel_size = gr.Slider(
339
- label="Dilate Kernel Size", minimum=0, maximum=30, value=15, step=1
340
- )
341
- with gr.Column(variant="panel"):
342
- with gr.Row():
343
- gr.Markdown("## Control Panel")
344
- text_prompt = gr.Textbox(label="Text Prompt")
345
- lama = gr.Button("Inpaint Image", variant="primary")
346
- fill_sd = gr.Button("Fill Anything with SD", variant="primary")
347
- replace_sd = gr.Button("Replace Anything with SD", variant="primary")
348
- clear_button_image = gr.Button(value="Reset", variant="secondary")
349
-
350
- # todo: maybe we can delete this row, for it's unnecessary to show the original mask for customers
351
- with gr.Row(variant="panel"):
352
- with gr.Column():
353
- with gr.Row():
354
- gr.Markdown("## Mask")
355
- with gr.Row():
356
- click_mask = gr.Image(
357
- type="numpy",
358
- label="Click Mask",
359
- interactive=False,
360
- )
361
- with gr.Column():
362
- with gr.Row():
363
- gr.Markdown("## Image Removed with Mask")
364
- with gr.Row():
365
- img_rm_with_mask = gr.Image(
366
- type="numpy",
367
- label="Image Removed with Mask",
368
- interactive=False,
369
- )
370
-
371
- with gr.Column():
372
- with gr.Row():
373
- gr.Markdown("## Fill Anything with Mask")
374
- with gr.Row():
375
- img_fill_with_mask = gr.Image(
376
- type="numpy",
377
- label="Image Fill Anything with Mask",
378
- interactive=False,
379
- )
380
-
381
- with gr.Column():
382
- with gr.Row():
383
- gr.Markdown("## Replace Anything with Mask")
384
- with gr.Row():
385
- img_replace_with_mask = gr.Image(
386
- type="numpy",
387
- label="Image Replace Anything with Mask",
388
- interactive=False,
389
- )
390
-
391
- source_image_click.upload(
392
- image_upload,
393
- inputs=[source_image_click, image_resolution],
394
- outputs=[source_image_click, features, orig_h, orig_w, input_h, input_w],
395
- )
396
-
397
- source_image_click.select(
398
- process_image_click,
399
- inputs=[
400
- source_image_click,
401
- point_prompt,
402
- clicked_points,
403
- image_resolution,
404
- features,
405
- orig_h,
406
- orig_w,
407
- input_h,
408
- input_w,
409
- dilate_kernel_size,
410
- ],
411
- outputs=[source_image_click, image_edit_complete, clicked_points, click_mask],
412
- show_progress=True,
413
- queue=True,
414
- )
415
-
416
- lama.click(
417
- get_inpainted_img,
418
- inputs=[source_image_click, click_mask, image_resolution],
419
- outputs=[img_rm_with_mask],
420
- )
421
-
422
- fill_sd.click(
423
- get_fill_img_with_sd,
424
- inputs=[source_image_click, click_mask, image_resolution, text_prompt],
425
- outputs=[img_fill_with_mask],
426
- )
427
-
428
- replace_sd.click(
429
- get_replace_img_with_sd,
430
- inputs=[source_image_click, click_mask, image_resolution, text_prompt],
431
- outputs=[img_replace_with_mask],
432
- )
433
-
434
- def reset(*args):
435
- return [None for _ in args]
436
-
437
- clear_button_image.click(
438
- reset,
439
- inputs=[
440
- source_image_click,
441
- image_edit_complete,
442
- clicked_points,
443
- click_mask,
444
- features,
445
- img_rm_with_mask,
446
- img_fill_with_mask,
447
- img_replace_with_mask,
448
- ],
449
- outputs=[
450
- source_image_click,
451
- image_edit_complete,
452
- clicked_points,
453
- click_mask,
454
- features,
455
- img_rm_with_mask,
456
- img_fill_with_mask,
457
- img_replace_with_mask,
458
- ],
459
- )
460
-
461
- if __name__ == "__main__":
462
- demo.launch(debug=False, show_error=True)
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.path.abspath(os.path.dirname(os.getcwd())))
5
+ # os.chdir("../")
6
+ import cv2
7
+ import gradio as gr
8
+ import numpy as np
9
+ from pathlib import Path
10
+ from matplotlib import pyplot as plt
11
+ import torch
12
+ import tempfile
13
+
14
+ from stable_diffusion_inpaint import fill_img_with_sd, replace_img_with_sd
15
+ from lama_inpaint import (
16
+ inpaint_img_with_lama,
17
+ build_lama_model,
18
+ inpaint_img_with_builded_lama,
19
+ )
20
+ from utils import (
21
+ load_img_to_array,
22
+ save_array_to_img,
23
+ dilate_mask,
24
+ show_mask,
25
+ show_points,
26
+ )
27
+ from PIL import Image
28
+ from segment_anything import SamPredictor, sam_model_registry
29
+ import argparse
30
+
31
+
32
+ def setup_args(parser):
33
+ parser.add_argument(
34
+ "--lama_config",
35
+ type=str,
36
+ default="./lama/configs/prediction/default.yaml",
37
+ help="The path to the config file of lama model. "
38
+ "Default: the config of big-lama",
39
+ )
40
+ parser.add_argument(
41
+ "--lama_ckpt",
42
+ type=str,
43
+ default="./pretrained_models/big-lama",
44
+ help="The path to the lama checkpoint.",
45
+ )
46
+ parser.add_argument(
47
+ "--sam_ckpt",
48
+ type=str,
49
+ default="./pretrained_models/sam_vit_h_4b8939.pth",
50
+ help="The path to the SAM checkpoint to use for mask generation.",
51
+ )
52
+
53
+
54
+ def mkstemp(suffix, dir=None):
55
+ fd, path = tempfile.mkstemp(suffix=f"{suffix}", dir=dir)
56
+ os.close(fd)
57
+ return Path(path)
58
+
59
+
60
+ def get_sam_feat(img):
61
+ model["sam"].set_image(img)
62
+ features = model["sam"].features
63
+ orig_h = model["sam"].orig_h
64
+ orig_w = model["sam"].orig_w
65
+ input_h = model["sam"].input_h
66
+ input_w = model["sam"].input_w
67
+ model["sam"].reset_image()
68
+ return features, orig_h, orig_w, input_h, input_w
69
+
70
+
71
+ def get_fill_img_with_sd(image, mask, image_resolution, text_prompt):
72
+ device = "cuda" if torch.cuda.is_available() else "cpu"
73
+ if len(mask.shape) == 3:
74
+ mask = mask[:, :, 0]
75
+ np_image = np.array(image, dtype=np.uint8)
76
+ H, W, C = np_image.shape
77
+ np_image = HWC3(np_image)
78
+ np_image = resize_image(np_image, image_resolution)
79
+ mask = cv2.resize(
80
+ mask, (np_image.shape[1], np_image.shape[0]), interpolation=cv2.INTER_NEAREST
81
+ )
82
+
83
+ img_fill = fill_img_with_sd(np_image, mask, text_prompt, device=device)
84
+ img_fill = img_fill.astype(np.uint8)
85
+ return img_fill
86
+
87
+
88
+ def get_replace_img_with_sd(image, mask, image_resolution, text_prompt):
89
+ device = "cuda" if torch.cuda.is_available() else "cpu"
90
+ if len(mask.shape) == 3:
91
+ mask = mask[:, :, 0]
92
+ np_image = np.array(image, dtype=np.uint8)
93
+ H, W, C = np_image.shape
94
+ np_image = HWC3(np_image)
95
+ np_image = resize_image(np_image, image_resolution)
96
+ mask = cv2.resize(
97
+ mask, (np_image.shape[1], np_image.shape[0]), interpolation=cv2.INTER_NEAREST
98
+ )
99
+
100
+ img_replaced = replace_img_with_sd(np_image, mask, text_prompt, device=device)
101
+ img_replaced = img_replaced.astype(np.uint8)
102
+ return img_replaced
103
+
104
+
105
+ def HWC3(x):
106
+ assert x.dtype == np.uint8
107
+ if x.ndim == 2:
108
+ x = x[:, :, None]
109
+ assert x.ndim == 3
110
+ H, W, C = x.shape
111
+ assert C == 1 or C == 3 or C == 4
112
+ if C == 3:
113
+ return x
114
+ if C == 1:
115
+ return np.concatenate([x, x, x], axis=2)
116
+ if C == 4:
117
+ color = x[:, :, 0:3].astype(np.float32)
118
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
119
+ y = color * alpha + 255.0 * (1.0 - alpha)
120
+ y = y.clip(0, 255).astype(np.uint8)
121
+ return y
122
+
123
+
124
+ def resize_image(input_image, resolution):
125
+ H, W, C = input_image.shape
126
+ k = float(resolution) / min(H, W)
127
+ H = int(np.round(H * k / 64.0)) * 64
128
+ W = int(np.round(W * k / 64.0)) * 64
129
+ img = cv2.resize(
130
+ input_image,
131
+ (W, H),
132
+ interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA,
133
+ )
134
+ return img
135
+
136
+
137
+ def resize_points(clicked_points, original_shape, resolution):
138
+ original_height, original_width, _ = original_shape
139
+ original_height = float(original_height)
140
+ original_width = float(original_width)
141
+
142
+ scale_factor = float(resolution) / min(original_height, original_width)
143
+ resized_points = []
144
+
145
+ for point in clicked_points:
146
+ x, y, lab = point
147
+ resized_x = int(round(x * scale_factor))
148
+ resized_y = int(round(y * scale_factor))
149
+ resized_point = (resized_x, resized_y, lab)
150
+ resized_points.append(resized_point)
151
+
152
+ return resized_points
153
+
154
+
155
+ def get_click_mask(
156
+ clicked_points, features, orig_h, orig_w, input_h, input_w, dilate_kernel_size
157
+ ):
158
+ # model['sam'].set_image(image)
159
+ model["sam"].is_image_set = True
160
+ model["sam"].features = features
161
+ model["sam"].orig_h = orig_h
162
+ model["sam"].orig_w = orig_w
163
+ model["sam"].input_h = input_h
164
+ model["sam"].input_w = input_w
165
+
166
+ # Separate the points and labels
167
+ points, labels = zip(*[(point[:2], point[2]) for point in clicked_points])
168
+
169
+ # Convert the points and labels to numpy arrays
170
+ input_point = np.array(points)
171
+ input_label = np.array(labels)
172
+
173
+ masks, _, _ = model["sam"].predict(
174
+ point_coords=input_point,
175
+ point_labels=input_label,
176
+ multimask_output=False,
177
+ )
178
+ if dilate_kernel_size is not None:
179
+ masks = [dilate_mask(mask, dilate_kernel_size) for mask in masks]
180
+ else:
181
+ masks = [mask for mask in masks]
182
+
183
+ return masks
184
+
185
+
186
+ def process_image_click(
187
+ original_image,
188
+ point_prompt,
189
+ clicked_points,
190
+ image_resolution,
191
+ features,
192
+ orig_h,
193
+ orig_w,
194
+ input_h,
195
+ input_w,
196
+ dilate_kernel_size,
197
+ evt: gr.SelectData,
198
+ ):
199
+ if clicked_points is None:
200
+ clicked_points = []
201
+
202
+ # print("Received click event:", evt)
203
+ if original_image is None:
204
+ # print("No image loaded.")
205
+ return None, clicked_points, None
206
+
207
+ clicked_coords = evt.index
208
+ if clicked_coords is None:
209
+ # print("No valid coordinates received.")
210
+ return None, clicked_points, None
211
+
212
+ x, y = clicked_coords
213
+ label = point_prompt
214
+ lab = 1 if label == "Foreground Point" else 0
215
+ clicked_points.append((x, y, lab))
216
+ # print("Updated points list:", clicked_points)
217
+
218
+ input_image = np.array(original_image, dtype=np.uint8)
219
+ H, W, C = input_image.shape
220
+ input_image = HWC3(input_image)
221
+ img = resize_image(input_image, image_resolution)
222
+ # print("Processed image size:", img.shape)
223
+
224
+ resized_points = resize_points(clicked_points, input_image.shape, image_resolution)
225
+ mask_click_np = get_click_mask(
226
+ resized_points, features, orig_h, orig_w, input_h, input_w, dilate_kernel_size
227
+ )
228
+ mask_click_np = np.transpose(mask_click_np, (1, 2, 0)) * 255.0
229
+ mask_image = HWC3(mask_click_np.astype(np.uint8))
230
+ mask_image = cv2.resize(mask_image, (W, H), interpolation=cv2.INTER_LINEAR)
231
+ # print("Mask image prepared.")
232
+
233
+ edited_image = input_image
234
+ for x, y, lab in clicked_points:
235
+ color = (255, 0, 0) if lab == 1 else (0, 0, 255)
236
+ edited_image = cv2.circle(edited_image, (x, y), 20, color, -1)
237
+
238
+ opacity_mask = 0.75
239
+ opacity_edited = 1.0
240
+ overlay_image = cv2.addWeighted(
241
+ edited_image,
242
+ opacity_edited,
243
+ (mask_image * np.array([0 / 255, 255 / 255, 0 / 255])).astype(np.uint8),
244
+ opacity_mask,
245
+ 0,
246
+ )
247
+
248
+ no_mask_overlay = edited_image.copy()
249
+
250
+ return no_mask_overlay, overlay_image, clicked_points, mask_image
251
+
252
+
253
+ def image_upload(image, image_resolution):
254
+ if image is None:
255
+ return None, None, None, None, None, None
256
+ else:
257
+ np_image = np.array(image, dtype=np.uint8)
258
+ H, W, C = np_image.shape
259
+ np_image = HWC3(np_image)
260
+ np_image = resize_image(np_image, image_resolution)
261
+ features, orig_h, orig_w, input_h, input_w = get_sam_feat(np_image)
262
+ return image, features, orig_h, orig_w, input_h, input_w
263
+
264
+
265
+ def get_inpainted_img(image, mask, image_resolution):
266
+ lama_config = args.lama_config
267
+ device = "cuda" if torch.cuda.is_available() else "cpu"
268
+ if len(mask.shape) == 3:
269
+ mask = mask[:, :, 0]
270
+ img_inpainted = inpaint_img_with_builded_lama(
271
+ model["lama"], image, mask, lama_config, device=device
272
+ )
273
+ return img_inpainted
274
+
275
+
276
+ # get args
277
+ parser = argparse.ArgumentParser()
278
+ setup_args(parser)
279
+ args = parser.parse_args(sys.argv[1:])
280
+ # build models
281
+ model = {}
282
+ # build the sam model
283
+ model_type = "vit_h"
284
+ ckpt_p = args.sam_ckpt
285
+ model_sam = sam_model_registry[model_type](checkpoint=ckpt_p)
286
+ device = "cuda" if torch.cuda.is_available() else "cpu"
287
+ model_sam.to(device=device)
288
+ model["sam"] = SamPredictor(model_sam)
289
+
290
+ # build the lama model
291
+ lama_config = args.lama_config
292
+ lama_ckpt = args.lama_ckpt
293
+ device = "cuda" if torch.cuda.is_available() else "cpu"
294
+ model["lama"] = build_lama_model(lama_config, lama_ckpt, device=device)
295
+
296
+ button_size = (100, 50)
297
+ with gr.Blocks() as demo:
298
+ clicked_points = gr.State([])
299
+ # origin_image = gr.State(None)
300
+ click_mask = gr.State(None)
301
+ features = gr.State(None)
302
+ orig_h = gr.State(None)
303
+ orig_w = gr.State(None)
304
+ input_h = gr.State(None)
305
+ input_w = gr.State(None)
306
+
307
+ with gr.Row():
308
+ with gr.Column(variant="panel"):
309
+ with gr.Row():
310
+ gr.Markdown("## Upload an image and click the region you want to edit.")
311
+ with gr.Row():
312
+ source_image_click = gr.Image(
313
+ type="numpy",
314
+ interactive=True,
315
+ label="Upload and Edit Image",
316
+ )
317
+
318
+ image_edit_complete = gr.Image(
319
+ type="numpy",
320
+ interactive=False,
321
+ label="Editing Complete",
322
+ )
323
+ with gr.Row():
324
+ point_prompt = gr.Radio(
325
+ choices=["Foreground Point", "Background Point"],
326
+ value="Foreground Point",
327
+ label="Point Label",
328
+ interactive=True,
329
+ show_label=False,
330
+ )
331
+ image_resolution = gr.Slider(
332
+ label="Image Resolution",
333
+ minimum=256,
334
+ maximum=768,
335
+ value=512,
336
+ step=64,
337
+ )
338
+ dilate_kernel_size = gr.Slider(
339
+ label="Dilate Kernel Size", minimum=0, maximum=30, value=15, step=1
340
+ )
341
+ with gr.Column(variant="panel"):
342
+ with gr.Row():
343
+ gr.Markdown("## Control Panel")
344
+ text_prompt = gr.Textbox(label="Text Prompt")
345
+ lama = gr.Button("Inpaint Image", variant="primary")
346
+ fill_sd = gr.Button("Fill Anything with SD", variant="primary")
347
+ replace_sd = gr.Button("Replace Anything with SD", variant="primary")
348
+ clear_button_image = gr.Button(value="Reset", variant="secondary")
349
+
350
+ # todo: maybe we can delete this row, for it's unnecessary to show the original mask for customers
351
+ with gr.Row(variant="panel"):
352
+ with gr.Column():
353
+ with gr.Row():
354
+ gr.Markdown("## Mask")
355
+ with gr.Row():
356
+ click_mask = gr.Image(
357
+ type="numpy",
358
+ label="Click Mask",
359
+ interactive=False,
360
+ )
361
+ with gr.Column():
362
+ with gr.Row():
363
+ gr.Markdown("## Image Removed with Mask")
364
+ with gr.Row():
365
+ img_rm_with_mask = gr.Image(
366
+ type="numpy",
367
+ label="Image Removed with Mask",
368
+ interactive=False,
369
+ )
370
+
371
+ with gr.Column():
372
+ with gr.Row():
373
+ gr.Markdown("## Fill Anything with Mask")
374
+ with gr.Row():
375
+ img_fill_with_mask = gr.Image(
376
+ type="numpy",
377
+ label="Image Fill Anything with Mask",
378
+ interactive=False,
379
+ )
380
+
381
+ with gr.Column():
382
+ with gr.Row():
383
+ gr.Markdown("## Replace Anything with Mask")
384
+ with gr.Row():
385
+ img_replace_with_mask = gr.Image(
386
+ type="numpy",
387
+ label="Image Replace Anything with Mask",
388
+ interactive=False,
389
+ )
390
+
391
+ gr.Markdown(
392
+ "Github Source Code: [Link](https://github.com/pg56714/Inpaint-Anything-Gradio)"
393
+ )
394
+
395
+ source_image_click.upload(
396
+ image_upload,
397
+ inputs=[source_image_click, image_resolution],
398
+ outputs=[source_image_click, features, orig_h, orig_w, input_h, input_w],
399
+ )
400
+
401
+ source_image_click.select(
402
+ process_image_click,
403
+ inputs=[
404
+ source_image_click,
405
+ point_prompt,
406
+ clicked_points,
407
+ image_resolution,
408
+ features,
409
+ orig_h,
410
+ orig_w,
411
+ input_h,
412
+ input_w,
413
+ dilate_kernel_size,
414
+ ],
415
+ outputs=[source_image_click, image_edit_complete, clicked_points, click_mask],
416
+ show_progress=True,
417
+ queue=True,
418
+ )
419
+
420
+ lama.click(
421
+ get_inpainted_img,
422
+ inputs=[source_image_click, click_mask, image_resolution],
423
+ outputs=[img_rm_with_mask],
424
+ )
425
+
426
+ fill_sd.click(
427
+ get_fill_img_with_sd,
428
+ inputs=[source_image_click, click_mask, image_resolution, text_prompt],
429
+ outputs=[img_fill_with_mask],
430
+ )
431
+
432
+ replace_sd.click(
433
+ get_replace_img_with_sd,
434
+ inputs=[source_image_click, click_mask, image_resolution, text_prompt],
435
+ outputs=[img_replace_with_mask],
436
+ )
437
+
438
+ def reset(*args):
439
+ return [None for _ in args]
440
+
441
+ clear_button_image.click(
442
+ reset,
443
+ inputs=[
444
+ source_image_click,
445
+ image_edit_complete,
446
+ clicked_points,
447
+ click_mask,
448
+ features,
449
+ img_rm_with_mask,
450
+ img_fill_with_mask,
451
+ img_replace_with_mask,
452
+ ],
453
+ outputs=[
454
+ source_image_click,
455
+ image_edit_complete,
456
+ clicked_points,
457
+ click_mask,
458
+ features,
459
+ img_rm_with_mask,
460
+ img_fill_with_mask,
461
+ img_replace_with_mask,
462
+ ],
463
+ )
464
+
465
+ if __name__ == "__main__":
466
+ demo.launch(debug=False, show_error=True)