# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import typing as tp import torch from audiocraft.loaders import load_compression_model, load_lm_model import typing as tp import omegaconf import torch import numpy as np from abc import ABC, abstractmethod from .lm import LMModel from .conditioners import ConditioningAttributes from .utils.autocast import TorchAutocast def _shift(x): n = x.shape[2] i = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD do we have very short segments x = torch.roll(x, i, dims=2) return x class BaseGenModel(ABC): """Base generative model with convenient generation API. Args: name (str) compression_model (CompressionModel): Encodec with Seanet Decoder lm max_duration (float, optional): As is using top250 token draw() we can gen xN sequences """ def __init__(self, name, compression_model, lm, max_duration=None): self.name = name self.compression_model = compression_model self.lm = lm self.cfg: tp.Optional[omegaconf.DictConfig] = None # Just to be safe, let's put everything in eval mode. self.compression_model.eval() self.lm.eval() if hasattr(lm, 'cfg'): cfg = lm.cfg assert isinstance(cfg, omegaconf.DictConfig) self.cfg = cfg if max_duration is None: if self.cfg is not None: max_duration = lm.cfg.dataset.segment_duration # type: ignore else: raise ValueError("You must provide max_duration when building directly your GenModel") assert max_duration is not None self.max_duration: float = max_duration self.duration = self.max_duration self.device = next(iter(lm.parameters())).device self.generation_params={} if self.device.type == 'cpu': self.autocast = TorchAutocast(enabled=False) else: self.autocast = TorchAutocast( enabled=True, device_type=self.device.type, dtype=torch.float16) @property def frame_rate(self) -> float: """Roughly the number of AR steps per seconds.""" return self.compression_model.frame_rate @property def sample_rate(self) -> int: """Sample rate of the generated audio.""" return self.compression_model.sample_rate def generate(self, descriptions): attributes = [ ConditioningAttributes(text={'description': d}) for d in descriptions] tokens = self._generate_tokens(attributes) return self.generate_audio(tokens) def _generate_tokens(self, attributes): total_gen_len = int(self.duration * self.frame_rate) # # print(f'{self.generation_params=}') # self.generation_params={'use_sampling': True, # 'temp': 1.0, 'top_k': 250, # 'top_p': 0.0, 'cfg_coef': 2.4, 'two_step_cfg': False} if self.duration <= self.max_duration: # generate by sampling from LM, simple case. with self.autocast: gen_tokens = self.lm.generate(conditions=attributes, max_gen_len=total_gen_len) else: print('<>Long gen ?<>') # print(f'{gen_tokens.shape=}') # [5,4,35] # FLATTEN BATCH AS EXTRA SEQUENCE (BATCH IS VIRTUAL JUST MULTINOMIAL SAMPLING OF N_DRAW TOKENS) gen_tokens = gen_tokens.transpose(0, 1).reshape(4, -1)[None, :, :] for _ in range(3): print(gen_tokens.shape) gen_tokens = _shift(gen_tokens) return gen_tokens def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor: """Generate Audio from tokens.""" assert gen_tokens.dim() == 3 with torch.no_grad(): gen_audio = self.compression_model.decode(gen_tokens, None) return gen_audio class AudioGen(BaseGenModel): def __init__(self, name, compression_model, lm, max_duration=None): # print(f'Using {compression_model=}\n-----=-----') super().__init__(name, compression_model, lm, max_duration) self.set_generation_params(duration=5) # default duration @staticmethod def get_pretrained(name: str = 'facebook/audiogen-medium', device=None): """Return pretrained model, we provide a single model for now: - facebook/audiogen-medium (1.5B), text to sound, # see: https://huggingface.co/facebook/audiogen-medium """ if device is None: if torch.cuda.device_count(): device = 'cuda' else: device = 'cpu' compression_model = load_compression_model(name, device=device) lm = load_lm_model(name, device=device) assert 'self_wav' not in lm.condition_provider.conditioners, \ "AudioGen do not support waveform conditioning for now" return AudioGen(name, compression_model, lm) def set_generation_params(self, use_sampling: bool = True, top_k: int = 250, top_p: float = 0.0, temperature: float = 1.0, duration: float = 10.0, cfg_coef: float = 2.4, two_step_cfg: bool = False, extend_stride: float = 2): """Set the generation parameters for AudioGen. Args: use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True. top_k (int, optional): top_k used for sampling. Defaults to 250. top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0. temperature (float, optional): Softmax temperature parameter. Defaults to 1.0. duration (float, optional): Duration of the generated waveform. Defaults to 10.0. cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0. two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance, instead of batching together the two. This has some impact on how things are padded but seems to have little impact in practice. extend_stride: when doing extended generation (i.e. more than 10 seconds), by how much should we extend the audio each time. Larger values will mean less context is preserved, and shorter value will require extra computations. """ assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration." self.extend_stride = extend_stride self.duration = duration self.generation_params = { 'use_sampling': use_sampling, 'temp': temperature, 'top_k': top_k, 'top_p': top_p, 'cfg_coef': cfg_coef, 'two_step_cfg': two_step_cfg, }