patrickvonplaten
commited on
Add hf loading & improve a couple of things on the README (#2)
Browse files- README.md +11 -5
- app.py +2 -2
- app_batched.py +2 -2
- audiocraft/models/loaders.py +37 -10
- audiocraft/models/musicgen.py +15 -20
- hf_loading.py +0 -61
- mypy.ini +1 -1
- requirements.txt +1 -0
README.md
CHANGED
@@ -40,15 +40,21 @@ You can play with MusicGen by running the jupyter notebook at [`demo.ipynb`](./d
|
|
40 |
## API
|
41 |
|
42 |
We provide a simple API and 4 pre-trained models. The pre trained models are:
|
43 |
-
- `small`: 300M model, text to music only
|
44 |
-
- `medium`: 1.5B model, text to music only
|
45 |
-
- `melody`: 1.5B model, text to music and text+melody to music
|
46 |
-
- `large`: 3.3B model, text to music only.
|
47 |
|
48 |
We observe the best trade-off between quality and compute with the `medium` or `melody` model.
|
49 |
In order to use MusicGen locally **you must have a GPU**. We recommend 16GB of memory, but smaller
|
50 |
GPUs will be able to generate short sequences, or longer sequences with the `small` model.
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
See after a quick example for using the API.
|
53 |
|
54 |
```python
|
@@ -68,7 +74,7 @@ wav = model.generate_with_chroma(descriptions, melody[None].expand(3, -1, -1), s
|
|
68 |
|
69 |
for idx, one_wav in enumerate(wav):
|
70 |
# Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
|
71 |
-
audio_write(f'{idx}', one_wav, model.sample_rate, strategy="loudness")
|
72 |
```
|
73 |
|
74 |
|
|
|
40 |
## API
|
41 |
|
42 |
We provide a simple API and 4 pre-trained models. The pre trained models are:
|
43 |
+
- `small`: 300M model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-small)
|
44 |
+
- `medium`: 1.5B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-medium)
|
45 |
+
- `melody`: 1.5B model, text to music and text+melody to music - [🤗 Hub](https://huggingface.co/facebook/musicgen-melody)
|
46 |
+
- `large`: 3.3B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-large)
|
47 |
|
48 |
We observe the best trade-off between quality and compute with the `medium` or `melody` model.
|
49 |
In order to use MusicGen locally **you must have a GPU**. We recommend 16GB of memory, but smaller
|
50 |
GPUs will be able to generate short sequences, or longer sequences with the `small` model.
|
51 |
|
52 |
+
**Note**: Please make sure to have [ffmpeg](https://ffmpeg.org/download.html) installed when using newer version of `torchaudio`.
|
53 |
+
You can install it with:
|
54 |
+
```
|
55 |
+
apt get install ffmpeg
|
56 |
+
```
|
57 |
+
|
58 |
See after a quick example for using the API.
|
59 |
|
60 |
```python
|
|
|
74 |
|
75 |
for idx, one_wav in enumerate(wav):
|
76 |
# Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
|
77 |
+
audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness")
|
78 |
```
|
79 |
|
80 |
|
app.py
CHANGED
@@ -8,7 +8,7 @@ LICENSE file in the root directory of this source tree.
|
|
8 |
|
9 |
import torch
|
10 |
import gradio as gr
|
11 |
-
from
|
12 |
|
13 |
|
14 |
MODEL = None
|
@@ -16,7 +16,7 @@ MODEL = None
|
|
16 |
|
17 |
def load_model(version):
|
18 |
print("Loading model", version)
|
19 |
-
return get_pretrained(version)
|
20 |
|
21 |
|
22 |
def predict(model, text, melody, duration, topk, topp, temperature, cfg_coef):
|
|
|
8 |
|
9 |
import torch
|
10 |
import gradio as gr
|
11 |
+
from audiocraft.models import MusicGen
|
12 |
|
13 |
|
14 |
MODEL = None
|
|
|
16 |
|
17 |
def load_model(version):
|
18 |
print("Loading model", version)
|
19 |
+
return MusicGen.get_pretrained(version)
|
20 |
|
21 |
|
22 |
def predict(model, text, melody, duration, topk, topp, temperature, cfg_coef):
|
app_batched.py
CHANGED
@@ -11,7 +11,7 @@ import torch
|
|
11 |
import gradio as gr
|
12 |
from audiocraft.data.audio_utils import convert_audio
|
13 |
from audiocraft.data.audio import audio_write
|
14 |
-
from
|
15 |
|
16 |
|
17 |
MODEL = None
|
@@ -19,7 +19,7 @@ MODEL = None
|
|
19 |
|
20 |
def load_model():
|
21 |
print("Loading model")
|
22 |
-
return get_pretrained("melody")
|
23 |
|
24 |
|
25 |
def predict(texts, melodies):
|
|
|
11 |
import gradio as gr
|
12 |
from audiocraft.data.audio_utils import convert_audio
|
13 |
from audiocraft.data.audio import audio_write
|
14 |
+
from audiocraft.models import MusicGen
|
15 |
|
16 |
|
17 |
MODEL = None
|
|
|
19 |
|
20 |
def load_model():
|
21 |
print("Loading model")
|
22 |
+
return MusicGen.get_pretrained("melody")
|
23 |
|
24 |
|
25 |
def predict(texts, melodies):
|
audiocraft/models/loaders.py
CHANGED
@@ -20,7 +20,9 @@ of the returned model.
|
|
20 |
"""
|
21 |
|
22 |
from pathlib import Path
|
|
|
23 |
import typing as tp
|
|
|
24 |
|
25 |
from omegaconf import OmegaConf
|
26 |
import torch
|
@@ -28,18 +30,43 @@ import torch
|
|
28 |
from . import builders
|
29 |
|
30 |
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
# Return the state dict either from a file or url
|
33 |
-
|
34 |
-
assert isinstance(
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
else:
|
38 |
-
|
39 |
|
40 |
|
41 |
-
def load_compression_model(
|
42 |
-
pkg = _get_state_dict(
|
43 |
cfg = OmegaConf.create(pkg['xp.cfg'])
|
44 |
cfg.device = str(device)
|
45 |
model = builders.get_compression_model(cfg)
|
@@ -48,8 +75,8 @@ def load_compression_model(file_or_url: tp.Union[Path, str], device='cpu'):
|
|
48 |
return model
|
49 |
|
50 |
|
51 |
-
def load_lm_model(
|
52 |
-
pkg = _get_state_dict(
|
53 |
cfg = OmegaConf.create(pkg['xp.cfg'])
|
54 |
cfg.device = str(device)
|
55 |
if cfg.device == 'cpu':
|
|
|
20 |
"""
|
21 |
|
22 |
from pathlib import Path
|
23 |
+
from huggingface_hub import hf_hub_download
|
24 |
import typing as tp
|
25 |
+
import os
|
26 |
|
27 |
from omegaconf import OmegaConf
|
28 |
import torch
|
|
|
30 |
from . import builders
|
31 |
|
32 |
|
33 |
+
HF_MODEL_CHECKPOINTS_MAP = {
|
34 |
+
"small": "facebook/musicgen-small",
|
35 |
+
"medium": "facebook/musicgen-medium",
|
36 |
+
"large": "facebook/musicgen-large",
|
37 |
+
"melody": "facebook/musicgen-melody",
|
38 |
+
}
|
39 |
+
|
40 |
+
|
41 |
+
def _get_state_dict(
|
42 |
+
file_or_url_or_id: tp.Union[Path, str],
|
43 |
+
filename: tp.Optional[str] = None,
|
44 |
+
device='cpu',
|
45 |
+
cache_dir: tp.Optional[str] = None,
|
46 |
+
):
|
47 |
# Return the state dict either from a file or url
|
48 |
+
file_or_url_or_id = str(file_or_url_or_id)
|
49 |
+
assert isinstance(file_or_url_or_id, str)
|
50 |
+
|
51 |
+
if os.path.isfile(file_or_url_or_id):
|
52 |
+
return torch.load(file_or_url_or_id, map_location=device)
|
53 |
+
|
54 |
+
elif file_or_url_or_id.startswith('https://'):
|
55 |
+
return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
|
56 |
+
|
57 |
+
elif file_or_url_or_id in HF_MODEL_CHECKPOINTS_MAP:
|
58 |
+
assert filename is not None, "filename needs to be defined if using HF checkpoints"
|
59 |
+
|
60 |
+
repo_id = HF_MODEL_CHECKPOINTS_MAP[file_or_url_or_id]
|
61 |
+
file = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir)
|
62 |
+
return torch.load(file, map_location=device)
|
63 |
+
|
64 |
else:
|
65 |
+
raise ValueError(f"{file_or_url_or_id} is not a valid name, path or link that can be loaded.")
|
66 |
|
67 |
|
68 |
+
def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
|
69 |
+
pkg = _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
|
70 |
cfg = OmegaConf.create(pkg['xp.cfg'])
|
71 |
cfg.device = str(device)
|
72 |
model = builders.get_compression_model(cfg)
|
|
|
75 |
return model
|
76 |
|
77 |
|
78 |
+
def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
|
79 |
+
pkg = _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir)
|
80 |
cfg = OmegaConf.create(pkg['xp.cfg'])
|
81 |
cfg.device = str(device)
|
82 |
if cfg.device == 'cpu':
|
audiocraft/models/musicgen.py
CHANGED
@@ -17,7 +17,7 @@ import torch
|
|
17 |
from .encodec import CompressionModel
|
18 |
from .lm import LMModel
|
19 |
from .builders import get_debug_compression_model, get_debug_lm_model
|
20 |
-
from .loaders import load_compression_model, load_lm_model
|
21 |
from ..data.audio_utils import convert_audio
|
22 |
from ..modules.conditioners import ConditioningAttributes, WavCondition
|
23 |
from ..utils.autocast import TorchAutocast
|
@@ -67,10 +67,10 @@ class MusicGen:
|
|
67 |
@staticmethod
|
68 |
def get_pretrained(name: str = 'melody', device='cuda'):
|
69 |
"""Return pretrained model, we provide four models:
|
70 |
-
- small (300M), text to music,
|
71 |
-
- medium (1.5B), text to music,
|
72 |
-
- melody (1.5B) text to music and text+melody to music,
|
73 |
-
- large (3.3B), text to music.
|
74 |
"""
|
75 |
|
76 |
if name == 'debug':
|
@@ -79,21 +79,16 @@ class MusicGen:
|
|
79 |
lm = get_debug_lm_model(device)
|
80 |
return MusicGen(name, compression_model, lm)
|
81 |
|
82 |
-
if
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
'large': '9b6e835c-1f0cf17b5e',
|
93 |
-
'melody': 'f79af192-61305ffc49',
|
94 |
-
}
|
95 |
-
sig = names[name]
|
96 |
-
lm = load_lm_model(ROOT + f'{sig}.th', device=device)
|
97 |
return MusicGen(name, compression_model, lm)
|
98 |
|
99 |
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
|
|
|
17 |
from .encodec import CompressionModel
|
18 |
from .lm import LMModel
|
19 |
from .builders import get_debug_compression_model, get_debug_lm_model
|
20 |
+
from .loaders import load_compression_model, load_lm_model, HF_MODEL_CHECKPOINTS_MAP
|
21 |
from ..data.audio_utils import convert_audio
|
22 |
from ..modules.conditioners import ConditioningAttributes, WavCondition
|
23 |
from ..utils.autocast import TorchAutocast
|
|
|
67 |
@staticmethod
|
68 |
def get_pretrained(name: str = 'melody', device='cuda'):
|
69 |
"""Return pretrained model, we provide four models:
|
70 |
+
- small (300M), text to music, # see: https://huggingface.co/facebook/musicgen-small
|
71 |
+
- medium (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-medium
|
72 |
+
- melody (1.5B) text to music and text+melody to music, # see: https://huggingface.co/facebook/musicgen-melody
|
73 |
+
- large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-large
|
74 |
"""
|
75 |
|
76 |
if name == 'debug':
|
|
|
79 |
lm = get_debug_lm_model(device)
|
80 |
return MusicGen(name, compression_model, lm)
|
81 |
|
82 |
+
if name not in HF_MODEL_CHECKPOINTS_MAP:
|
83 |
+
raise ValueError(
|
84 |
+
f"{name} is not a valid checkpoint name. "
|
85 |
+
f"Choose one of {', '.join(HF_MODEL_CHECKPOINTS_MAP.keys())}"
|
86 |
+
)
|
87 |
+
|
88 |
+
cache_dir = os.environ.get('MUSICGEN_ROOT', None)
|
89 |
+
compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
|
90 |
+
lm = load_lm_model(name, device=device, cache_dir=cache_dir)
|
91 |
+
|
|
|
|
|
|
|
|
|
|
|
92 |
return MusicGen(name, compression_model, lm)
|
93 |
|
94 |
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
|
hf_loading.py
DELETED
@@ -1,61 +0,0 @@
|
|
1 |
-
"""Utility for loading the models from HF."""
|
2 |
-
from pathlib import Path
|
3 |
-
import typing as tp
|
4 |
-
|
5 |
-
from omegaconf import OmegaConf
|
6 |
-
from huggingface_hub import hf_hub_download
|
7 |
-
import torch
|
8 |
-
|
9 |
-
from audiocraft.models import builders, MusicGen
|
10 |
-
|
11 |
-
MODEL_CHECKPOINTS_MAP = {
|
12 |
-
"small": "facebook/musicgen-small",
|
13 |
-
"medium": "facebook/musicgen-medium",
|
14 |
-
"large": "facebook/musicgen-large",
|
15 |
-
"melody": "facebook/musicgen-melody",
|
16 |
-
}
|
17 |
-
|
18 |
-
|
19 |
-
def _get_state_dict(file_or_url: tp.Union[Path, str],
|
20 |
-
filename="state_dict.bin", device='cpu'):
|
21 |
-
# Return the state dict either from a file or url
|
22 |
-
print("loading", file_or_url, filename)
|
23 |
-
file_or_url = str(file_or_url)
|
24 |
-
assert isinstance(file_or_url, str)
|
25 |
-
return torch.load(
|
26 |
-
hf_hub_download(repo_id=file_or_url, filename=filename), map_location=device)
|
27 |
-
|
28 |
-
|
29 |
-
def load_compression_model(file_or_url: tp.Union[Path, str], device='cpu'):
|
30 |
-
pkg = _get_state_dict(file_or_url, filename="compression_state_dict.bin")
|
31 |
-
cfg = OmegaConf.create(pkg['xp.cfg'])
|
32 |
-
cfg.device = str(device)
|
33 |
-
model = builders.get_compression_model(cfg)
|
34 |
-
model.load_state_dict(pkg['best_state'])
|
35 |
-
model.eval()
|
36 |
-
model.cfg = cfg
|
37 |
-
return model
|
38 |
-
|
39 |
-
|
40 |
-
def load_lm_model(file_or_url: tp.Union[Path, str], device='cpu'):
|
41 |
-
pkg = _get_state_dict(file_or_url)
|
42 |
-
cfg = OmegaConf.create(pkg['xp.cfg'])
|
43 |
-
cfg.device = str(device)
|
44 |
-
if cfg.device == 'cpu':
|
45 |
-
cfg.transformer_lm.memory_efficient = False
|
46 |
-
cfg.transformer_lm.custom = True
|
47 |
-
cfg.dtype = 'float32'
|
48 |
-
else:
|
49 |
-
cfg.dtype = 'float16'
|
50 |
-
model = builders.get_lm_model(cfg)
|
51 |
-
model.load_state_dict(pkg['best_state'])
|
52 |
-
model.eval()
|
53 |
-
model.cfg = cfg
|
54 |
-
return model
|
55 |
-
|
56 |
-
|
57 |
-
def get_pretrained(name: str = 'small', device='cuda'):
|
58 |
-
model_id = MODEL_CHECKPOINTS_MAP[name]
|
59 |
-
compression_model = load_compression_model(model_id, device=device)
|
60 |
-
lm = load_lm_model(model_id, device=device)
|
61 |
-
return MusicGen(name, compression_model, lm)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mypy.ini
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
[mypy]
|
2 |
|
3 |
-
[mypy-treetable,torchaudio.*,soundfile,einops.*,av.*,tqdm.*,num2words.*,spacy,xformers.*,scipy]
|
4 |
ignore_missing_imports = True
|
|
|
1 |
[mypy]
|
2 |
|
3 |
+
[mypy-treetable,torchaudio.*,soundfile,einops.*,av.*,tqdm.*,num2words.*,spacy,xformers.*,scipy,huggingface_hub]
|
4 |
ignore_missing_imports = True
|
requirements.txt
CHANGED
@@ -11,6 +11,7 @@ sentencepiece
|
|
11 |
spacy==3.5.2
|
12 |
torch>=2.0.0
|
13 |
torchaudio>=2.0.0
|
|
|
14 |
tqdm
|
15 |
transformers
|
16 |
xformers
|
|
|
11 |
spacy==3.5.2
|
12 |
torch>=2.0.0
|
13 |
torchaudio>=2.0.0
|
14 |
+
huggingface_hub
|
15 |
tqdm
|
16 |
transformers
|
17 |
xformers
|