import PIL.Image import gradio as gr from PIL import ImageColor from rembg import new_session, remove from utils import * remove_bg_models = { "U2NET": "u2net", "U2NET Human Seg": "u2net_human_seg", "U2NET Cloth Seg": "u2net_cloth_seg" } model_choices = keys(remove_bg_models) def alpha_matting(state): if state: return 270, 20, 11 return 0, 0, 0 def predict(image, session, matting, only_mask, post_process_mask, foreground_threshold, background_threshold, matting_erode_size, new_bg_color, bg_color, transparency): session = new_session(remove_bg_models[session]) if new_bg_color: r, g, b, _ = ImageColor.getcolor(bg_color, "RGBA") bg_color = r, g, b, transparency else: bg_color = None return remove(data=image, session=session, alpha_matting=matting, only_mask=only_mask, post_process_mask=post_process_mask, bgcolor=bg_color, alpha_matting_foreground_threshold=foreground_threshold, alpha_matting_background_threshold=background_threshold, alpha_matting_erode_size=matting_erode_size) footer = r"""
Demo based on Rembg
""" with gr.Blocks(title="Remove background") as app: gr.HTML("

Remove Background Tool

") with gr.Row(equal_height=False): with gr.Column(): input_img = gr.Image(type="numpy", label="Input image") drp_models = gr.Dropdown(choices=model_choices, label="Model Segment", value="U2NET") with gr.Row(): chk_alm = gr.Checkbox(label="Alpha Matting", value=False) chk_psm = gr.Checkbox(label="Post process mask", value=False) chk_msk = gr.Checkbox(label="Only Mask", value=False) sld_aft = gr.Slider(0, 300, value=0, step=1, label="Alpha matting foreground threshold") sld_amb = gr.Slider(0, 50, value=0, step=1, label="Alpha matting background threshold") sld_aes = gr.Slider(0, 20, value=0, step=1, label="Alpha matting erode size") with gr.Box(): with gr.Row(): chk_col = gr.Checkbox(label="Change background color", value=False) color = gr.ColorPicker(label="Pick a new color") trans = gr.Number(label="Transparency level", value=255, precision=0, minimum=0, maximum=255) run_btn = gr.Button(value="Remove background", variant="primary") with gr.Column(): output_img = gr.Image(type="pil", label="result") gr.ClearButton(components=[input_img, output_img]) chk_alm.change(alpha_matting, inputs=[chk_alm], outputs=[sld_aft, sld_amb, sld_aes]) run_btn.click(predict, [input_img, drp_models, chk_alm, chk_msk, chk_psm, sld_aft, sld_amb, sld_aes, chk_col, color, trans], [output_img]) with gr.Row(): gr.HTML(footer) app.launch(share=False, debug=True, enable_queue=True, show_error=True)