Spaces:
Runtime error
Runtime error
import gradio as gr | |
import sys | |
import cv2 | |
from scripts.td_abg import get_foreground | |
from scripts.convertor import pil2cv | |
class webui: | |
def __init__(self): | |
self.demo = gr.Blocks() | |
def processing(self, input_image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L): | |
image = pil2cv(input_image) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
mask, image = get_foreground(image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L) | |
return image, mask | |
def launch(self, share): | |
with self.demo: | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type="pil") | |
with gr.Accordion("tile division ABG", open=True): | |
with gr.Box(): | |
td_abg_enabled = gr.Checkbox(label="enabled", show_label=True) | |
h_split = gr.Slider(1, 2048, value=256, step=4, label="horizontal split num", show_label=True) | |
v_split = gr.Slider(1, 2048, value=256, step=4, label="vertical split num", show_label=True) | |
n_cluster = gr.Slider(1, 1000, value=500, step=10, label="cluster num", show_label=True) | |
alpha = gr.Slider(1, 255, value=100, step=1, label="alpha threshold", show_label=True) | |
th_rate = gr.Slider(0, 1, value=0.1, step=0.01, label="mask content ratio", show_label=True) | |
with gr.Accordion("cascadePSP", open=True): | |
with gr.Box(): | |
cascadePSP_enabled = gr.Checkbox(label="enabled", show_label=True) | |
fast = gr.Checkbox(label="fast", show_label=True) | |
psp_L = gr.Slider(1, 2048, value=900, step=1, label="Memory usage", show_label=True) | |
submit = gr.Button(value="Submit") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Tab("output"): | |
output_img = gr.Image() | |
with gr.Tab("mask"): | |
output_mask = gr.Image() | |
submit.click( | |
self.processing, | |
inputs=[input_image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L], | |
outputs=[output_img, output_mask] | |
) | |
self.demo.queue() | |
self.demo.launch(share=share) | |
if __name__ == "__main__": | |
ui = webui() | |
if len(sys.argv) > 1: | |
if sys.argv[1] == "share": | |
ui.launch(share=True) | |
else: | |
ui.launch(share=False) | |
else: | |
ui.launch(share=False) | |