import PIL.Image
import gradio as gr
from PIL import ImageColor
from rembg import new_session, remove
from utils import *
remove_bg_models = {
"U2NET": "u2net",
"U2NET Human Seg": "u2net_human_seg",
"U2NET Cloth Seg": "u2net_cloth_seg"
}
model_choices = keys(remove_bg_models)
def alpha_matting(state):
if state:
return 270, 20, 11
return 0, 0, 0
def predict(image,
session,
matting,
only_mask,
post_process_mask,
foreground_threshold,
background_threshold,
matting_erode_size,
new_bg_color,
bg_color,
transparency):
session = new_session(remove_bg_models[session])
if new_bg_color:
r, g, b, _ = ImageColor.getcolor(bg_color, "RGBA")
bg_color = r, g, b, transparency
else:
bg_color = None
return remove(data=image,
session=session,
alpha_matting=matting,
only_mask=only_mask,
post_process_mask=post_process_mask,
bgcolor=bg_color,
alpha_matting_foreground_threshold=foreground_threshold,
alpha_matting_background_threshold=background_threshold,
alpha_matting_erode_size=matting_erode_size)
footer = r"""
Demo based on Rembg
"""
with gr.Blocks(title="Remove background") as app:
gr.HTML("Remove Background Tool
")
with gr.Row(equal_height=False):
with gr.Column():
input_img = gr.Image(type="numpy", label="Input image")
drp_models = gr.Dropdown(choices=model_choices, label="Model Segment", value="U2NET")
with gr.Row():
chk_alm = gr.Checkbox(label="Alpha Matting", value=False)
chk_psm = gr.Checkbox(label="Post process mask", value=False)
chk_msk = gr.Checkbox(label="Only Mask", value=False)
sld_aft = gr.Slider(0, 300, value=0, step=1, label="Alpha matting foreground threshold")
sld_amb = gr.Slider(0, 50, value=0, step=1, label="Alpha matting background threshold")
sld_aes = gr.Slider(0, 20, value=0, step=1, label="Alpha matting erode size")
with gr.Box():
with gr.Row():
chk_col = gr.Checkbox(label="Change background color", value=False)
color = gr.ColorPicker(label="Pick a new color")
trans = gr.Number(label="Transparency level", value=255, precision=0, minimum=0, maximum=255)
run_btn = gr.Button(value="Remove background", variant="primary")
with gr.Column():
output_img = gr.Image(type="pil", label="result")
gr.ClearButton(components=[input_img, output_img])
chk_alm.change(alpha_matting, inputs=[chk_alm], outputs=[sld_aft, sld_amb, sld_aes])
run_btn.click(predict, [input_img, drp_models,
chk_alm, chk_msk, chk_psm,
sld_aft, sld_amb, sld_aes, chk_col, color, trans], [output_img])
with gr.Row():
gr.HTML(footer)
app.launch(share=False, debug=True, enable_queue=True, show_error=True)