File size: 1,223 Bytes
a3a3ae4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
from scepter.modules.utils.config import Config
from scepter.modules.utils.registry import Registry, build_from_config


def build_annotator(cfg, registry, logger=None, *args, **kwargs):
    """ After build model, load pretrained model if exists key `pretrain`.

    pretrain (str, dict): Describes how to load pretrained model.
        str, treat pretrain as model path;
        dict: should contains key `path`, and other parameters token by function load_pretrained();
    """
    if not isinstance(cfg, Config):
        raise TypeError(f'Config must be type dict, got {type(cfg)}')
    if cfg.have('PRETRAINED_MODEL'):
        pretrain_cfg = cfg.PRETRAINED_MODEL
        if pretrain_cfg is not None and not isinstance(pretrain_cfg, (str)):
            raise TypeError('Pretrain parameter must be a string')
    else:
        pretrain_cfg = None

    model = build_from_config(cfg, registry, logger=logger, *args, **kwargs)
    if pretrain_cfg is not None:
        if hasattr(model, 'load_pretrained_model'):
            model.load_pretrained_model(pretrain_cfg)
    return model


ANNOTATORS = Registry('ANNOTATORS', build_func=build_annotator)