|
import audiofile |
|
import numpy as np |
|
import torch |
|
from audiocraft.loaders import load_compression_model, load_lm_model |
|
from audiocraft.conditioners import ConditioningAttributes |
|
|
|
|
|
|
|
|
|
class AudioGen(): |
|
|
|
def __init__(self, |
|
compression_model=None, |
|
lm=None, |
|
duration=.74): |
|
|
|
self.compression_model = compression_model |
|
self.lm = lm |
|
self.duration = duration |
|
|
|
@property |
|
def frame_rate(self): |
|
return self.compression_model.frame_rate |
|
|
|
def generate(self, |
|
descriptions): |
|
with torch.no_grad(): |
|
attributes = [ |
|
ConditioningAttributes(text={'description': d}) for d in descriptions] |
|
gen_tokens = self.lm.generate( |
|
conditions=attributes, |
|
max_gen_len=int(self.duration * self.frame_rate)) |
|
x = self.compression_model.decode(gen_tokens, None) |
|
n_draw, _, n_time_samples = x.shape |
|
x = x.reshape(1, n_draw * n_time_samples) |
|
return x |
|
|
|
|
|
|
|
|
|
device = 'cuda:0' |
|
|
|
|
|
|
|
sound_generator = AudioGen( |
|
compression_model=load_compression_model('facebook/audiogen-medium', device=device).eval(), |
|
lm=load_lm_model('facebook/audiogen-medium', device=device).to(torch.float).eval(), |
|
duration=.74) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print('\n\n\n\n___________________') |
|
|
|
txt = 'dogs barging in the street' |
|
|
|
x = sound_generator.generate([txt])[0].detach().cpu().numpy() |
|
x /= np.abs(x).max() + 1e-7 |
|
|
|
audiofile.write('del_seane.wav', x, 16000) |
|
|