adefossez commited on
Commit
fad2862
·
1 Parent(s): 6457900

adding support for cpu

Browse files
audiocraft/models/loaders.py CHANGED
@@ -80,8 +80,6 @@ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_di
80
  cfg = OmegaConf.create(pkg['xp.cfg'])
81
  cfg.device = str(device)
82
  if cfg.device == 'cpu':
83
- cfg.transformer_lm.memory_efficient = False
84
- cfg.transformer_lm.custom = True
85
  cfg.dtype = 'float32'
86
  else:
87
  cfg.dtype = 'float16'
 
80
  cfg = OmegaConf.create(pkg['xp.cfg'])
81
  cfg.device = str(device)
82
  if cfg.device == 'cpu':
 
 
83
  cfg.dtype = 'float32'
84
  else:
85
  cfg.dtype = 'float16'
audiocraft/models/musicgen.py CHANGED
@@ -68,7 +68,7 @@ class MusicGen:
68
  return self.compression_model.channels
69
 
70
  @staticmethod
71
- def get_pretrained(name: str = 'melody', device='cuda'):
72
  """Return pretrained model, we provide four models:
73
  - small (300M), text to music, # see: https://huggingface.co/facebook/musicgen-small
74
  - medium (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-medium
@@ -76,11 +76,17 @@ class MusicGen:
76
  - large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-large
77
  """
78
 
 
 
 
 
 
 
79
  if name == 'debug':
80
  # used only for unit tests
81
  compression_model = get_debug_compression_model(device)
82
  lm = get_debug_lm_model(device)
83
- return MusicGen(name, compression_model, lm, max_duration=3.)
84
 
85
  if name not in HF_MODEL_CHECKPOINTS_MAP:
86
  raise ValueError(
@@ -313,7 +319,6 @@ class MusicGen:
313
  all_tokens.append(prompt_tokens)
314
  prompt_length = prompt_tokens.shape[-1]
315
 
316
-
317
  stride_tokens = int(self.frame_rate * self.extend_stride)
318
 
319
  while current_gen_offset + prompt_length < total_gen_len:
 
68
  return self.compression_model.channels
69
 
70
  @staticmethod
71
+ def get_pretrained(name: str = 'melody', device=None):
72
  """Return pretrained model, we provide four models:
73
  - small (300M), text to music, # see: https://huggingface.co/facebook/musicgen-small
74
  - medium (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-medium
 
76
  - large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-large
77
  """
78
 
79
+ if device is None:
80
+ if torch.cuda.device_count():
81
+ device = 'cuda'
82
+ else:
83
+ device = 'cpu'
84
+
85
  if name == 'debug':
86
  # used only for unit tests
87
  compression_model = get_debug_compression_model(device)
88
  lm = get_debug_lm_model(device)
89
+ return MusicGen(name, compression_model, lm)
90
 
91
  if name not in HF_MODEL_CHECKPOINTS_MAP:
92
  raise ValueError(
 
319
  all_tokens.append(prompt_tokens)
320
  prompt_length = prompt_tokens.shape[-1]
321
 
 
322
  stride_tokens = int(self.frame_rate * self.extend_stride)
323
 
324
  while current_gen_offset + prompt_length < total_gen_len:
tests/models/test_musicgen.py CHANGED
@@ -51,6 +51,7 @@ class TestSEANetModel:
51
 
52
  def test_generate_long(self):
53
  mg = self.get_musicgen()
 
54
  mg.set_generation_params(duration=4., stride_extend=2.)
55
  wav = mg.generate(
56
  ['youpi', 'lapin dort'])
 
51
 
52
  def test_generate_long(self):
53
  mg = self.get_musicgen()
54
+ mg.max_duration = 3.
55
  mg.set_generation_params(duration=4., stride_extend=2.)
56
  wav = mg.generate(
57
  ['youpi', 'lapin dort'])