Spaces:
Runtime error
Runtime error
File size: 2,198 Bytes
413d4d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
# ==========================================================
# Text-to-Video Generation
from .lavie import LaVie
from .videocrafter import VideoCrafter2
from .modelscope import ModelScope
from .streamingt2v import StreamingT2V
from .show_one import ShowOne
from .opensora import OpenSora
from .opensora_plan import OpenSoraPlan
from .t2v_turbo import T2VTurbo
from .opensora_12 import OpenSora12
from .cogvideox import CogVideoX
# from .cogvideo import CogVideo # Not supporting CogVideo ATM
# ==========================================================
# Image-to-Video Generation
from .seine import SEINE
from .consisti2v import ConsistI2V
from .dynamicrafter import DynamiCrafter
from .i2vgen_xl import I2VGenXL
# ==========================================================
import sys
from functools import partial
def get_model(model_name: str = None, init_with_default_params: bool = True):
"""
Retrieves a model class or instance by its name.
Args:
model_name (str): Name of the model class. Triggers an error if the module name does not exist.
init_with_default_params (bool, optional): If True, returns an initialized model instance; otherwise, returns
the model class. Default is True. If set to True, be cautious of potential ``OutOfMemoryError`` with insufficient CUDA memory.
Returns:
model_class or model_instance: Depending on ``init_with_default_params``, either the model class or an instance of the model.
Examples::
initialized_model = infermodels.get_model(model_name='<Model>', init_with_default_params=True)
uninitialized_model = infermodels.get_model(model_name='<Model>', init_with_default_params=False)
initialized_model = uninitialized_model(device="cuda", <...>)
"""
if not hasattr(sys.modules[__name__], model_name):
raise ValueError(f"No model named {model_name} found in infermodels.")
model_class = getattr(sys.modules[__name__], model_name)
if init_with_default_params:
model_instance = model_class()
return model_instance
return model_class
load_model = partial(get_model, init_with_default_params=True)
load = partial(get_model)
|