#!/usr/bin/env python from __future__ import annotations import argparse import os import pathlib import subprocess import tarfile if os.getenv('SYSTEM') == 'spaces': import mim mim.uninstall('mmcv-full', confirm_yes=True) mim.install('mmcv-full==1.5.2', is_yes=True) subprocess.call('pip uninstall -y opencv-python'.split()) subprocess.call('pip uninstall -y opencv-python-headless'.split()) subprocess.call('pip install opencv-python-headless'.split()) import cv2 import gradio as gr import numpy as np from model import Model DEFAULT_MODEL_TYPE = 'detection' DEFAULT_MODEL_NAMES = { 'detection': 'YOLOX-l', 'instance_segmentation': 'QueryInst (R-50-FPN)', 'panoptic_segmentation': 'MaskFormer (R-50)', } DEFAULT_MODEL_NAME = DEFAULT_MODEL_NAMES[DEFAULT_MODEL_TYPE] def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cpu') parser.add_argument('--theme', type=str) parser.add_argument('--share', action='store_true') parser.add_argument('--port', type=int) parser.add_argument('--disable-queue', dest='enable_queue', action='store_false') return parser.parse_args() def extract_tar() -> None: if pathlib.Path('mmdet_configs/configs').exists(): return with tarfile.open('mmdet_configs/configs.tar') as f: f.extractall('mmdet_configs') def update_input_image(image: np.ndarray) -> dict: if image is None: return gr.Image.update(value=None) scale = 1500 / max(image.shape[:2]) if scale < 1: image = cv2.resize(image, None, fx=scale, fy=scale) return gr.Image.update(value=image) def update_model_name(model_type: str) -> dict: model_dict = getattr(Model, f'{model_type.upper()}_MODEL_DICT') model_names = list(model_dict.keys()) model_name = DEFAULT_MODEL_NAMES[model_type] return gr.Dropdown.update(choices=model_names, value=model_name) def update_visualization_score_threshold(model_type: str) -> dict: return gr.Slider.update(visible=model_type != 'panoptic_segmentation') def update_redraw_button(model_type: str) -> dict: return gr.Button.update(visible=model_type != 'panoptic_segmentation') def set_example_image(example: list) -> dict: return gr.Image.update(value=example[0]) def main(): args = parse_args() extract_tar() model = Model(DEFAULT_MODEL_NAME, args.device) css = ''' h1#title { text-align: center; } img#overview { max-width: 1000px; max-height: 600px; } ''' with gr.Blocks(theme=args.theme, css=css) as demo: gr.Markdown('''