pg56714 commited on
Commit
1080e5b
·
verified ·
1 Parent(s): 952626c

Update lama_inpaint.py

Browse files
Files changed (1) hide show
  1. lama_inpaint.py +199 -200
lama_inpaint.py CHANGED
@@ -1,200 +1,199 @@
1
- import os
2
- import sys
3
- import numpy as np
4
- import torch
5
- import yaml
6
- import glob
7
- import argparse
8
- from PIL import Image
9
- from omegaconf import OmegaConf
10
- from pathlib import Path
11
-
12
- os.environ['OMP_NUM_THREADS'] = '1'
13
- os.environ['OPENBLAS_NUM_THREADS'] = '1'
14
- os.environ['MKL_NUM_THREADS'] = '1'
15
- os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
16
- os.environ['NUMEXPR_NUM_THREADS'] = '1'
17
-
18
- sys.path.insert(0, str(Path(__file__).resolve().parent / "lama"))
19
- from saicinpainting.evaluation.utils import move_to_device
20
- from saicinpainting.training.trainers import load_checkpoint
21
- from saicinpainting.evaluation.data import pad_tensor_to_modulo
22
-
23
- from utils import load_img_to_array, save_array_to_img
24
-
25
-
26
- @torch.no_grad()
27
- def inpaint_img_with_lama(
28
- img: np.ndarray,
29
- mask: np.ndarray,
30
- config_p: str,
31
- ckpt_p: str,
32
- mod=8,
33
- device="cuda"
34
- ):
35
- assert len(mask.shape) == 2
36
- if np.max(mask) == 1:
37
- mask = mask * 255
38
- img = torch.from_numpy(img).float().div(255.)
39
- mask = torch.from_numpy(mask).float()
40
- predict_config = OmegaConf.load(config_p)
41
- predict_config.model.path = ckpt_p
42
- # device = torch.device(predict_config.device)
43
- device = torch.device(device)
44
-
45
- train_config_path = os.path.join(
46
- predict_config.model.path, 'config.yaml')
47
-
48
- with open(train_config_path, 'r') as f:
49
- train_config = OmegaConf.create(yaml.safe_load(f))
50
-
51
- train_config.training_model.predict_only = True
52
- train_config.visualizer.kind = 'noop'
53
-
54
- checkpoint_path = os.path.join(
55
- predict_config.model.path, 'models',
56
- predict_config.model.checkpoint
57
- )
58
- model = load_checkpoint(
59
- train_config, checkpoint_path, strict=False, map_location='cpu')
60
- model.freeze()
61
- if not predict_config.get('refine', False):
62
- model.to(device)
63
-
64
- batch = {}
65
- batch['image'] = img.permute(2, 0, 1).unsqueeze(0)
66
- batch['mask'] = mask[None, None]
67
- unpad_to_size = [batch['image'].shape[2], batch['image'].shape[3]]
68
- batch['image'] = pad_tensor_to_modulo(batch['image'], mod)
69
- batch['mask'] = pad_tensor_to_modulo(batch['mask'], mod)
70
- batch = move_to_device(batch, device)
71
- batch['mask'] = (batch['mask'] > 0) * 1
72
-
73
- batch = model(batch)
74
- cur_res = batch[predict_config.out_key][0].permute(1, 2, 0)
75
- cur_res = cur_res.detach().cpu().numpy()
76
-
77
- if unpad_to_size is not None:
78
- orig_height, orig_width = unpad_to_size
79
- cur_res = cur_res[:orig_height, :orig_width]
80
-
81
- cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
82
- return cur_res
83
-
84
-
85
- def build_lama_model(
86
- config_p: str,
87
- ckpt_p: str,
88
- device="cuda"
89
- ):
90
- predict_config = OmegaConf.load(config_p)
91
- predict_config.model.path = ckpt_p
92
- device = torch.device(device)
93
-
94
- train_config_path = os.path.join(
95
- predict_config.model.path, 'config.yaml')
96
-
97
- with open(train_config_path, 'r') as f:
98
- train_config = OmegaConf.create(yaml.safe_load(f))
99
-
100
- train_config.training_model.predict_only = True
101
- train_config.visualizer.kind = 'noop'
102
-
103
- checkpoint_path = os.path.join(
104
- predict_config.model.path, 'models',
105
- predict_config.model.checkpoint
106
- )
107
- model = load_checkpoint(train_config, checkpoint_path, strict=False)
108
- model.to(device)
109
- model.freeze()
110
- return model
111
-
112
-
113
- @torch.no_grad()
114
- def inpaint_img_with_builded_lama(
115
- model,
116
- img: np.ndarray,
117
- mask: np.ndarray,
118
- config_p=None,
119
- mod=8,
120
- device="cuda"
121
- ):
122
- assert len(mask.shape) == 2
123
- if np.max(mask) == 1:
124
- mask = mask * 255
125
- img = torch.from_numpy(img).float().div(255.)
126
- mask = torch.from_numpy(mask).float()
127
-
128
- batch = {}
129
- batch['image'] = img.permute(2, 0, 1).unsqueeze(0)
130
- batch['mask'] = mask[None, None]
131
- unpad_to_size = [batch['image'].shape[2], batch['image'].shape[3]]
132
- batch['image'] = pad_tensor_to_modulo(batch['image'], mod)
133
- batch['mask'] = pad_tensor_to_modulo(batch['mask'], mod)
134
- batch = move_to_device(batch, device)
135
- batch['mask'] = (batch['mask'] > 0) * 1
136
-
137
- batch = model(batch)
138
- cur_res = batch["inpainted"][0].permute(1, 2, 0)
139
- cur_res = cur_res.detach().cpu().numpy()
140
-
141
- if unpad_to_size is not None:
142
- orig_height, orig_width = unpad_to_size
143
- cur_res = cur_res[:orig_height, :orig_width]
144
-
145
- cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
146
- return cur_res
147
-
148
-
149
-
150
- def setup_args(parser):
151
- parser.add_argument(
152
- "--input_img", type=str, required=True,
153
- help="Path to a single input img",
154
- )
155
- parser.add_argument(
156
- "--input_mask_glob", type=str, required=True,
157
- help="Glob to input masks",
158
- )
159
- parser.add_argument(
160
- "--output_dir", type=str, required=True,
161
- help="Output path to the directory with results.",
162
- )
163
- parser.add_argument(
164
- "--lama_config", type=str,
165
- default="./lama/configs/prediction/default.yaml",
166
- help="The path to the config file of lama model. "
167
- "Default: the config of big-lama",
168
- )
169
- parser.add_argument(
170
- "--lama_ckpt", type=str, required=True,
171
- help="The path to the lama checkpoint.",
172
- )
173
-
174
-
175
- if __name__ == "__main__":
176
- """Example usage:
177
- python lama_inpaint.py \
178
- --input_img FA_demo/FA1_dog.png \
179
- --input_mask_glob "results/FA1_dog/mask*.png" \
180
- --output_dir results \
181
- --lama_config lama/configs/prediction/default.yaml \
182
- --lama_ckpt big-lama
183
- """
184
- parser = argparse.ArgumentParser()
185
- setup_args(parser)
186
- args = parser.parse_args(sys.argv[1:])
187
- device = "cuda" if torch.cuda.is_available() else "cpu"
188
-
189
- img_stem = Path(args.input_img).stem
190
- mask_ps = sorted(glob.glob(args.input_mask_glob))
191
- out_dir = Path(args.output_dir) / img_stem
192
- out_dir.mkdir(parents=True, exist_ok=True)
193
-
194
- img = load_img_to_array(args.input_img)
195
- for mask_p in mask_ps:
196
- mask = load_img_to_array(mask_p)
197
- img_inpainted_p = out_dir / f"inpainted_with_{Path(mask_p).name}"
198
- img_inpainted = inpaint_img_with_lama(
199
- img, mask, args.lama_config, args.lama_ckpt, device=device)
200
- save_array_to_img(img_inpainted, img_inpainted_p)
 
1
+ import os
2
+ import sys
3
+ import numpy as np
4
+ import torch
5
+ import yaml
6
+ import glob
7
+ import argparse
8
+ from PIL import Image
9
+ from omegaconf import OmegaConf
10
+ from pathlib import Path
11
+
12
+ os.environ["OMP_NUM_THREADS"] = "1"
13
+ os.environ["OPENBLAS_NUM_THREADS"] = "1"
14
+ os.environ["MKL_NUM_THREADS"] = "1"
15
+ os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
16
+ os.environ["NUMEXPR_NUM_THREADS"] = "1"
17
+
18
+ sys.path.insert(0, str(Path(__file__).resolve().parent / "lama"))
19
+
20
+ from saicinpainting.evaluation.utils import move_to_device
21
+ from saicinpainting.training.trainers import load_checkpoint
22
+ from saicinpainting.evaluation.data import pad_tensor_to_modulo
23
+
24
+ from utils import load_img_to_array, save_array_to_img
25
+
26
+
27
+ @torch.no_grad()
28
+ def inpaint_img_with_lama(
29
+ img: np.ndarray, mask: np.ndarray, config_p: str, ckpt_p: str, mod=8, device="cuda"
30
+ ):
31
+ assert len(mask.shape) == 2
32
+ if np.max(mask) == 1:
33
+ mask = mask * 255
34
+ img = torch.from_numpy(img).float().div(255.0)
35
+ mask = torch.from_numpy(mask).float()
36
+ predict_config = OmegaConf.load(config_p)
37
+ predict_config.model.path = ckpt_p
38
+ # device = torch.device(predict_config.device)
39
+ device = torch.device(device)
40
+
41
+ train_config_path = os.path.join(predict_config.model.path, "config.yaml")
42
+
43
+ with open(train_config_path, "r") as f:
44
+ train_config = OmegaConf.create(yaml.safe_load(f))
45
+
46
+ train_config.training_model.predict_only = True
47
+ train_config.visualizer.kind = "noop"
48
+
49
+ checkpoint_path = os.path.join(
50
+ predict_config.model.path, "models", predict_config.model.checkpoint
51
+ )
52
+ model = load_checkpoint(
53
+ train_config, checkpoint_path, strict=False, map_location=device
54
+ )
55
+ model.freeze()
56
+ if not predict_config.get("refine", False):
57
+ model.to(device)
58
+
59
+ batch = {}
60
+ batch["image"] = img.permute(2, 0, 1).unsqueeze(0)
61
+ batch["mask"] = mask[None, None]
62
+ unpad_to_size = [batch["image"].shape[2], batch["image"].shape[3]]
63
+ batch["image"] = pad_tensor_to_modulo(batch["image"], mod)
64
+ batch["mask"] = pad_tensor_to_modulo(batch["mask"], mod)
65
+ batch = move_to_device(batch, device)
66
+ batch["mask"] = (batch["mask"] > 0) * 1
67
+
68
+ batch = model(batch)
69
+ cur_res = batch[predict_config.out_key][0].permute(1, 2, 0)
70
+ cur_res = cur_res.detach().cpu().numpy()
71
+
72
+ if unpad_to_size is not None:
73
+ orig_height, orig_width = unpad_to_size
74
+ cur_res = cur_res[:orig_height, :orig_width]
75
+
76
+ cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
77
+ return cur_res
78
+
79
+
80
+ def build_lama_model(config_p: str, ckpt_p: str, device="cuda"):
81
+ predict_config = OmegaConf.load(config_p)
82
+ predict_config.model.path = ckpt_p
83
+ # device = torch.device(predict_config.device)
84
+ device = torch.device(device)
85
+
86
+ train_config_path = os.path.join(predict_config.model.path, "config.yaml")
87
+
88
+ with open(train_config_path, "r") as f:
89
+ train_config = OmegaConf.create(yaml.safe_load(f))
90
+
91
+ train_config.training_model.predict_only = True
92
+ train_config.visualizer.kind = "noop"
93
+
94
+ checkpoint_path = os.path.join(
95
+ predict_config.model.path, "models", predict_config.model.checkpoint
96
+ )
97
+ model = load_checkpoint(
98
+ train_config, checkpoint_path, strict=False, map_location=device
99
+ )
100
+ model.freeze()
101
+ if not predict_config.get("refine", False):
102
+ model.to(device)
103
+
104
+ return model
105
+
106
+
107
+ @torch.no_grad()
108
+ def inpaint_img_with_builded_lama(
109
+ model, img: np.ndarray, mask: np.ndarray, config_p: str, mod=8, device="cuda"
110
+ ):
111
+ assert len(mask.shape) == 2
112
+ if np.max(mask) == 1:
113
+ mask = mask * 255
114
+ img = torch.from_numpy(img).float().div(255.0)
115
+ mask = torch.from_numpy(mask).float()
116
+ predict_config = OmegaConf.load(config_p)
117
+
118
+ batch = {}
119
+ batch["image"] = img.permute(2, 0, 1).unsqueeze(0)
120
+ batch["mask"] = mask[None, None]
121
+ unpad_to_size = [batch["image"].shape[2], batch["image"].shape[3]]
122
+ batch["image"] = pad_tensor_to_modulo(batch["image"], mod)
123
+ batch["mask"] = pad_tensor_to_modulo(batch["mask"], mod)
124
+ batch = move_to_device(batch, device)
125
+ batch["mask"] = (batch["mask"] > 0) * 1
126
+
127
+ batch = model(batch)
128
+ cur_res = batch[predict_config.out_key][0].permute(1, 2, 0)
129
+ cur_res = cur_res.detach().cpu().numpy()
130
+
131
+ if unpad_to_size is not None:
132
+ orig_height, orig_width = unpad_to_size
133
+ cur_res = cur_res[:orig_height, :orig_width]
134
+
135
+ cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
136
+ return cur_res
137
+
138
+
139
+ def setup_args(parser):
140
+ parser.add_argument(
141
+ "--input_img",
142
+ type=str,
143
+ required=True,
144
+ help="Path to a single input img",
145
+ )
146
+ parser.add_argument(
147
+ "--input_mask_glob",
148
+ type=str,
149
+ required=True,
150
+ help="Glob to input masks",
151
+ )
152
+ parser.add_argument(
153
+ "--output_dir",
154
+ type=str,
155
+ required=True,
156
+ help="Output path to the directory with results.",
157
+ )
158
+ parser.add_argument(
159
+ "--lama_config",
160
+ type=str,
161
+ default="./third_party/lama/configs/prediction/default.yaml",
162
+ help="The path to the config file of lama model. "
163
+ "Default: the config of big-lama",
164
+ )
165
+ parser.add_argument(
166
+ "--lama_ckpt",
167
+ type=str,
168
+ required=True,
169
+ help="The path to the lama checkpoint.",
170
+ )
171
+
172
+
173
+ if __name__ == "__main__":
174
+ """Example usage:
175
+ python lama_inpaint.py \
176
+ --input_img FA_demo/FA1_dog.png \
177
+ --input_mask_glob "results/FA1_dog/mask*.png" \
178
+ --output_dir results \
179
+ --lama_config lama/configs/prediction/default.yaml \
180
+ --lama_ckpt big-lama
181
+ """
182
+ parser = argparse.ArgumentParser()
183
+ setup_args(parser)
184
+ args = parser.parse_args(sys.argv[1:])
185
+ device = "cuda" if torch.cuda.is_available() else "cpu"
186
+
187
+ img_stem = Path(args.input_img).stem
188
+ mask_ps = sorted(glob.glob(args.input_mask_glob))
189
+ out_dir = Path(args.output_dir) / img_stem
190
+ out_dir.mkdir(parents=True, exist_ok=True)
191
+
192
+ img = load_img_to_array(args.input_img)
193
+ for mask_p in mask_ps:
194
+ mask = load_img_to_array(mask_p)
195
+ img_inpainted_p = out_dir / f"inpainted_with_{Path(mask_p).name}"
196
+ img_inpainted = inpaint_img_with_lama(
197
+ img, mask, args.lama_config, args.lama_ckpt, device=device
198
+ )
199
+ save_array_to_img(img_inpainted, img_inpainted_p)