aki-0421
F: add
a3a3ae4 unverified
raw
history blame
3.5 kB
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import cv2
import numpy as np
def resize_image(input_image, resolution):
H, W, C = input_image.shape
H = float(H)
W = float(W)
k = float(resolution) / min(H, W)
H *= k
W *= k
H = int(np.round(H / 64.0)) * 64
W = int(np.round(W / 64.0)) * 64
img = cv2.resize(
input_image, (W, H),
interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
return img, k
def resize_image_ori(h, w, image, k):
img = cv2.resize(
image, (w, h),
interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
return img
class AnnotatorProcessor():
canny_cfg = {
'NAME': 'CannyAnnotator',
'LOW_THRESHOLD': 100,
'HIGH_THRESHOLD': 200,
'INPUT_KEYS': ['img'],
'OUTPUT_KEYS': ['canny']
}
hed_cfg = {
'NAME': 'HedAnnotator',
'PRETRAINED_MODEL':
'ms://damo/scepter_scedit@annotator/ckpts/ControlNetHED.pth',
'INPUT_KEYS': ['img'],
'OUTPUT_KEYS': ['hed']
}
openpose_cfg = {
'NAME': 'OpenposeAnnotator',
'BODY_MODEL_PATH':
'ms://damo/scepter_scedit@annotator/ckpts/body_pose_model.pth',
'HAND_MODEL_PATH':
'ms://damo/scepter_scedit@annotator/ckpts/hand_pose_model.pth',
'INPUT_KEYS': ['img'],
'OUTPUT_KEYS': ['openpose']
}
midas_cfg = {
'NAME': 'MidasDetector',
'PRETRAINED_MODEL':
'ms://damo/scepter_scedit@annotator/ckpts/dpt_hybrid-midas-501f0c75.pt',
'INPUT_KEYS': ['img'],
'OUTPUT_KEYS': ['depth']
}
mlsd_cfg = {
'NAME': 'MLSDdetector',
'PRETRAINED_MODEL':
'ms://damo/scepter_scedit@annotator/ckpts/mlsd_large_512_fp32.pth',
'INPUT_KEYS': ['img'],
'OUTPUT_KEYS': ['mlsd']
}
color_cfg = {
'NAME': 'ColorAnnotator',
'RATIO': 64,
'INPUT_KEYS': ['img'],
'OUTPUT_KEYS': ['color']
}
anno_type_map = {
'canny': canny_cfg,
'hed': hed_cfg,
'pose': openpose_cfg,
'depth': midas_cfg,
'mlsd': mlsd_cfg,
'color': color_cfg
}
def __init__(self, anno_type):
from scepter.modules.annotator.registry import ANNOTATORS
from scepter.modules.utils.config import Config
from scepter.modules.utils.distribute import we
if isinstance(anno_type, str):
assert anno_type in self.anno_type_map.keys()
anno_type = [anno_type]
elif isinstance(anno_type, (list, tuple)):
assert all(tp in self.anno_type_map.keys() for tp in anno_type)
else:
raise Exception(f'Error anno_type: {anno_type}')
general_dict = {
'NAME': 'GeneralAnnotator',
'ANNOTATORS': [self.anno_type_map[tp] for tp in anno_type]
}
general_anno = Config(cfg_dict=general_dict, load=False)
self.general_ins = ANNOTATORS.build(general_anno).to(we.device_id)
def run(self, image, anno_type=None):
output_image = self.general_ins({'img': image})
if anno_type is not None:
if isinstance(anno_type, str) and anno_type in output_image:
return output_image[anno_type]
else:
return {
tp: output_image[tp]
for tp in anno_type if tp in output_image
}
else:
return output_image