Spaces:
Runtime error
Runtime error
# ========================================================== | |
# Text-to-Video Generation | |
from .lavie import LaVie | |
from .videocrafter import VideoCrafter2 | |
from .modelscope import ModelScope | |
from .streamingt2v import StreamingT2V | |
from .show_one import ShowOne | |
from .opensora import OpenSora | |
from .opensora_plan import OpenSoraPlan | |
from .t2v_turbo import T2VTurbo | |
from .opensora_12 import OpenSora12 | |
from .cogvideox import CogVideoX | |
# from .cogvideo import CogVideo # Not supporting CogVideo ATM | |
# ========================================================== | |
# Image-to-Video Generation | |
from .seine import SEINE | |
from .consisti2v import ConsistI2V | |
from .dynamicrafter import DynamiCrafter | |
from .i2vgen_xl import I2VGenXL | |
# ========================================================== | |
import sys | |
from functools import partial | |
def get_model(model_name: str = None, init_with_default_params: bool = True): | |
""" | |
Retrieves a model class or instance by its name. | |
Args: | |
model_name (str): Name of the model class. Triggers an error if the module name does not exist. | |
init_with_default_params (bool, optional): If True, returns an initialized model instance; otherwise, returns | |
the model class. Default is True. If set to True, be cautious of potential ``OutOfMemoryError`` with insufficient CUDA memory. | |
Returns: | |
model_class or model_instance: Depending on ``init_with_default_params``, either the model class or an instance of the model. | |
Examples:: | |
initialized_model = infermodels.get_model(model_name='<Model>', init_with_default_params=True) | |
uninitialized_model = infermodels.get_model(model_name='<Model>', init_with_default_params=False) | |
initialized_model = uninitialized_model(device="cuda", <...>) | |
""" | |
if not hasattr(sys.modules[__name__], model_name): | |
raise ValueError(f"No model named {model_name} found in infermodels.") | |
model_class = getattr(sys.modules[__name__], model_name) | |
if init_with_default_params: | |
model_instance = model_class() | |
return model_instance | |
return model_class | |
load_model = partial(get_model, init_with_default_params=True) | |
load = partial(get_model) | |