|
|
|
|
|
|
|
|
|
|
|
|
|
import typing as tp |
|
import omegaconf |
|
import torch |
|
from .encodec import EncodecModel |
|
from .lm import LMModel |
|
from .seanet import SEANetDecoder |
|
from .codebooks_patterns import DelayedPatternProvider |
|
from .conditioners import ( |
|
BaseConditioner, |
|
ConditionFuser, |
|
ConditioningProvider, |
|
T5Conditioner, |
|
) |
|
from .unet import DiffusionUnet |
|
from .vq import ResidualVectorQuantizer |
|
from .utils.utils import dict_from_config |
|
from .diffusion_schedule import MultiBandProcessor, SampleProcessor |
|
|
|
|
|
def get_quantizer(quantizer, cfg, dimension): |
|
klass = { |
|
'no_quant': None, |
|
'rvq': ResidualVectorQuantizer |
|
}[quantizer] |
|
kwargs = dict_from_config(getattr(cfg, quantizer)) |
|
if quantizer != 'no_quant': |
|
kwargs['dimension'] = dimension |
|
return klass(**kwargs) |
|
|
|
|
|
def get_encodec_autoencoder(cfg): |
|
kwargs = dict_from_config(getattr(cfg, 'seanet')) |
|
_ = kwargs.pop('encoder') |
|
decoder_override_kwargs = kwargs.pop('decoder') |
|
decoder_kwargs = {**kwargs, **decoder_override_kwargs} |
|
decoder = SEANetDecoder(**decoder_kwargs) |
|
return decoder |
|
|
|
|
|
|
|
def get_compression_model(cfg): |
|
"""Instantiate a compression model.""" |
|
if cfg.compression_model == 'encodec': |
|
kwargs = dict_from_config(getattr(cfg, 'encodec')) |
|
quantizer_name = kwargs.pop('quantizer') |
|
decoder = get_encodec_autoencoder(cfg) |
|
quantizer = get_quantizer(quantizer_name, cfg, 128) |
|
renormalize = kwargs.pop('renormalize', False) |
|
|
|
|
|
kwargs.pop('renorm', None) |
|
|
|
|
|
|
|
|
|
|
|
return EncodecModel(decoder=decoder, |
|
quantizer=quantizer, |
|
frame_rate=50, |
|
renormalize=renormalize, |
|
sample_rate=16000, |
|
channels=1, |
|
causal=False |
|
).to(cfg.device) |
|
else: |
|
raise KeyError(f"Unexpected compression model {cfg.compression_model}") |
|
|
|
|
|
def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel: |
|
"""Instantiate a transformer LM.""" |
|
if cfg.lm_model in ['transformer_lm', 'transformer_lm_magnet']: |
|
kwargs = dict_from_config(getattr(cfg, 'transformer_lm')) |
|
n_q = kwargs['n_q'] |
|
q_modeling = kwargs.pop('q_modeling', None) |
|
codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern') |
|
attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout')) |
|
cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance')) |
|
cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef'] |
|
fuser = get_condition_fuser(cfg) |
|
condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device) |
|
if len(fuser.fuse2cond['cross']) > 0: |
|
kwargs['cross_attention'] = True |
|
if codebooks_pattern_cfg.modeling is None: |
|
assert q_modeling is not None, \ |
|
"LM model should either have a codebook pattern defined or transformer_lm.q_modeling" |
|
codebooks_pattern_cfg = omegaconf.OmegaConf.create( |
|
{'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}} |
|
) |
|
|
|
pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg) |
|
|
|
lm_class = LMModel |
|
print(f'{lm_class=}\n\n\n\n=====================') |
|
return lm_class( |
|
pattern_provider=pattern_provider, |
|
condition_provider=condition_provider, |
|
fuser=fuser, |
|
cfg_dropout=cfg_prob, |
|
cfg_coef=cfg_coef, |
|
attribute_dropout=attribute_dropout, |
|
dtype=getattr(torch, cfg.dtype), |
|
device=cfg.device, |
|
**kwargs |
|
).to(cfg.device) |
|
else: |
|
raise KeyError(f"Unexpected LM model {cfg.lm_model}") |
|
|
|
|
|
def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider: |
|
"""Instantiate a conditioning model.""" |
|
device = cfg.device |
|
duration = cfg.dataset.segment_duration |
|
cfg = getattr(cfg, 'conditioners') |
|
dict_cfg = {} if cfg is None else dict_from_config(cfg) |
|
conditioners: tp.Dict[str, BaseConditioner] = {} |
|
condition_provider_args = dict_cfg.pop('args', {}) |
|
condition_provider_args.pop('merge_text_conditions_p', None) |
|
condition_provider_args.pop('drop_desc_p', None) |
|
|
|
for cond, cond_cfg in dict_cfg.items(): |
|
model_type = cond_cfg['model'] |
|
model_args = cond_cfg[model_type] |
|
if model_type == 't5': |
|
conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args) |
|
else: |
|
raise ValueError(f"Unrecognized conditioning model: {model_type}") |
|
conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args) |
|
return conditioner |
|
|
|
|
|
def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser: |
|
"""Instantiate a condition fuser object.""" |
|
fuser_cfg = getattr(cfg, 'fuser') |
|
fuser_methods = ['sum', 'cross', 'prepend', 'input_interpolate'] |
|
fuse2cond = {k: fuser_cfg[k] for k in fuser_methods} |
|
kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods} |
|
fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs) |
|
return fuser |
|
|
|
|
|
def get_codebooks_pattern_provider(n_q, cfg): |
|
pattern_providers = { |
|
'delay': DelayedPatternProvider, |
|
} |
|
name = cfg.modeling |
|
kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {} |
|
|
|
klass = pattern_providers[name] |
|
return klass(n_q, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
def get_diffusion_model(cfg: omegaconf.DictConfig): |
|
|
|
channels = cfg.channels |
|
num_steps = cfg.schedule.num_steps |
|
return DiffusionUnet( |
|
chin=channels, num_steps=num_steps, **cfg.diffusion_unet) |
|
|
|
|
|
def get_processor(cfg, sample_rate: int = 24000): |
|
sample_processor = SampleProcessor() |
|
if cfg.use: |
|
kw = dict(cfg) |
|
kw.pop('use') |
|
kw.pop('name') |
|
if cfg.name == "multi_band_processor": |
|
sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw) |
|
return sample_processor |
|
|