#!/usr/bin/env python
from __future__ import annotations
import gradio as gr
import PIL.Image
import zipfile
from genTag import genTag
from checkIgnore import is_ignore
from createTagDom import create_tag_dom
def predict(image: PIL.Image.Image):
result_threshold = genTag(image, 0.5)
result_html = ''
for label, prob in result_threshold.items():
result_html += create_tag_dom(label, is_ignore(label, 1), prob)
result_html = '
' + result_html + '
'
result_filter = {key: value for key, value in result_threshold.items() if not is_ignore(key, 1)}
result_text = '' + ', '.join(result_filter.keys()) + '
'
return result_html, result_text
def predict_batch(zip_file, progress=gr.Progress()):
result = ''
with zipfile.ZipFile(zip_file) as zf:
for file in progress.tqdm(zf.namelist()):
print(file)
if file.endswith(".png") or file.endswith(".jpg") or file.endswith(".jpeg") or file.endswith(".webp"):
image_file = zf.open(file)
image = PIL.Image.open(image_file)
image = image.convert("RGBA")
result_threshold = genTag(image, 0.5)
result_filter = {key: value for key, value in result_threshold.items() if not is_ignore(key, 2)}
tag = ', '.join(result_filter.keys())
result = result + str(file) + '\n' + str(tag) + '\n\n'
return result
with gr.Blocks(head_paths="head.html") as demo:
with gr.Tab(label='Single'):
with gr.Row():
with gr.Column(scale=1):
image = gr.Image(label='Upload a image',
type='pil',
elem_classes='m5dd_image',
image_mode="RGBA",
show_fullscreen_button=False,
sources=["upload", "clipboard"])
result_text = gr.HTML(value="", elem_classes='m5dd_html', padding=False)
with gr.Column(scale=2):
result_html = gr.HTML(value="", elem_classes='m5dd_html', padding=False)
with gr.Tab(label='Batch'):
with gr.Row():
with gr.Column(scale=1):
batch_file = gr.File(label="Upload a ZIP file containing images",
file_types=['.zip'])
run_button2 = gr.Button('Run')
with gr.Column(scale=2):
result_text2 = gr.Textbox(lines=20,
max_lines=20,
label='Result',
show_copy_button=True,
autoscroll=False)
image.upload(
fn=predict,
inputs=[image],
outputs=[result_html, result_text],
api_name='predict',
)
run_button2.click(
fn=predict_batch,
inputs=[batch_file],
outputs=[result_text2],
api_name='predict_batch',
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()