pg56714 commited on
Commit
77c9ee7
·
verified ·
1 Parent(s): 1080e5b

Update lama_inpaint.py

Browse files
Files changed (1) hide show
  1. lama_inpaint.py +27 -15
lama_inpaint.py CHANGED
@@ -5,7 +5,6 @@ import torch
5
  import yaml
6
  import glob
7
  import argparse
8
- from PIL import Image
9
  from omegaconf import OmegaConf
10
  from pathlib import Path
11
 
@@ -20,6 +19,7 @@ sys.path.insert(0, str(Path(__file__).resolve().parent / "lama"))
20
  from saicinpainting.evaluation.utils import move_to_device
21
  from saicinpainting.training.trainers import load_checkpoint
22
  from saicinpainting.evaluation.data import pad_tensor_to_modulo
 
23
 
24
  from utils import load_img_to_array, save_array_to_img
25
 
@@ -53,8 +53,7 @@ def inpaint_img_with_lama(
53
  train_config, checkpoint_path, strict=False, map_location=device
54
  )
55
  model.freeze()
56
- if not predict_config.get("refine", False):
57
- model.to(device)
58
 
59
  batch = {}
60
  batch["image"] = img.permute(2, 0, 1).unsqueeze(0)
@@ -62,16 +61,30 @@ def inpaint_img_with_lama(
62
  unpad_to_size = [batch["image"].shape[2], batch["image"].shape[3]]
63
  batch["image"] = pad_tensor_to_modulo(batch["image"], mod)
64
  batch["mask"] = pad_tensor_to_modulo(batch["mask"], mod)
65
- batch = move_to_device(batch, device)
66
- batch["mask"] = (batch["mask"] > 0) * 1
67
-
68
- batch = model(batch)
69
- cur_res = batch[predict_config.out_key][0].permute(1, 2, 0)
70
- cur_res = cur_res.detach().cpu().numpy()
71
-
72
- if unpad_to_size is not None:
73
- orig_height, orig_width = unpad_to_size
74
- cur_res = cur_res[:orig_height, :orig_width]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
77
  return cur_res
@@ -98,8 +111,7 @@ def build_lama_model(config_p: str, ckpt_p: str, device="cuda"):
98
  train_config, checkpoint_path, strict=False, map_location=device
99
  )
100
  model.freeze()
101
- if not predict_config.get("refine", False):
102
- model.to(device)
103
 
104
  return model
105
 
 
5
  import yaml
6
  import glob
7
  import argparse
 
8
  from omegaconf import OmegaConf
9
  from pathlib import Path
10
 
 
19
  from saicinpainting.evaluation.utils import move_to_device
20
  from saicinpainting.training.trainers import load_checkpoint
21
  from saicinpainting.evaluation.data import pad_tensor_to_modulo
22
+ from saicinpainting.evaluation.refinement import refine_predict
23
 
24
  from utils import load_img_to_array, save_array_to_img
25
 
 
53
  train_config, checkpoint_path, strict=False, map_location=device
54
  )
55
  model.freeze()
56
+ model.to(device)
 
57
 
58
  batch = {}
59
  batch["image"] = img.permute(2, 0, 1).unsqueeze(0)
 
61
  unpad_to_size = [batch["image"].shape[2], batch["image"].shape[3]]
62
  batch["image"] = pad_tensor_to_modulo(batch["image"], mod)
63
  batch["mask"] = pad_tensor_to_modulo(batch["mask"], mod)
64
+ # batch = move_to_device(batch, device)
65
+ # batch["mask"] = (batch["mask"] > 0) * 1
66
+
67
+ # batch = model(batch)
68
+ # cur_res = batch[predict_config.out_key][0].permute(1, 2, 0)
69
+ # cur_res = cur_res.detach().cpu().numpy()
70
+ if predict_config.get("refine", False):
71
+ batch["unpad_to_size"] = [torch.tensor([size]) for size in unpad_to_size]
72
+ cur_res = refine_predict(batch, model, **predict_config.refiner)
73
+ cur_res = cur_res[0].permute(1, 2, 0).detach().cpu().numpy()
74
+ else:
75
+ batch = move_to_device(batch, device)
76
+ batch["mask"] = (batch["mask"] > 0) * 1
77
+ batch = model(batch)
78
+ cur_res = batch[predict_config.out_key][0].permute(1, 2, 0)
79
+ cur_res = cur_res.detach().cpu().numpy()
80
+
81
+ if unpad_to_size is not None:
82
+ orig_height, orig_width = unpad_to_size
83
+ cur_res = cur_res[:orig_height, :orig_width]
84
+
85
+ # if unpad_to_size is not None:
86
+ # orig_height, orig_width = unpad_to_size
87
+ # cur_res = cur_res[:orig_height, :orig_width]
88
 
89
  cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
90
  return cur_res
 
111
  train_config, checkpoint_path, strict=False, map_location=device
112
  )
113
  model.freeze()
114
+ model.to(device)
 
115
 
116
  return model
117