vemodalen commited on
Commit
bf44779
·
verified ·
1 Parent(s): 3b67d42

Upload 7 files

Browse files

update model and v0.1 pipeline

Files changed (6) hide show
  1. .gitattributes +13 -0
  2. app.py +271 -0
  3. models/matting.pt +3 -0
  4. models/sod.pt +3 -0
  5. models/trimap.pt +3 -0
  6. requirements.txt +0 -0
.gitattributes CHANGED
@@ -33,3 +33,16 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+
37
+ ### Python ###
38
+ # Byte-compiled / optimized / DLL files
39
+ __pycache__/
40
+ *.py[cod]
41
+ *$py.class
42
+
43
+ # PyCharm
44
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
45
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
46
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
47
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
48
+ #.idea/
app.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from hashlib import sha1
2
+ from pathlib import Path
3
+
4
+ import cv2
5
+ import gradio as gr
6
+ import numpy as np
7
+ from PIL import Image
8
+ import PIL
9
+ import torch
10
+ from torchvision import transforms
11
+ import torch.nn.functional as F
12
+
13
+
14
+ def estimate_foreground_ml(image, alpha, return_background=False):
15
+ """
16
+ Estimates the foreground and background of an image based on an alpha mask.
17
+
18
+ Parameters:
19
+ - image: numpy array of shape (H, W, 3), the input RGB image.
20
+ - alpha: numpy array of shape (H, W), the alpha mask with values ranging from 0 to 1.
21
+ - return_background: boolean, if True, both foreground and background are returned.
22
+
23
+ Returns:
24
+ - If return_background is False, returns only the foreground.
25
+ - If return_background is True, returns a tuple (foreground, background).
26
+ """
27
+
28
+ # Estimating foreground
29
+ # Expand alpha dimensions from (H, W) to (H, W, 1) to make it compatible for element-wise multiplication with the RGB image
30
+ foreground = image * alpha
31
+
32
+ if return_background:
33
+ # Estimating background
34
+ # Inverse alpha mask to isolate background
35
+ background_alpha = 1 - alpha
36
+ # Assuming a white background. This can be changed based on the application or estimated from the image.
37
+ background = (image * background_alpha) + (1 - background_alpha) * 255
38
+
39
+ return foreground, background
40
+
41
+ return foreground
42
+
43
+
44
+ def load_entire_model(taskname):
45
+ model_ls = []
46
+ if (taskname == "mask"):
47
+ model = torch.jit.load(Path("./models/sod.pt"))
48
+ model.eval()
49
+ model_ls.append(model)
50
+ elif (taskname == "matting"):
51
+ model = torch.jit.load(Path("./models/trimap.pt"))
52
+ model.eval()
53
+ model_ls.append(model)
54
+
55
+ model = torch.jit.load(Path("./models/matting.pt"))
56
+ model.eval()
57
+ model_ls.append(model)
58
+ else:
59
+ model_ls = []
60
+
61
+ return model_ls
62
+
63
+
64
+ model_names = [
65
+ "matting",
66
+ "mask"
67
+ ]
68
+ model_dict = {
69
+ name: None
70
+ for name in model_names
71
+ }
72
+
73
+ last_result = {
74
+ "cache_key": None,
75
+ "algorithm": None,
76
+ }
77
+
78
+
79
+ def image_matting(
80
+ image: PIL.Image.Image,
81
+ result_type: str,
82
+ bg_color: str,
83
+ algorithm: str,
84
+ morph_op: str,
85
+ morph_op_factor: float,
86
+ ) -> np.ndarray:
87
+ image_np = np.ascontiguousarray(image)
88
+ width, height = image_np.shape[1], image_np.shape[0]
89
+ cache_key = sha1(image_np).hexdigest()
90
+ if cache_key == last_result["cache_key"] and algorithm == last_result["algorithm"]:
91
+ alpha = last_result["alpha"]
92
+ else:
93
+ model = load_entire_model(algorithm)
94
+ transform = transforms.Compose([
95
+ # transforms.ToPILImage(),
96
+ transforms.Resize((798, 798)),
97
+ transforms.ToTensor(),
98
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
99
+ ])
100
+ if (algorithm == "mask"):
101
+ input_tensor = transform(image).unsqueeze(0)
102
+ with torch.no_grad():
103
+ alpha = model[0](input_tensor).float()
104
+ alpha = F.interpolate(alpha, [height, width], mode="bilinear")
105
+ alpha = np.array(alpha* 255.).astype(np.uint8)[0][0]
106
+ alpha = np.stack((alpha,alpha,alpha),axis=2)
107
+ else:
108
+ transform2 = transforms.Compose([
109
+ transforms.Resize((800, 800)),
110
+ transforms.ToTensor(),
111
+ # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
112
+ ])
113
+
114
+ input_tensor = transform(image).unsqueeze(0)
115
+ with torch.no_grad():
116
+ output = model[0](input_tensor).float()
117
+ output = F.interpolate(output, [height, width], mode="bilinear")
118
+
119
+ trimap = np.array(output[0][0])
120
+
121
+ ratio = 0.05
122
+ site = np.where(trimap > 0)
123
+ try:
124
+ bbox = [np.min(site[1]), np.min(site[0]), np.max(site[1]), np.max(site[0])]
125
+ except:
126
+ bbox = [0, 0, width, height]
127
+
128
+ x0, y0, x1, y1 = bbox
129
+ H = y1 - y0
130
+ W = x1 - x0
131
+ x0 = int(max(0, x0 - ratio * W))
132
+ x1 = int(min(width, x1 + ratio * W) )
133
+ y0 = int(max(0, y0 - ratio * H) )
134
+ y1 = int(min(height, y1 + ratio * H) )
135
+
136
+ Image_input = image.crop((x0, y0, x1, y1))
137
+ # Image_input.save('image.png')
138
+ input_tensor = transform2(Image_input).unsqueeze(0)
139
+
140
+ trimap = trimap[y0:y1, x0:x1]
141
+ trimap = np.where(trimap < 1, 0, trimap)
142
+ trimap = np.where(trimap > 1, 255, trimap)
143
+ trimap = np.where(trimap == 1, 128, trimap)
144
+ # cv2.imwrite("trimap.png", trimap)
145
+
146
+ trimap = Image.fromarray(np.uint8(trimap)).convert('L')
147
+ input_tensor2 = transform2(trimap).unsqueeze(0)
148
+ with torch.no_grad():
149
+ output = model[1]({'image': input_tensor, 'trimap': input_tensor2})['phas']
150
+ output = F.interpolate(output, [Image_input.size[1],Image_input.size[0]], mode="bilinear")[0].numpy()
151
+
152
+ numpy_image = (output * 255.).astype(np.uint8) # Scale to [0, 255] and convert to uint8
153
+
154
+ # Step 4: Remove the channel dimension since it's a grayscale image
155
+ numpy_image = numpy_image.squeeze(0) # Convert from (1, H, W) to (H, W)
156
+ pil_image = Image.fromarray(numpy_image, mode='L')
157
+ alpha = Image.new(mode='RGB', size=image.size)
158
+ alpha.paste(pil_image, (x0, y0))
159
+ # alpha.save('tmp.png')
160
+
161
+ alpha = np.array(alpha).astype(np.uint8)
162
+ last_result["cache_key"] = cache_key
163
+ last_result["algorithm"] = algorithm
164
+ last_result["alpha"] = alpha
165
+
166
+ # alpha = (alpha * 255).astype(np.uint8)
167
+ image = np.array(image)
168
+ kernel = np.ones((morph_op_factor, morph_op_factor), np.uint8)
169
+ if morph_op == "Dilate":
170
+ alpha = cv2.dilate(alpha, kernel, iterations=int(morph_op_factor))
171
+ elif morph_op == "Erode":
172
+ alpha = cv2.erode(alpha, kernel, iterations=int(morph_op_factor))
173
+ else:
174
+ alpha = alpha
175
+ alpha = (alpha / 255).astype("float32")
176
+
177
+ image = (image / 255.0).astype("float32")
178
+ fg = estimate_foreground_ml(image, alpha)
179
+
180
+ if result_type == "Remove BG":
181
+ result = fg
182
+ elif result_type == "Replace BG":
183
+ bg_r = int(bg_color[1:3], base=16)
184
+ bg_g = int(bg_color[3:5], base=16)
185
+ bg_b = int(bg_color[5:7], base=16)
186
+
187
+ bg = np.zeros_like(fg)
188
+ bg[:, :, 0] = bg_r / 255.
189
+ bg[:, :, 1] = bg_g / 255.
190
+ bg[:, :, 2] = bg_b / 255.
191
+
192
+ result = alpha * image + (1 - alpha) * bg
193
+ result = np.clip(result, 0, 1)
194
+ else:
195
+ result = alpha
196
+
197
+ return result
198
+
199
+
200
+ def main():
201
+ with gr.Blocks() as app:
202
+ gr.Markdown("Salient Object Matting")
203
+
204
+ with gr.Row(variant="panel"):
205
+ image_input = gr.Image(type='pil')
206
+ image_output = gr.Image()
207
+
208
+ with gr.Row(variant="panel"):
209
+ result_type = gr.Radio(
210
+ label="Mode",
211
+ show_label=True,
212
+ choices=[
213
+ "Remove BG",
214
+ "Replace BG",
215
+ "Generate Mask",
216
+ ],
217
+ value="Remove BG",
218
+ )
219
+ bg_color = gr.ColorPicker(
220
+ label="BG Color",
221
+ show_label=True,
222
+ value="#000000",
223
+ )
224
+ algorithm = gr.Dropdown(
225
+ label="Algorithm",
226
+ show_label=True,
227
+ choices=model_names,
228
+ value="matting"
229
+ )
230
+
231
+ with gr.Row(variant="panel"):
232
+ morph_op = gr.Radio(
233
+ label="Post-process",
234
+ show_label=True,
235
+ choices=[
236
+ "None",
237
+ "Erode",
238
+ "Dilate",
239
+ ],
240
+ value="None",
241
+ )
242
+
243
+ morph_op_factor = gr.Slider(
244
+ label="Factor",
245
+ show_label=True,
246
+ minimum=3,
247
+ maximum=20,
248
+ value=3,
249
+ step=2,
250
+ )
251
+
252
+ run_button = gr.Button("Run")
253
+
254
+ run_button.click(
255
+ image_matting,
256
+ inputs=[
257
+ image_input,
258
+ result_type,
259
+ bg_color,
260
+ algorithm,
261
+ morph_op,
262
+ morph_op_factor,
263
+ ],
264
+ outputs=image_output,
265
+ )
266
+
267
+ app.launch()
268
+
269
+
270
+ if __name__ == "__main__":
271
+ main()
models/matting.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00f2ab77b8f35af8509410df12f0dd14645b49d540da16ab84f78d9497a48d61
3
+ size 387204217
models/sod.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4783bd4a1fd075d43e486ec81224ad831772dd178817dafd251af4016f9048ca
3
+ size 356605803
models/trimap.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69173b0fe662c967c53e2274ea128a4d4ed68f88e60e05af4ca540d99b95e450
3
+ size 356607339
requirements.txt ADDED
Binary file (222 Bytes). View file