|
from hashlib import sha1 |
|
from pathlib import Path |
|
|
|
import cv2 |
|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
import PIL |
|
import torch |
|
from torchvision import transforms |
|
import torch.nn.functional as F |
|
|
|
|
|
def estimate_foreground_ml(image, alpha, return_background=False): |
|
""" |
|
Estimates the foreground and background of an image based on an alpha mask. |
|
|
|
Parameters: |
|
- image: numpy array of shape (H, W, 3), the input RGB image. |
|
- alpha: numpy array of shape (H, W), the alpha mask with values ranging from 0 to 1. |
|
- return_background: boolean, if True, both foreground and background are returned. |
|
|
|
Returns: |
|
- If return_background is False, returns only the foreground. |
|
- If return_background is True, returns a tuple (foreground, background). |
|
""" |
|
|
|
|
|
|
|
foreground = image * alpha |
|
|
|
if return_background: |
|
|
|
|
|
background_alpha = 1 - alpha |
|
|
|
background = (image * background_alpha) + (1 - background_alpha) * 255 |
|
|
|
return foreground, background |
|
|
|
return foreground |
|
|
|
|
|
def load_entire_model(taskname): |
|
model_ls = [] |
|
if (taskname == "mask"): |
|
model = torch.jit.load(Path("./models/sod.pt")) |
|
model.eval() |
|
model_ls.append(model) |
|
elif (taskname == "matting"): |
|
model = torch.jit.load(Path("./models/trimap.pt")) |
|
model.eval() |
|
model_ls.append(model) |
|
|
|
model = torch.jit.load(Path("./models/matting.pt")) |
|
model.eval() |
|
model_ls.append(model) |
|
else: |
|
model_ls = [] |
|
|
|
return model_ls |
|
|
|
|
|
model_names = [ |
|
"matting", |
|
"mask" |
|
] |
|
model_dict = { |
|
name: None |
|
for name in model_names |
|
} |
|
|
|
last_result = { |
|
"cache_key": None, |
|
"algorithm": None, |
|
} |
|
|
|
|
|
def image_matting( |
|
image: PIL.Image.Image, |
|
result_type: str, |
|
bg_color: str, |
|
algorithm: str, |
|
morph_op: str, |
|
morph_op_factor: float, |
|
) -> np.ndarray: |
|
image_np = np.ascontiguousarray(image) |
|
width, height = image_np.shape[1], image_np.shape[0] |
|
cache_key = sha1(image_np).hexdigest() |
|
if cache_key == last_result["cache_key"] and algorithm == last_result["algorithm"]: |
|
alpha = last_result["alpha"] |
|
else: |
|
model = load_entire_model(algorithm) |
|
transform = transforms.Compose([ |
|
|
|
transforms.Resize((798, 798)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
if (algorithm == "mask"): |
|
input_tensor = transform(image).unsqueeze(0) |
|
with torch.no_grad(): |
|
alpha = model[0](input_tensor).float() |
|
alpha = F.interpolate(alpha, [height, width], mode="bilinear") |
|
alpha = np.array(alpha* 255.).astype(np.uint8)[0][0] |
|
alpha = np.stack((alpha,alpha,alpha),axis=2) |
|
else: |
|
transform2 = transforms.Compose([ |
|
transforms.Resize((800, 800)), |
|
transforms.ToTensor(), |
|
|
|
]) |
|
|
|
input_tensor = transform(image).unsqueeze(0) |
|
with torch.no_grad(): |
|
output = model[0](input_tensor).float() |
|
output = F.interpolate(output, [height, width], mode="bilinear") |
|
|
|
trimap = np.array(output[0][0]) |
|
|
|
ratio = 0.05 |
|
site = np.where(trimap > 0) |
|
try: |
|
bbox = [np.min(site[1]), np.min(site[0]), np.max(site[1]), np.max(site[0])] |
|
except: |
|
bbox = [0, 0, width, height] |
|
|
|
x0, y0, x1, y1 = bbox |
|
H = y1 - y0 |
|
W = x1 - x0 |
|
x0 = int(max(0, x0 - ratio * W)) |
|
x1 = int(min(width, x1 + ratio * W) ) |
|
y0 = int(max(0, y0 - ratio * H) ) |
|
y1 = int(min(height, y1 + ratio * H) ) |
|
|
|
Image_input = image.crop((x0, y0, x1, y1)) |
|
|
|
input_tensor = transform2(Image_input).unsqueeze(0) |
|
|
|
trimap = trimap[y0:y1, x0:x1] |
|
trimap = np.where(trimap < 1, 0, trimap) |
|
trimap = np.where(trimap > 1, 255, trimap) |
|
trimap = np.where(trimap == 1, 128, trimap) |
|
|
|
|
|
trimap = Image.fromarray(np.uint8(trimap)).convert('L') |
|
input_tensor2 = transform2(trimap).unsqueeze(0) |
|
with torch.no_grad(): |
|
output = model[1]({'image': input_tensor, 'trimap': input_tensor2})['phas'] |
|
output = F.interpolate(output, [Image_input.size[1],Image_input.size[0]], mode="bilinear")[0].numpy() |
|
|
|
numpy_image = (output * 255.).astype(np.uint8) |
|
|
|
|
|
numpy_image = numpy_image.squeeze(0) |
|
pil_image = Image.fromarray(numpy_image, mode='L') |
|
alpha = Image.new(mode='RGB', size=image.size) |
|
alpha.paste(pil_image, (x0, y0)) |
|
|
|
|
|
alpha = np.array(alpha).astype(np.uint8) |
|
last_result["cache_key"] = cache_key |
|
last_result["algorithm"] = algorithm |
|
last_result["alpha"] = alpha |
|
|
|
|
|
image = np.array(image) |
|
kernel = np.ones((morph_op_factor, morph_op_factor), np.uint8) |
|
if morph_op == "Dilate": |
|
alpha = cv2.dilate(alpha, kernel, iterations=int(morph_op_factor)) |
|
elif morph_op == "Erode": |
|
alpha = cv2.erode(alpha, kernel, iterations=int(morph_op_factor)) |
|
else: |
|
alpha = alpha |
|
alpha = (alpha / 255).astype("float32") |
|
|
|
image = (image / 255.0).astype("float32") |
|
fg = estimate_foreground_ml(image, alpha) |
|
|
|
if result_type == "Remove BG": |
|
result = fg |
|
elif result_type == "Replace BG": |
|
bg_r = int(bg_color[1:3], base=16) |
|
bg_g = int(bg_color[3:5], base=16) |
|
bg_b = int(bg_color[5:7], base=16) |
|
|
|
bg = np.zeros_like(fg) |
|
bg[:, :, 0] = bg_r / 255. |
|
bg[:, :, 1] = bg_g / 255. |
|
bg[:, :, 2] = bg_b / 255. |
|
|
|
result = alpha * image + (1 - alpha) * bg |
|
result = np.clip(result, 0, 1) |
|
else: |
|
result = alpha |
|
|
|
return result |
|
|
|
|
|
def main(): |
|
with gr.Blocks() as app: |
|
gr.Markdown("Salient Object Matting") |
|
|
|
with gr.Row(variant="panel"): |
|
image_input = gr.Image(type='pil') |
|
image_output = gr.Image() |
|
|
|
with gr.Row(variant="panel"): |
|
result_type = gr.Radio( |
|
label="Mode", |
|
show_label=True, |
|
choices=[ |
|
"Remove BG", |
|
"Replace BG", |
|
"Generate Mask", |
|
], |
|
value="Remove BG", |
|
) |
|
bg_color = gr.ColorPicker( |
|
label="BG Color", |
|
show_label=True, |
|
value="#000000", |
|
) |
|
algorithm = gr.Dropdown( |
|
label="Algorithm", |
|
show_label=True, |
|
choices=model_names, |
|
value="matting" |
|
) |
|
|
|
with gr.Row(variant="panel"): |
|
morph_op = gr.Radio( |
|
label="Post-process", |
|
show_label=True, |
|
choices=[ |
|
"None", |
|
"Erode", |
|
"Dilate", |
|
], |
|
value="None", |
|
) |
|
|
|
morph_op_factor = gr.Slider( |
|
label="Factor", |
|
show_label=True, |
|
minimum=3, |
|
maximum=20, |
|
value=3, |
|
step=2, |
|
) |
|
|
|
run_button = gr.Button("Run") |
|
|
|
run_button.click( |
|
image_matting, |
|
inputs=[ |
|
image_input, |
|
result_type, |
|
bg_color, |
|
algorithm, |
|
morph_op, |
|
morph_op_factor, |
|
], |
|
outputs=image_output, |
|
) |
|
|
|
app.launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|