Spaces:
Running
on
Zero
Running
on
Zero
# -*- coding: utf-8 -*- | |
# Copyright (c) Alibaba, Inc. and its affiliates. | |
import cv2 | |
import numpy as np | |
def resize_image(input_image, resolution): | |
H, W, C = input_image.shape | |
H = float(H) | |
W = float(W) | |
k = float(resolution) / min(H, W) | |
H *= k | |
W *= k | |
H = int(np.round(H / 64.0)) * 64 | |
W = int(np.round(W / 64.0)) * 64 | |
img = cv2.resize( | |
input_image, (W, H), | |
interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) | |
return img, k | |
def resize_image_ori(h, w, image, k): | |
img = cv2.resize( | |
image, (w, h), | |
interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) | |
return img | |
class AnnotatorProcessor(): | |
canny_cfg = { | |
'NAME': 'CannyAnnotator', | |
'LOW_THRESHOLD': 100, | |
'HIGH_THRESHOLD': 200, | |
'INPUT_KEYS': ['img'], | |
'OUTPUT_KEYS': ['canny'] | |
} | |
hed_cfg = { | |
'NAME': 'HedAnnotator', | |
'PRETRAINED_MODEL': | |
'ms://damo/scepter_scedit@annotator/ckpts/ControlNetHED.pth', | |
'INPUT_KEYS': ['img'], | |
'OUTPUT_KEYS': ['hed'] | |
} | |
openpose_cfg = { | |
'NAME': 'OpenposeAnnotator', | |
'BODY_MODEL_PATH': | |
'ms://damo/scepter_scedit@annotator/ckpts/body_pose_model.pth', | |
'HAND_MODEL_PATH': | |
'ms://damo/scepter_scedit@annotator/ckpts/hand_pose_model.pth', | |
'INPUT_KEYS': ['img'], | |
'OUTPUT_KEYS': ['openpose'] | |
} | |
midas_cfg = { | |
'NAME': 'MidasDetector', | |
'PRETRAINED_MODEL': | |
'ms://damo/scepter_scedit@annotator/ckpts/dpt_hybrid-midas-501f0c75.pt', | |
'INPUT_KEYS': ['img'], | |
'OUTPUT_KEYS': ['depth'] | |
} | |
mlsd_cfg = { | |
'NAME': 'MLSDdetector', | |
'PRETRAINED_MODEL': | |
'ms://damo/scepter_scedit@annotator/ckpts/mlsd_large_512_fp32.pth', | |
'INPUT_KEYS': ['img'], | |
'OUTPUT_KEYS': ['mlsd'] | |
} | |
color_cfg = { | |
'NAME': 'ColorAnnotator', | |
'RATIO': 64, | |
'INPUT_KEYS': ['img'], | |
'OUTPUT_KEYS': ['color'] | |
} | |
anno_type_map = { | |
'canny': canny_cfg, | |
'hed': hed_cfg, | |
'pose': openpose_cfg, | |
'depth': midas_cfg, | |
'mlsd': mlsd_cfg, | |
'color': color_cfg | |
} | |
def __init__(self, anno_type): | |
from scepter.modules.annotator.registry import ANNOTATORS | |
from scepter.modules.utils.config import Config | |
from scepter.modules.utils.distribute import we | |
if isinstance(anno_type, str): | |
assert anno_type in self.anno_type_map.keys() | |
anno_type = [anno_type] | |
elif isinstance(anno_type, (list, tuple)): | |
assert all(tp in self.anno_type_map.keys() for tp in anno_type) | |
else: | |
raise Exception(f'Error anno_type: {anno_type}') | |
general_dict = { | |
'NAME': 'GeneralAnnotator', | |
'ANNOTATORS': [self.anno_type_map[tp] for tp in anno_type] | |
} | |
general_anno = Config(cfg_dict=general_dict, load=False) | |
self.general_ins = ANNOTATORS.build(general_anno).to(we.device_id) | |
def run(self, image, anno_type=None): | |
output_image = self.general_ins({'img': image}) | |
if anno_type is not None: | |
if isinstance(anno_type, str) and anno_type in output_image: | |
return output_image[anno_type] | |
else: | |
return { | |
tp: output_image[tp] | |
for tp in anno_type if tp in output_image | |
} | |
else: | |
return output_image | |