character-360 / annotator /base_annotator.py
aki-0421
F: add
a3a3ae4 unverified
raw
history blame
2.1 kB
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
from abc import ABCMeta
import torch
import torch.nn as nn
from scepter.modules.annotator.registry import ANNOTATORS
from scepter.modules.model.base_model import BaseModel
from scepter.modules.utils.config import dict_to_yaml
@ANNOTATORS.register_class()
class BaseAnnotator(BaseModel, metaclass=ABCMeta):
para_dict = {}
def __init__(self, cfg, logger=None):
super().__init__(cfg, logger=logger)
@torch.no_grad()
@torch.inference_mode
def forward(self, *args, **kwargs):
raise NotImplementedError
@staticmethod
def get_config_template():
return dict_to_yaml('ANNOTATORS',
__class__.__name__,
BaseAnnotator.para_dict,
set_name=True)
@ANNOTATORS.register_class()
class GeneralAnnotator(BaseAnnotator, metaclass=ABCMeta):
def __init__(self, cfg, logger=None):
super().__init__(cfg, logger=logger)
anno_models = cfg.get('ANNOTATORS', [])
self.annotators = nn.ModuleList()
for n, anno_config in enumerate(anno_models):
annotator = ANNOTATORS.build(anno_config, logger=logger)
annotator.input_keys = anno_config.get('INPUT_KEYS', [])
if isinstance(annotator.input_keys, str):
annotator.input_keys = [annotator.input_keys]
annotator.output_keys = anno_config.get('OUTPUT_KEYS', [])
if isinstance(annotator.output_keys, str):
annotator.output_keys = [annotator.output_keys]
assert len(annotator.input_keys) == len(annotator.output_keys)
self.annotators.append(annotator)
def forward(self, input_dict):
output_dict = {}
for annotator in self.annotators:
for idx, in_key in enumerate(annotator.input_keys):
if in_key in input_dict:
image = annotator(input_dict[in_key])
output_dict[annotator.output_keys[idx]] = image
return output_dict