File size: 2,198 Bytes
413d4d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# ==========================================================
# Text-to-Video Generation
from .lavie import LaVie
from .videocrafter import VideoCrafter2
from .modelscope import ModelScope
from .streamingt2v import StreamingT2V
from .show_one import ShowOne
from .opensora import OpenSora
from .opensora_plan import OpenSoraPlan
from .t2v_turbo import T2VTurbo
from .opensora_12 import OpenSora12
from .cogvideox import CogVideoX

# from .cogvideo import CogVideo # Not supporting CogVideo ATM

# ==========================================================
# Image-to-Video Generation
from .seine import SEINE
from .consisti2v import ConsistI2V
from .dynamicrafter import DynamiCrafter
from .i2vgen_xl import I2VGenXL

# ==========================================================

import sys
from functools import partial


def get_model(model_name: str = None, init_with_default_params: bool = True):
    """
    Retrieves a model class or instance by its name.

    Args:
        model_name (str): Name of the model class. Triggers an error if the module name does not exist.
        init_with_default_params (bool, optional): If True, returns an initialized model instance; otherwise, returns
            the model class. Default is True. If set to True, be cautious of potential ``OutOfMemoryError`` with insufficient CUDA memory.

    Returns:
        model_class or model_instance: Depending on ``init_with_default_params``, either the model class or an instance of the model.

    Examples::
        initialized_model = infermodels.get_model(model_name='<Model>', init_with_default_params=True)

        uninitialized_model = infermodels.get_model(model_name='<Model>', init_with_default_params=False)
        initialized_model = uninitialized_model(device="cuda", <...>)
    """

    if not hasattr(sys.modules[__name__], model_name):
        raise ValueError(f"No model named {model_name} found in infermodels.")

    model_class = getattr(sys.modules[__name__], model_name)
    if init_with_default_params:
        model_instance = model_class()
        return model_instance
    return model_class


load_model = partial(get_model, init_with_default_params=True)
load = partial(get_model)