RysonFeng commited on
Commit
c331e65
·
1 Parent(s): d63b142

Update the demo of remove anything!

Browse files
Files changed (3) hide show
  1. .gitignore +0 -3
  2. app.py +53 -15
  3. pretrained_models/big-lama/config.yaml +3 -157
.gitignore CHANGED
@@ -10,6 +10,3 @@ __pycache__/
10
 
11
  # tmp
12
  ~*
13
-
14
- # third_party git
15
- third_party/**.git
 
10
 
11
  # tmp
12
  ~*
 
 
 
app.py CHANGED
@@ -9,6 +9,7 @@ from sam_segment import predict_masks_with_sam
9
  from lama_inpaint import inpaint_img_with_lama
10
  from utils import load_img_to_array, save_array_to_img, dilate_mask, \
11
  show_mask, show_points
 
12
 
13
 
14
  def mkstemp(suffix, dir=None):
@@ -17,10 +18,12 @@ def mkstemp(suffix, dir=None):
17
  return Path(path)
18
 
19
 
20
- def get_masked_img(img, point_coords):
21
  point_labels = [1]
 
22
  dilate_kernel_size = 15
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
24
  masks, _, _ = predict_masks_with_sam(
25
  img,
26
  [point_coords],
@@ -29,11 +32,14 @@ def get_masked_img(img, point_coords):
29
  ckpt_p="pretrained_models/sam_vit_h_4b8939.pth",
30
  device=device,
31
  )
 
32
  masks = masks.astype(np.uint8) * 255
33
 
34
  # dilate mask to avoid unmasked edge effect
35
  if dilate_kernel_size is not None:
36
  masks = [dilate_mask(mask, dilate_kernel_size) for mask in masks]
 
 
37
 
38
  figs = []
39
  for idx, mask in enumerate(masks):
@@ -44,39 +50,71 @@ def get_masked_img(img, point_coords):
44
  fig = plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
45
  plt.imshow(img)
46
  plt.axis('off')
47
- # show_points(plt.gca(), [point_coords], point_labels,
48
- # size=(width*0.04)**2)
49
- # plt.savefig(tmp_p, bbox_inches='tight', pad_inches=0)
50
  show_mask(plt.gca(), mask, random_color=False)
51
  plt.savefig(tmp_p, bbox_inches='tight', pad_inches=0)
52
  figs.append(fig)
53
  plt.close()
54
- return figs
55
 
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  with gr.Blocks() as demo:
59
  with gr.Row():
60
  img = gr.Image(label="Image")
61
- with gr.Row(label="Image with Segmentation Mask"):
62
- img_with_mask_0 = gr.Plot()
63
- img_with_mask_1 = gr.Plot()
64
- img_with_mask_2 = gr.Plot()
 
 
 
65
  with gr.Row():
66
- w = gr.Number()
67
- h = gr.Number()
 
68
 
69
- predict_mask = gr.Button("Predict Mask Using SAM")
 
 
 
70
 
 
 
 
 
 
 
 
71
 
72
  def get_select_coords(evt: gr.SelectData):
73
  return evt.index[0], evt.index[1]
74
 
75
  img.select(get_select_coords, [], [w, h])
76
- predict_mask.click(
77
  get_masked_img,
78
- [img, [w, h]],
79
- [img_with_mask_0, img_with_mask_1, img_with_mask_2]
 
 
 
 
 
 
80
  )
81
 
82
 
 
9
  from lama_inpaint import inpaint_img_with_lama
10
  from utils import load_img_to_array, save_array_to_img, dilate_mask, \
11
  show_mask, show_points
12
+ from PIL import Image
13
 
14
 
15
  def mkstemp(suffix, dir=None):
 
18
  return Path(path)
19
 
20
 
21
+ def get_masked_img(img, w, h):
22
  point_labels = [1]
23
+ point_coords = [w, h]
24
  dilate_kernel_size = 15
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
+
27
  masks, _, _ = predict_masks_with_sam(
28
  img,
29
  [point_coords],
 
32
  ckpt_p="pretrained_models/sam_vit_h_4b8939.pth",
33
  device=device,
34
  )
35
+
36
  masks = masks.astype(np.uint8) * 255
37
 
38
  # dilate mask to avoid unmasked edge effect
39
  if dilate_kernel_size is not None:
40
  masks = [dilate_mask(mask, dilate_kernel_size) for mask in masks]
41
+ else:
42
+ masks = [mask for mask in masks]
43
 
44
  figs = []
45
  for idx, mask in enumerate(masks):
 
50
  fig = plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
51
  plt.imshow(img)
52
  plt.axis('off')
53
+ show_points(plt.gca(), [point_coords], point_labels,
54
+ size=(width*0.04)**2)
 
55
  show_mask(plt.gca(), mask, random_color=False)
56
  plt.savefig(tmp_p, bbox_inches='tight', pad_inches=0)
57
  figs.append(fig)
58
  plt.close()
59
+ return *figs, *masks
60
 
61
 
62
+ def get_inpainted_img(img, mask0, mask1, mask2):
63
+ lama_config = "third_party/lama/configs/prediction/default.yaml"
64
+ lama_ckpt = "pretrained_models/big-lama"
65
+ device = "cuda" if torch.cuda.is_available() else "cpu"
66
+ out = []
67
+ for mask in [mask0, mask1, mask2]:
68
+ if len(mask.shape)==3:
69
+ mask = mask[:,:,0]
70
+ img_inpainted = inpaint_img_with_lama(
71
+ img, mask, lama_config, lama_ckpt, device=device)
72
+ out.append(img_inpainted)
73
+ return out
74
+
75
 
76
  with gr.Blocks() as demo:
77
  with gr.Row():
78
  img = gr.Image(label="Image")
79
+ with gr.Column():
80
+ with gr.Row():
81
+ w = gr.Number(label="Point Coordinate W")
82
+ h = gr.Number(label="Point Coordinate H")
83
+ sam = gr.Button("Predict Mask Using SAM")
84
+ lama = gr.Button("Inpaint Image Using LaMA")
85
+
86
  with gr.Row():
87
+ mask_0 = gr.outputs.Image(type="numpy", label="Segmentation Mask 0")
88
+ mask_1 = gr.outputs.Image(type="numpy", label="Segmentation Mask 1")
89
+ mask_2 = gr.outputs.Image(type="numpy", label="Segmentation Mask 2")
90
 
91
+ with gr.Row():
92
+ img_with_mask_0 = gr.Plot(label="Image with Segmentation Mask 0")
93
+ img_with_mask_1 = gr.Plot(label="Image with Segmentation Mask 1")
94
+ img_with_mask_2 = gr.Plot(label="Image with Segmentation Mask 2")
95
 
96
+ with gr.Row():
97
+ img_rm_with_mask_0 = gr.outputs.Image(
98
+ type="numpy", label="Image Removed with Segmentation Mask 0")
99
+ img_rm_with_mask_1 = gr.outputs.Image(
100
+ type="numpy", label="Image Removed with Segmentation Mask 1")
101
+ img_rm_with_mask_2 = gr.outputs.Image(
102
+ type="numpy", label="Image Removed with Segmentation Mask 2")
103
 
104
  def get_select_coords(evt: gr.SelectData):
105
  return evt.index[0], evt.index[1]
106
 
107
  img.select(get_select_coords, [], [w, h])
108
+ sam.click(
109
  get_masked_img,
110
+ [img, w, h],
111
+ [img_with_mask_0, img_with_mask_1, img_with_mask_2, mask_0, mask_1, mask_2]
112
+ )
113
+
114
+ lama.click(
115
+ get_inpainted_img,
116
+ [img, mask_0, mask_1, mask_2],
117
+ [img_rm_with_mask_0, img_rm_with_mask_1, img_rm_with_mask_2]
118
  )
119
 
120
 
pretrained_models/big-lama/config.yaml CHANGED
@@ -1,157 +1,3 @@
1
- run_title: b18_ffc075_batch8x15
2
- training_model:
3
- kind: default
4
- visualize_each_iters: 1000
5
- concat_mask: true
6
- store_discr_outputs_for_vis: true
7
- losses:
8
- l1:
9
- weight_missing: 0
10
- weight_known: 10
11
- perceptual:
12
- weight: 0
13
- adversarial:
14
- kind: r1
15
- weight: 10
16
- gp_coef: 0.001
17
- mask_as_fake_target: true
18
- allow_scale_mask: true
19
- feature_matching:
20
- weight: 100
21
- resnet_pl:
22
- weight: 30
23
- weights_path: ${env:TORCH_HOME}
24
-
25
- optimizers:
26
- generator:
27
- kind: adam
28
- lr: 0.001
29
- discriminator:
30
- kind: adam
31
- lr: 0.0001
32
- visualizer:
33
- key_order:
34
- - image
35
- - predicted_image
36
- - discr_output_fake
37
- - discr_output_real
38
- - inpainted
39
- rescale_keys:
40
- - discr_output_fake
41
- - discr_output_real
42
- kind: directory
43
- outdir: /group-volume/User-Driven-Content-Generation/r.suvorov/inpainting/experiments/r.suvorov_2021-04-30_14-41-12_train_simple_pix2pix2_gap_sdpl_novgg_large_b18_ffc075_batch8x15/samples
44
- location:
45
- data_root_dir: /group-volume/User-Driven-Content-Generation/datasets/inpainting_data_root_large
46
- out_root_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/experiments
47
- tb_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/tb_logs
48
- data:
49
- batch_size: 15
50
- val_batch_size: 2
51
- num_workers: 3
52
- train:
53
- indir: ${location.data_root_dir}/train
54
- out_size: 256
55
- mask_gen_kwargs:
56
- irregular_proba: 1
57
- irregular_kwargs:
58
- max_angle: 4
59
- max_len: 200
60
- max_width: 100
61
- max_times: 5
62
- min_times: 1
63
- box_proba: 1
64
- box_kwargs:
65
- margin: 10
66
- bbox_min_size: 30
67
- bbox_max_size: 150
68
- max_times: 3
69
- min_times: 1
70
- segm_proba: 0
71
- segm_kwargs:
72
- confidence_threshold: 0.5
73
- max_object_area: 0.5
74
- min_mask_area: 0.07
75
- downsample_levels: 6
76
- num_variants_per_mask: 1
77
- rigidness_mode: 1
78
- max_foreground_coverage: 0.3
79
- max_foreground_intersection: 0.7
80
- max_mask_intersection: 0.1
81
- max_hidden_area: 0.1
82
- max_scale_change: 0.25
83
- horizontal_flip: true
84
- max_vertical_shift: 0.2
85
- position_shuffle: true
86
- transform_variant: distortions
87
- dataloader_kwargs:
88
- batch_size: ${data.batch_size}
89
- shuffle: true
90
- num_workers: ${data.num_workers}
91
- val:
92
- indir: ${location.data_root_dir}/val
93
- img_suffix: .png
94
- dataloader_kwargs:
95
- batch_size: ${data.val_batch_size}
96
- shuffle: false
97
- num_workers: ${data.num_workers}
98
- visual_test:
99
- indir: ${location.data_root_dir}/korean_test
100
- img_suffix: _input.png
101
- pad_out_to_modulo: 32
102
- dataloader_kwargs:
103
- batch_size: 1
104
- shuffle: false
105
- num_workers: ${data.num_workers}
106
- generator:
107
- kind: ffc_resnet
108
- input_nc: 4
109
- output_nc: 3
110
- ngf: 64
111
- n_downsampling: 3
112
- n_blocks: 18
113
- add_out_act: sigmoid
114
- init_conv_kwargs:
115
- ratio_gin: 0
116
- ratio_gout: 0
117
- enable_lfu: false
118
- downsample_conv_kwargs:
119
- ratio_gin: ${generator.init_conv_kwargs.ratio_gout}
120
- ratio_gout: ${generator.downsample_conv_kwargs.ratio_gin}
121
- enable_lfu: false
122
- resnet_conv_kwargs:
123
- ratio_gin: 0.75
124
- ratio_gout: ${generator.resnet_conv_kwargs.ratio_gin}
125
- enable_lfu: false
126
- discriminator:
127
- kind: pix2pixhd_nlayer
128
- input_nc: 3
129
- ndf: 64
130
- n_layers: 4
131
- evaluator:
132
- kind: default
133
- inpainted_key: inpainted
134
- integral_kind: ssim_fid100_f1
135
- trainer:
136
- kwargs:
137
- gpus: -1
138
- accelerator: ddp
139
- max_epochs: 200
140
- gradient_clip_val: 1
141
- log_gpu_memory: None
142
- limit_train_batches: 25000
143
- val_check_interval: ${trainer.kwargs.limit_train_batches}
144
- log_every_n_steps: 1000
145
- precision: 32
146
- terminate_on_nan: false
147
- check_val_every_n_epoch: 1
148
- num_sanity_val_steps: 8
149
- limit_val_batches: 1000
150
- replace_sampler_ddp: false
151
- checkpoint_kwargs:
152
- verbose: true
153
- save_top_k: 5
154
- save_last: true
155
- period: 1
156
- monitor: val_ssim_fid100_f1_total_mean
157
- mode: max
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fdeed49926e13b101c4dd9e193acec9e58677dfdb4ba49dd6a3a8927964e2a7
3
+ size 3947