hugo flores garcia commited on
Commit
41b9d24
·
0 Parent(s):

recovering from a gittastrophe

Browse files
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.rar filter=lfs diff=lfs merge=lfs -text
24
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
25
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
26
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/env.sh
108
+ venv/
109
+ env.bak/
110
+ venv.bak/
111
+
112
+ # Spyder project settings
113
+ .spyderproject
114
+ .spyproject
115
+
116
+ # Rope project settings
117
+ .ropeproject
118
+
119
+ # mkdocs documentation
120
+ /site
121
+
122
+ # mypy
123
+ .mypy_cache/
124
+ .dmypy.json
125
+ dmypy.json
126
+
127
+ # Pyre type checker
128
+ .pyre/
129
+
130
+ # Files created by experiments
131
+ output/
132
+ snapshot/
133
+ *.m4a
134
+ notebooks/scratch.ipynb
135
+ notebooks/inspect.ipynb
136
+ notebooks/effects.ipynb
137
+ notebooks/*.ipynb
138
+ notebooks/*.gif
139
+ notebooks/*.wav
140
+ notebooks/*.mp4
141
+ *runs/
142
+ boards/
143
+ samples/
144
+ *.ipynb
145
+
146
+ results.json
147
+ metrics.csv
148
+ mprofile_*
149
+ mem.png
150
+
151
+ results/
152
+ mprofile*
153
+ *.png
154
+ # do not ignore the test wav file
155
+ !tests/audio/short_test_audio.wav
156
+ !tests/audio/output.wav
157
+ */.DS_Store
158
+ .DS_Store
159
+ env.sh
160
+ _codebraid/
161
+ **/*.html
162
+ **/*.exec.md
163
+ flagged/
164
+ log.txt
165
+ ckpt/
166
+ .syncthing*
167
+ tests/assets/
168
+ archived/
169
+
170
+ scratch/
171
+
172
+ runs-archive
173
+ lyrebird-audiotools
174
+ lyrebird-audio-codec
175
+ samples-*/**
176
+
177
+ gradio-outputs/
178
+ samples*/
179
+ models-all/
180
+ models.zip
181
+ .git-old
182
+
183
+
184
+
185
+ gtzan.zip
186
+ .gtzan_emb_cache
187
+
188
+
189
+ data/
190
+ data
191
+ pyharp
192
+
193
+ models/vampnet/*
194
+ models/*
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Hugo Flores García and Prem Seetharaman
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: salad bowl (vampnet)
3
+ emoji: 🥗
4
+ colorFrom: yellow
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 4.37.2
8
+ python_version: 3.9.17
9
+ app_file: app.py
10
+ pinned: false
11
+ license: cc-by-nc-4.0
12
+ ---
13
+
14
+ # VampNet
15
+
16
+ This repository contains recipes for training generative music models on top of the Descript Audio Codec.
17
+
18
+ # Setting up
19
+
20
+ **Requires Python 3.9**.
21
+
22
+ you'll need a Python 3.9 environment to run VampNet. This is due to a [known issue with madmom](https://github.com/hugofloresgarcia/vampnet/issues/15).
23
+
24
+ (for example, using conda)
25
+ ```bash
26
+ conda create -n vampnet python=3.9
27
+ conda activate vampnet
28
+ ```
29
+
30
+ install VampNet
31
+
32
+ ```bash
33
+ git clone https://github.com/hugofloresgarcia/vampnet.git
34
+ pip install -e ./vampnet
35
+ ```
36
+
37
+ # Usage
38
+
39
+
40
+
41
+ ## Launching the Gradio Interface
42
+ You can launch a gradio UI to play with vampnet.
43
+
44
+ ```bash
45
+ python app.py --args.load conf/interface.yml --Interface.device cuda
46
+ ```
47
+
48
+ # Training / Fine-tuning
49
+
50
+ ## Training a model
51
+
52
+ To train a model, run the following script:
53
+
54
+ ```bash
55
+ python scripts/exp/train.py --args.load conf/vampnet.yml --save_path /path/to/checkpoints
56
+ ```
57
+
58
+ for multi-gpu training, use torchrun:
59
+
60
+ ```bash
61
+ torchrun --nproc_per_node gpu scripts/exp/train.py --args.load conf/vampnet.yml --save_path path/to/ckpt
62
+ ```
63
+
64
+ You can edit `conf/vampnet.yml` to change the dataset paths or any training hyperparameters.
65
+
66
+ For coarse2fine models, you can use `conf/c2f.yml` as a starting configuration.
67
+
68
+ See `python scripts/exp/train.py -h` for a list of options.
69
+
70
+ ## Debugging training
71
+
72
+ To debug training, it's easier to debug with 1 gpu and 0 workers
73
+
74
+ ```bash
75
+ CUDA_VISIBLE_DEVICES=0 python -m pdb scripts/exp/train.py --args.load conf/vampnet.yml --save_path /path/to/checkpoints --num_workers 0
76
+ ```
77
+
78
+ ## Fine-tuning
79
+ To fine-tune a model, use the script in `scripts/exp/fine_tune.py` to generate 3 configuration files: `c2f.yml`, `coarse.yml`, and `interface.yml`.
80
+ The first two are used to fine-tune the coarse and fine models, respectively. The last one is used to launch the gradio interface.
81
+
82
+ ```bash
83
+ python scripts/exp/fine_tune.py "/path/to/audio1.mp3 /path/to/audio2/ /path/to/audio3.wav" <fine_tune_name>
84
+ ```
85
+
86
+ This will create a folder under `conf/<fine_tune_name>/` with the 3 configuration files.
87
+
88
+ The save_paths will be set to `runs/<fine_tune_name>/coarse` and `runs/<fine_tune_name>/c2f`.
89
+
90
+ launch the coarse job:
91
+ ```bash
92
+ python scripts/exp/train.py --args.load conf/generated/<fine_tune_name>/coarse.yml
93
+ ```
94
+
95
+ this will save the coarse model to `runs/<fine_tune_name>/coarse/ckpt/best/`.
96
+
97
+ launch the c2f job:
98
+ ```bash
99
+ python scripts/exp/train.py --args.load conf/generated/<fine_tune_name>/c2f.yml
100
+ ```
101
+
102
+ ## A note on argbind
103
+ This repository relies on [argbind](https://github.com/pseeth/argbind) to manage CLIs and config files.
104
+ Config files are stored in the `conf/` folder.
105
+
106
+ ### Licensing for Pretrained Models:
107
+ The weights for the models are licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml). Likewise, any VampNet models fine-tuned on the pretrained models are also licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml).
108
+
109
+ Download the pretrained models from [this link](https://zenodo.org/record/8136629). Then, extract the models to the `models/` folder.
110
+
111
+
112
+
113
+
app.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ from pathlib import Path
3
+ import yaml
4
+ import time
5
+ import uuid
6
+
7
+ import numpy as np
8
+ import audiotools as at
9
+ import argbind
10
+ import shutil
11
+ import torch
12
+ from datetime import datetime
13
+
14
+ import gradio as gr
15
+ from vampnet.interface import Interface, signal_concat
16
+ from vampnet import mask as pmask
17
+
18
+
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+
21
+ interface = Interface.default()
22
+
23
+ # populate the model choices with any interface.yml files in the generated confs
24
+ MODEL_CHOICES = {
25
+ "default": {
26
+ "Interface.coarse_ckpt": str(interface.coarse_path),
27
+ "Interface.coarse2fine_ckpt": str(interface.c2f_path),
28
+ "Interface.codec_ckpt": str(interface.codec_path),
29
+ }
30
+ }
31
+ generated_confs = Path("conf/generated")
32
+ for conf_file in generated_confs.glob("*/interface.yml"):
33
+ with open(conf_file) as f:
34
+ _conf = yaml.safe_load(f)
35
+
36
+ # check if the coarse, c2f, and codec ckpts exist
37
+ # otherwise, dont' add this model choice
38
+ if not (
39
+ Path(_conf["Interface.coarse_ckpt"]).exists() and
40
+ Path(_conf["Interface.coarse2fine_ckpt"]).exists() and
41
+ Path(_conf["Interface.codec_ckpt"]).exists()
42
+ ):
43
+ continue
44
+
45
+ MODEL_CHOICES[conf_file.parent.name] = _conf
46
+
47
+
48
+ def to_output(sig):
49
+ return sig.sample_rate, sig.cpu().detach().numpy()[0][0]
50
+
51
+
52
+
53
+ MAX_DURATION_S = 5
54
+ def load_audio(file):
55
+ print(file)
56
+ if isinstance(file, str):
57
+ filepath = file
58
+ elif isinstance(file, tuple):
59
+ # not a file
60
+ sr, samples = file
61
+ samples = samples / np.iinfo(samples.dtype).max
62
+ return sr, samples
63
+ else:
64
+ filepath = file.name
65
+ sig = at.AudioSignal.salient_excerpt(
66
+ filepath, duration=MAX_DURATION_S
67
+ )
68
+ sig = at.AudioSignal(filepath)
69
+ return to_output(sig)
70
+
71
+
72
+ def load_example_audio():
73
+ return load_audio("./assets/example.wav")
74
+
75
+ from torch_pitch_shift import pitch_shift, get_fast_shifts
76
+ def shift_pitch(signal, interval: int):
77
+ signal.samples = pitch_shift(
78
+ signal.samples,
79
+ shift=interval,
80
+ sample_rate=signal.sample_rate
81
+ )
82
+ return signal
83
+
84
+
85
+ @spaces.GPU
86
+ def _vamp(
87
+ seed, input_audio, model_choice,
88
+ pitch_shift_amt, periodic_p,
89
+ n_mask_codebooks, periodic_w, onset_mask_width,
90
+ dropout, sampletemp, typical_filtering,
91
+ typical_mass, typical_min_tokens, top_p,
92
+ sample_cutoff, stretch_factor, api=False
93
+ ):
94
+ t0 = time.time()
95
+ interface.to("cuda" if torch.cuda.is_available() else "cpu")
96
+ print(f"using device {interface.device}")
97
+ _seed = seed if seed > 0 else None
98
+ if _seed is None:
99
+ _seed = int(torch.randint(0, 2**32, (1,)).item())
100
+ at.util.seed(_seed)
101
+
102
+ sr, input_audio = input_audio
103
+ input_audio = input_audio / np.iinfo(input_audio.dtype).max
104
+
105
+ sig = at.AudioSignal(input_audio, sr)
106
+
107
+ # reload the model if necessary
108
+ interface.reload(
109
+ coarse_ckpt=MODEL_CHOICES[model_choice]["Interface.coarse_ckpt"],
110
+ c2f_ckpt=MODEL_CHOICES[model_choice]["Interface.coarse2fine_ckpt"],
111
+ )
112
+
113
+ if pitch_shift_amt != 0:
114
+ sig = shift_pitch(sig, pitch_shift_amt)
115
+
116
+ build_mask_kwargs = dict(
117
+ rand_mask_intensity=1.0,
118
+ prefix_s=0.0,
119
+ suffix_s=0.0,
120
+ periodic_prompt=int(periodic_p),
121
+ periodic_prompt_width=periodic_w,
122
+ onset_mask_width=onset_mask_width,
123
+ _dropout=dropout,
124
+ upper_codebook_mask=int(n_mask_codebooks),
125
+ )
126
+
127
+ vamp_kwargs = dict(
128
+ temperature=sampletemp,
129
+ typical_filtering=typical_filtering,
130
+ typical_mass=typical_mass,
131
+ typical_min_tokens=typical_min_tokens,
132
+ top_p=None,
133
+ seed=_seed,
134
+ sample_cutoff=1.0,
135
+ )
136
+
137
+ # save the mask as a txt file
138
+ interface.set_chunk_size(10.0)
139
+ sig, mask, codes = interface.ez_vamp(
140
+ sig,
141
+ batch_size=1 if api else 1,
142
+ feedback_steps=1,
143
+ time_stretch_factor=stretch_factor,
144
+ build_mask_kwargs=build_mask_kwargs,
145
+ vamp_kwargs=vamp_kwargs,
146
+ return_mask=True,
147
+ )
148
+ print(f"vamp took {time.time() - t0} seconds")
149
+
150
+
151
+ return to_output(sig)
152
+
153
+ def vamp(data):
154
+ return _vamp(
155
+ seed=data[seed],
156
+ input_audio=data[input_audio],
157
+ model_choice=data[model_choice],
158
+ pitch_shift_amt=data[pitch_shift_amt],
159
+ periodic_p=data[periodic_p],
160
+ n_mask_codebooks=data[n_mask_codebooks],
161
+ periodic_w=data[periodic_w],
162
+ onset_mask_width=data[onset_mask_width],
163
+ dropout=data[dropout],
164
+ sampletemp=data[sampletemp],
165
+ typical_filtering=data[typical_filtering],
166
+ typical_mass=data[typical_mass],
167
+ typical_min_tokens=data[typical_min_tokens],
168
+ top_p=data[top_p],
169
+ sample_cutoff=data[sample_cutoff],
170
+ stretch_factor=data[stretch_factor],
171
+ api=False,
172
+ )
173
+
174
+ def api_vamp(data):
175
+ return _vamp(
176
+ seed=data[seed],
177
+ input_audio=data[input_audio],
178
+ model_choice=data[model_choice],
179
+ pitch_shift_amt=data[pitch_shift_amt],
180
+ periodic_p=data[periodic_p],
181
+ n_mask_codebooks=data[n_mask_codebooks],
182
+ periodic_w=data[periodic_w],
183
+ onset_mask_width=data[onset_mask_width],
184
+ dropout=data[dropout],
185
+ sampletemp=data[sampletemp],
186
+ typical_filtering=data[typical_filtering],
187
+ typical_mass=data[typical_mass],
188
+ typical_min_tokens=data[typical_min_tokens],
189
+ top_p=data[top_p],
190
+ sample_cutoff=data[sample_cutoff],
191
+ stretch_factor=data[stretch_factor],
192
+ api=True,
193
+ )
194
+
195
+
196
+
197
+
198
+
199
+ with gr.Blocks() as demo:
200
+ with gr.Row():
201
+ with gr.Column():
202
+ manual_audio_upload = gr.File(
203
+ label=f"upload some audio (will be randomly trimmed to max of 100s)",
204
+ file_types=["audio"]
205
+ )
206
+ load_example_audio_button = gr.Button("or load example audio")
207
+
208
+ input_audio = gr.Audio(
209
+ label="input audio",
210
+ interactive=False,
211
+ type="numpy",
212
+ )
213
+
214
+ audio_mask = gr.Audio(
215
+ label="audio mask (listen to this to hear the mask hints)",
216
+ interactive=False,
217
+ type="numpy",
218
+ )
219
+
220
+ # connect widgets
221
+ load_example_audio_button.click(
222
+ fn=load_example_audio,
223
+ inputs=[],
224
+ outputs=[ input_audio]
225
+ )
226
+
227
+ manual_audio_upload.change(
228
+ fn=load_audio,
229
+ inputs=[manual_audio_upload],
230
+ outputs=[ input_audio]
231
+ )
232
+
233
+
234
+ # mask settings
235
+ with gr.Column():
236
+ with gr.Accordion("manual controls", open=True):
237
+ periodic_p = gr.Slider(
238
+ label="periodic prompt",
239
+ minimum=0,
240
+ maximum=13,
241
+ step=1,
242
+ value=3,
243
+ )
244
+
245
+ onset_mask_width = gr.Slider(
246
+ label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
247
+ minimum=0,
248
+ maximum=100,
249
+ step=1,
250
+ value=0, visible=False
251
+ )
252
+
253
+ n_mask_codebooks = gr.Slider(
254
+ label="compression prompt ",
255
+ value=3,
256
+ minimum=1,
257
+ maximum=14,
258
+ step=1,
259
+ )
260
+
261
+ maskimg = gr.Image(
262
+ label="mask image",
263
+ interactive=False,
264
+ type="filepath"
265
+ )
266
+
267
+ with gr.Accordion("extras ", open=False):
268
+ pitch_shift_amt = gr.Slider(
269
+ label="pitch shift amount (semitones)",
270
+ minimum=-12,
271
+ maximum=12,
272
+ step=1,
273
+ value=0,
274
+ )
275
+
276
+ stretch_factor = gr.Slider(
277
+ label="time stretch factor",
278
+ minimum=0,
279
+ maximum=8,
280
+ step=1,
281
+ value=1,
282
+ )
283
+
284
+ periodic_w = gr.Slider(
285
+ label="periodic prompt width (steps, 1 step ~= 10milliseconds)",
286
+ minimum=1,
287
+ maximum=20,
288
+ step=1,
289
+ value=1,
290
+ )
291
+
292
+
293
+ with gr.Accordion("sampling settings", open=False):
294
+ sampletemp = gr.Slider(
295
+ label="sample temperature",
296
+ minimum=0.1,
297
+ maximum=10.0,
298
+ value=1.0,
299
+ step=0.001
300
+ )
301
+
302
+ top_p = gr.Slider(
303
+ label="top p (0.0 = off)",
304
+ minimum=0.0,
305
+ maximum=1.0,
306
+ value=0.0
307
+ )
308
+ typical_filtering = gr.Checkbox(
309
+ label="typical filtering ",
310
+ value=True
311
+ )
312
+ typical_mass = gr.Slider(
313
+ label="typical mass (should probably stay between 0.1 and 0.5)",
314
+ minimum=0.01,
315
+ maximum=0.99,
316
+ value=0.15
317
+ )
318
+ typical_min_tokens = gr.Slider(
319
+ label="typical min tokens (should probably stay between 1 and 256)",
320
+ minimum=1,
321
+ maximum=256,
322
+ step=1,
323
+ value=64
324
+ )
325
+ sample_cutoff = gr.Slider(
326
+ label="sample cutoff",
327
+ minimum=0.0,
328
+ maximum=0.9,
329
+ value=1.0,
330
+ step=0.01
331
+ )
332
+
333
+
334
+ dropout = gr.Slider(
335
+ label="mask dropout",
336
+ minimum=0.0,
337
+ maximum=1.0,
338
+ step=0.01,
339
+ value=0.0
340
+ )
341
+
342
+
343
+ seed = gr.Number(
344
+ label="seed (0 for random)",
345
+ value=0,
346
+ precision=0,
347
+ )
348
+
349
+
350
+ # mask settings
351
+ with gr.Column():
352
+
353
+ model_choice = gr.Dropdown(
354
+ label="model choice",
355
+ choices=list(MODEL_CHOICES.keys()),
356
+ value="default",
357
+ visible=True
358
+ )
359
+
360
+
361
+ vamp_button = gr.Button("generate (vamp)!!!")
362
+
363
+
364
+ audio_outs = []
365
+ use_as_input_btns = []
366
+ for i in range(1):
367
+ with gr.Column():
368
+ audio_outs.append(gr.Audio(
369
+ label=f"output audio {i+1}",
370
+ interactive=False,
371
+ type="numpy"
372
+ ))
373
+ use_as_input_btns.append(
374
+ gr.Button(f"use as input (feedback)")
375
+ )
376
+
377
+ thank_you = gr.Markdown("")
378
+
379
+ # download all the outputs
380
+ # download = gr.File(type="filepath", label="download outputs")
381
+
382
+
383
+ _inputs = {
384
+ input_audio,
385
+ sampletemp,
386
+ top_p,
387
+ periodic_p, periodic_w,
388
+ dropout,
389
+ stretch_factor,
390
+ onset_mask_width,
391
+ typical_filtering,
392
+ typical_mass,
393
+ typical_min_tokens,
394
+ seed,
395
+ model_choice,
396
+ n_mask_codebooks,
397
+ pitch_shift_amt,
398
+ sample_cutoff,
399
+ }
400
+
401
+ # connect widgets
402
+ vamp_button.click(
403
+ fn=vamp,
404
+ inputs=_inputs,
405
+ outputs=[audio_outs[0]],
406
+ )
407
+
408
+ api_vamp_button = gr.Button("api vamp", visible=True)
409
+ api_vamp_button.click(
410
+ fn=api_vamp,
411
+ inputs=_inputs,
412
+ outputs=[audio_outs[0]],
413
+ api_name="vamp"
414
+ )
415
+
416
+ for i, btn in enumerate(use_as_input_btns):
417
+ btn.click(
418
+ fn=load_audio,
419
+ inputs=[audio_outs[i]],
420
+ outputs=[input_audio]
421
+ )
422
+
423
+ try:
424
+ demo.queue()
425
+ demo.launch(share=True)
426
+ except KeyboardInterrupt:
427
+ shutil.rmtree("gradio-outputs", ignore_errors=True)
428
+ raise
assets/example.wav ADDED
Binary file (883 kB). View file
 
conf/c2f.yml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/vampnet.yml
3
+
4
+ VampNet.n_codebooks: 14
5
+ VampNet.n_conditioning_codebooks: 4
6
+
7
+ VampNet.embedding_dim: 1280
8
+ VampNet.n_layers: 16
9
+ VampNet.n_heads: 20
10
+
11
+ AudioDataset.duration: 3.0
12
+
13
+
14
+ AudioDataset.loudness_cutoff: -40.0
conf/interface.yml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ Interface.coarse_ckpt: ./models/vampnet/coarse.pth
2
+ Interface.coarse2fine_ckpt: ./models/vampnet/c2f.pth
3
+ Interface.codec_ckpt: ./models/vampnet/codec.pth
4
+ Interface.coarse_chunk_size_s: 10
5
+ Interface.coarse2fine_chunk_size_s: 3
6
+ Interface.wavebeat_ckpt: ./models/wavebeat.pth
7
+
8
+ # AudioLoader.sources:
9
+ # - /media/CHONK/null
10
+
conf/lora/lora.yml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/vampnet.yml
3
+
4
+ fine_tune: True
5
+
6
+ train/AudioDataset.n_examples: 100000000
7
+ val/AudioDataset.n_examples: 500
8
+
9
+
10
+ NoamScheduler.warmup: 500
11
+
12
+ batch_size: 7
13
+ num_workers: 7
14
+ save_iters: [2000, 4000, 10000,20000, 40000, 100000]
15
+ sample_freq: 2000
16
+ val_freq: 1000
17
+
18
+ AdamW.lr: 0.0001
19
+
20
+ # let's us organize sound classes into folders and choose from those sound classes uniformly
21
+ AudioDataset.without_replacement: False
22
+ num_iters: 500000
conf/salad_bowl.yml ADDED
File without changes
conf/vampnet.yml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ codec_ckpt: ./models/vampnet/codec.pth
3
+ save_path: ckpt
4
+
5
+ num_iters: 1000000000
6
+ save_iters: [10000, 50000, 100000, 300000, 500000]
7
+ val_idx: [0,1,2,3,4,5,6,7,8,9]
8
+ sample_freq: 10000
9
+ val_freq: 1000
10
+
11
+ batch_size: 8
12
+ num_workers: 10
13
+
14
+ # Optimization
15
+ amp: false
16
+
17
+ CrossEntropyLoss.label_smoothing: 0.1
18
+
19
+ AdamW.lr: 0.001
20
+
21
+ NoamScheduler.factor: 2.0
22
+ NoamScheduler.warmup: 10000
23
+
24
+ VampNet.vocab_size: 1024
25
+ VampNet.n_codebooks: 4
26
+ VampNet.n_conditioning_codebooks: 0
27
+ VampNet.r_cond_dim: 0
28
+ VampNet.noise_mode: mask
29
+ VampNet.embedding_dim: 1280
30
+ VampNet.n_layers: 20
31
+ VampNet.n_heads: 20
32
+ VampNet.flash_attn: false
33
+ VampNet.dropout: 0.1
34
+
35
+ AudioLoader.relative_path: ""
36
+ AudioDataset.loudness_cutoff: -30.0
37
+ AudioDataset.without_replacement: true
38
+ AudioLoader.shuffle: true
39
+
40
+ AudioDataset.duration: 10.0
41
+
42
+ train/AudioDataset.n_examples: 10000000
43
+ train/AudioLoader.sources:
44
+ - /media/CHONK/hugo/spotdl/audio-train
45
+
46
+ val/AudioDataset.n_examples: 2000
47
+ val/AudioLoader.sources:
48
+ - /media/CHONK/hugo/spotdl/audio-val
49
+
hello.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import vampnet
3
+ import audiotools as at
4
+
5
+ # load the default vampnet model
6
+ interface = vampnet.interface.Interface.default()
7
+
8
+ # list available finetuned models
9
+ finetuned_model_choices = interface.available_models()
10
+ print(f"available finetuned models: {finetuned_model_choices}")
11
+
12
+ # pick a random finetuned model
13
+ model_choice = random.choice(finetuned_model_choices)
14
+ print(f"choosing model: {model_choice}")
15
+
16
+ # load a finetuned model
17
+ interface.load_finetuned(model_choice)
18
+
19
+ # load an example audio file
20
+ signal = at.AudioSignal("assets/example.wav")
21
+
22
+ # get the tokens for the audio
23
+ codes = interface.encode(signal)
24
+
25
+ # build a mask for the audio
26
+ mask = interface.build_mask(
27
+ codes, signal,
28
+ periodic_prompt=7,
29
+ upper_codebook_mask=3,
30
+ )
31
+
32
+ # generate the output tokens
33
+ output_tokens = interface.vamp(
34
+ codes, mask, return_mask=False,
35
+ temperature=1.0,
36
+ typical_filtering=True,
37
+ )
38
+
39
+ # convert them to a signal
40
+ output_signal = interface.decode(output_tokens)
41
+
42
+ # save the output signal
43
+ output_signal.write("scratch/output.wav")
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.1.0
2
+ argbind>=0.3.2
3
+ numpy==1.23
4
+ loralib
5
+ wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
6
+ lac @ git+https://github.com/hugofloresgarcia/lac.git
7
+ descript-audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git
8
+ -e git+https://github.com/audacitorch/pyharp.git#egg=pyharp
9
+ torch_pitch_shift
10
+ gradio==4.37.2
scripts/exp/eval.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import os
3
+ from functools import partial
4
+
5
+ from frechet_audio_distance import FrechetAudioDistance
6
+ import pandas
7
+ import argbind
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ import audiotools
12
+ from audiotools import AudioSignal
13
+
14
+ @argbind.bind(without_prefix=True)
15
+ def eval(
16
+ exp_dir: str = None,
17
+ baseline_key: str = "baseline",
18
+ audio_ext: str = ".wav",
19
+ ):
20
+ assert exp_dir is not None
21
+ exp_dir = Path(exp_dir)
22
+ assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist"
23
+
24
+ # set up our metrics
25
+ # sisdr_loss = audiotools.metrics.distance.SISDRLoss()
26
+ # stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
27
+ mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
28
+ frechet = FrechetAudioDistance(
29
+ use_pca=False,
30
+ use_activation=False,
31
+ verbose=True,
32
+ audio_load_worker=4,
33
+ )
34
+ frechet.model.to("cuda" if torch.cuda.is_available() else "cpu")
35
+
36
+ # figure out what conditions we have
37
+ conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()]
38
+
39
+ assert baseline_key in conditions, f"baseline_key {baseline_key} not found in {exp_dir}"
40
+ conditions.remove(baseline_key)
41
+
42
+ print(f"Found {len(conditions)} conditions in {exp_dir}")
43
+ print(f"conditions: {conditions}")
44
+
45
+ baseline_dir = exp_dir / baseline_key
46
+ baseline_files = sorted(list(baseline_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
47
+
48
+ metrics = []
49
+ for condition in tqdm(conditions):
50
+ cond_dir = exp_dir / condition
51
+ cond_files = sorted(list(cond_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
52
+
53
+ print(f"computing fad for {baseline_dir} and {cond_dir}")
54
+ frechet_score = frechet.score(baseline_dir, cond_dir)
55
+
56
+ # make sure we have the same number of files
57
+ num_files = min(len(baseline_files), len(cond_files))
58
+ baseline_files = baseline_files[:num_files]
59
+ cond_files = cond_files[:num_files]
60
+ assert len(list(baseline_files)) == len(list(cond_files)), f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}"
61
+
62
+ def process(baseline_file, cond_file):
63
+ # make sure the files match (same name)
64
+ assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match"
65
+
66
+ # load the files
67
+ baseline_sig = AudioSignal(str(baseline_file))
68
+ cond_sig = AudioSignal(str(cond_file))
69
+
70
+ cond_sig.resample(baseline_sig.sample_rate)
71
+ cond_sig.truncate_samples(baseline_sig.length)
72
+
73
+ # if our condition is inpainting, we need to trim the conditioning off
74
+ if "inpaint" in condition:
75
+ ctx_amt = float(condition.split("_")[-1])
76
+ ctx_samples = int(ctx_amt * baseline_sig.sample_rate)
77
+ print(f"found inpainting condition. trimming off {ctx_samples} samples from {cond_file} and {baseline_file}")
78
+ cond_sig.trim(ctx_samples, ctx_samples)
79
+ baseline_sig.trim(ctx_samples, ctx_samples)
80
+
81
+ return {
82
+ # "sisdr": -sisdr_loss(baseline_sig, cond_sig).item(),
83
+ # "stft": stft_loss(baseline_sig, cond_sig).item(),
84
+ "mel": mel_loss(baseline_sig, cond_sig).item(),
85
+ "frechet": frechet_score,
86
+ # "visqol": vsq,
87
+ "condition": condition,
88
+ "file": baseline_file.stem,
89
+ }
90
+
91
+ print(f"processing {len(baseline_files)} files in {baseline_dir} and {cond_dir}")
92
+ metrics.extend(tqdm(map(process, baseline_files, cond_files), total=len(baseline_files)))
93
+
94
+ metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")]
95
+
96
+
97
+ for mk in metric_keys:
98
+ stat = pandas.DataFrame(metrics)
99
+ stat = stat.groupby(['condition'])[mk].agg(['mean', 'count', 'std'])
100
+ stat.to_csv(exp_dir / f"stats-{mk}.csv")
101
+
102
+ df = pandas.DataFrame(metrics)
103
+ df.to_csv(exp_dir / "metrics-all.csv", index=False)
104
+
105
+
106
+ if __name__ == "__main__":
107
+ args = argbind.parse_args()
108
+
109
+ with argbind.scope(args):
110
+ eval()
scripts/exp/experiment.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import random
3
+ from typing import List
4
+ import tempfile
5
+ import subprocess
6
+
7
+ import argbind
8
+ from tqdm import tqdm
9
+ import torch
10
+
11
+ from vampnet.interface import Interface
12
+ from vampnet import mask as pmask
13
+ import audiotools as at
14
+
15
+ Interface: Interface = argbind.bind(Interface)
16
+
17
+
18
+
19
+ def calculate_bitrate(
20
+ interface, num_codebooks,
21
+ downsample_factor
22
+ ):
23
+ bit_width = 10
24
+ sr = interface.codec.sample_rate
25
+ hop = interface.codec.hop_size
26
+ rate = (sr / hop) * ((bit_width * num_codebooks) / downsample_factor)
27
+ return rate
28
+
29
+ def baseline(sig, interface):
30
+ return interface.preprocess(sig)
31
+
32
+ def reconstructed(sig, interface):
33
+ return interface.decode(
34
+ interface.encode(sig)
35
+ )
36
+
37
+ def coarse2fine(sig, interface):
38
+ z = interface.encode(sig)
39
+ z = z[:, :interface.c2f.n_conditioning_codebooks, :]
40
+
41
+ z = interface.coarse_to_fine(z)
42
+ return interface.decode(z)
43
+
44
+ class CoarseCond:
45
+
46
+ def __init__(self, num_conditioning_codebooks, downsample_factor):
47
+ self.num_conditioning_codebooks = num_conditioning_codebooks
48
+ self.downsample_factor = downsample_factor
49
+
50
+ def __call__(self, sig, interface):
51
+ z = interface.encode(sig)
52
+ mask = pmask.full_mask(z)
53
+ mask = pmask.codebook_unmask(mask, self.num_conditioning_codebooks)
54
+ mask = pmask.periodic_mask(mask, self.downsample_factor)
55
+
56
+ zv = interface.coarse_vamp(z, mask)
57
+ zv = interface.coarse_to_fine(zv)
58
+ return interface.decode(zv)
59
+
60
+ def opus(sig, interface, bitrate=128):
61
+ sig = interface.preprocess(sig)
62
+
63
+ with tempfile.NamedTemporaryFile(suffix=".wav") as f:
64
+ sig.write(f.name)
65
+
66
+ opus_name = Path(f.name).with_suffix(".opus")
67
+ # convert to opus
68
+ cmd = [
69
+ "ffmpeg", "-y", "-i", f.name,
70
+ "-c:a", "libopus",
71
+ "-b:a", f"{bitrate}",
72
+ opus_name
73
+ ]
74
+ subprocess.run(cmd, check=True)
75
+
76
+ # convert back to wav
77
+ output_name = Path(f"{f.name}-opus").with_suffix(".wav")
78
+ cmd = [
79
+ "ffmpeg", "-y", "-i", opus_name,
80
+ output_name
81
+ ]
82
+
83
+ subprocess.run(cmd, check=True)
84
+
85
+ sig = at.AudioSignal(
86
+ output_name,
87
+ sample_rate=sig.sample_rate
88
+ )
89
+ return sig
90
+
91
+ def mask_ratio_1_step(ratio=1.0):
92
+ def wrapper(sig, interface):
93
+ z = interface.encode(sig)
94
+ mask = pmask.linear_random(z, ratio)
95
+ zv = interface.coarse_vamp(
96
+ z,
97
+ mask,
98
+ sampling_steps=1,
99
+ )
100
+
101
+ return interface.decode(zv)
102
+ return wrapper
103
+
104
+ def num_sampling_steps(num_steps=1):
105
+ def wrapper(sig, interface: Interface):
106
+ z = interface.encode(sig)
107
+ mask = pmask.periodic_mask(z, 16)
108
+ zv = interface.coarse_vamp(
109
+ z,
110
+ mask,
111
+ sampling_steps=num_steps,
112
+ )
113
+
114
+ zv = interface.coarse_to_fine(zv)
115
+ return interface.decode(zv)
116
+ return wrapper
117
+
118
+ def beat_mask(ctx_time):
119
+ def wrapper(sig, interface):
120
+ beat_mask = interface.make_beat_mask(
121
+ sig,
122
+ before_beat_s=ctx_time/2,
123
+ after_beat_s=ctx_time/2,
124
+ invert=True
125
+ )
126
+
127
+ z = interface.encode(sig)
128
+
129
+ zv = interface.coarse_vamp(
130
+ z, beat_mask
131
+ )
132
+
133
+ zv = interface.coarse_to_fine(zv)
134
+ return interface.decode(zv)
135
+ return wrapper
136
+
137
+ def inpaint(ctx_time):
138
+ def wrapper(sig, interface: Interface):
139
+ z = interface.encode(sig)
140
+ mask = pmask.inpaint(z, interface.s2t(ctx_time), interface.s2t(ctx_time))
141
+
142
+ zv = interface.coarse_vamp(z, mask)
143
+ zv = interface.coarse_to_fine(zv)
144
+
145
+ return interface.decode(zv)
146
+ return wrapper
147
+
148
+ def token_noise(noise_amt):
149
+ def wrapper(sig, interface: Interface):
150
+ z = interface.encode(sig)
151
+ mask = pmask.random(z, noise_amt)
152
+ z = torch.where(
153
+ mask,
154
+ torch.randint_like(z, 0, interface.coarse.vocab_size),
155
+ z
156
+ )
157
+ return interface.decode(z)
158
+ return wrapper
159
+
160
+ EXP_REGISTRY = {}
161
+
162
+ EXP_REGISTRY["gen-compression"] = {
163
+ "baseline": baseline,
164
+ "reconstructed": reconstructed,
165
+ "coarse2fine": coarse2fine,
166
+ **{
167
+ f"{n}_codebooks_downsampled_{x}x": CoarseCond(num_conditioning_codebooks=n, downsample_factor=x)
168
+ for (n, x) in (
169
+ (1, 1), # 1 codebook, no downsampling
170
+ (4, 4), # 4 codebooks, downsampled 4x
171
+ (4, 16), # 4 codebooks, downsampled 16x
172
+ (4, 32), # 4 codebooks, downsampled 16x
173
+ )
174
+ },
175
+ **{
176
+ f"token_noise_{x}": mask_ratio_1_step(ratio=x)
177
+ for x in [0.25, 0.5, 0.75]
178
+ },
179
+
180
+ }
181
+
182
+
183
+ EXP_REGISTRY["sampling-steps"] = {
184
+ # "codec": reconstructed,
185
+ **{f"steps_{n}": num_sampling_steps(n) for n in [1, 4, 12, 36, 64, 72]},
186
+ }
187
+
188
+
189
+ EXP_REGISTRY["musical-sampling"] = {
190
+ **{f"beat_mask_{t}": beat_mask(t) for t in [0.075]},
191
+ **{f"inpaint_{t}": inpaint(t) for t in [0.5, 1.0,]}, # multiply these by 2 (they go left and right)
192
+ }
193
+
194
+ @argbind.bind(without_prefix=True)
195
+ def main(
196
+ sources=[
197
+ "/media/CHONK/hugo/spotdl/val",
198
+ ],
199
+ output_dir: str = "./samples",
200
+ max_excerpts: int = 2000,
201
+ exp_type: str = "gen-compression",
202
+ seed: int = 0,
203
+ ext: str = [".mp3"],
204
+ ):
205
+ at.util.seed(seed)
206
+ interface = Interface()
207
+
208
+ output_dir = Path(output_dir)
209
+ output_dir.mkdir(exist_ok=True, parents=True)
210
+
211
+ from audiotools.data.datasets import AudioLoader, AudioDataset
212
+
213
+ loader = AudioLoader(sources=sources, shuffle_state=seed, ext=ext)
214
+ dataset = AudioDataset(loader,
215
+ sample_rate=interface.codec.sample_rate,
216
+ duration=interface.coarse.chunk_size_s,
217
+ n_examples=max_excerpts,
218
+ without_replacement=True,
219
+ )
220
+
221
+ if exp_type in EXP_REGISTRY:
222
+ SAMPLE_CONDS = EXP_REGISTRY[exp_type]
223
+ else:
224
+ raise ValueError(f"Unknown exp_type {exp_type}")
225
+
226
+
227
+ indices = list(range(max_excerpts))
228
+ random.shuffle(indices)
229
+ for i in tqdm(indices):
230
+ # if all our files are already there, skip
231
+ done = []
232
+ for name in SAMPLE_CONDS:
233
+ o_dir = Path(output_dir) / name
234
+ done.append((o_dir / f"{i}.wav").exists())
235
+ if all(done):
236
+ continue
237
+
238
+ sig = dataset[i]["signal"]
239
+ results = {
240
+ name: cond(sig, interface).cpu()
241
+ for name, cond in SAMPLE_CONDS.items()
242
+ }
243
+
244
+ for name, sig in results.items():
245
+ o_dir = Path(output_dir) / name
246
+ o_dir.mkdir(exist_ok=True, parents=True)
247
+
248
+ sig.write(o_dir / f"{i}.wav")
249
+
250
+ if __name__ == "__main__":
251
+ args = argbind.parse_args()
252
+
253
+ with argbind.scope(args):
254
+ main()
scripts/exp/export.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ run_dir = Path("runs/sample-instrument")
4
+ name = run_dir.name
5
+
6
+ repo_dir = Path("models/vampnet")
7
+
8
+
9
+ for part in ("coarse", "c2f"):
10
+ outdir = repo_dir / "loras" / name
11
+ outdir.mkdir(parents=True, exist_ok=True)
12
+ outpath = outdir / f"{part}.pth"
13
+ path = run_dir / part / "latest" / "vampnet" / "weights.pth"
14
+ path.rename(outpath)
15
+ print(f"moved {path} to {outpath}")
16
+
17
+ # now, push to hub
18
+ from huggingface_hub import Repository
19
+ repo = Repository(repo_dir, git_user="hugofloresgarcia", git_email="[email protected]")
20
+ repo.push_to_hub(
21
+ commit_message=f"add {name}"
22
+ )
scripts/exp/fine_tune.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argbind
2
+ from pathlib import Path
3
+ import yaml
4
+ from typing import List
5
+
6
+
7
+
8
+
9
+ """example output: (yaml)
10
+
11
+ """
12
+
13
+ @argbind.bind(without_prefix=True, positional=True)
14
+ def fine_tune(audio_files_or_folders: List[str], name: str):
15
+
16
+ conf_dir = Path("conf")
17
+ assert conf_dir.exists(), "conf directory not found. are you in the vampnet directory?"
18
+
19
+ conf_dir = conf_dir / "generated"
20
+ conf_dir.mkdir(exist_ok=True)
21
+
22
+ finetune_dir = conf_dir / name
23
+ finetune_dir.mkdir(exist_ok=True)
24
+
25
+ finetune_c2f_conf = {
26
+ "$include": ["conf/lora/lora.yml"],
27
+ "fine_tune": True,
28
+ "train/AudioLoader.sources": audio_files_or_folders,
29
+ "val/AudioLoader.sources": audio_files_or_folders,
30
+ "VampNet.n_codebooks": 14,
31
+ "VampNet.n_conditioning_codebooks": 4,
32
+ "VampNet.embedding_dim": 1280,
33
+ "VampNet.n_layers": 16,
34
+ "VampNet.n_heads": 20,
35
+ "AudioDataset.duration": 3.0,
36
+ "AudioDataset.loudness_cutoff": -40.0,
37
+ "save_path": f"./runs/{name}/c2f",
38
+ "fine_tune_checkpoint": "./models/vampnet/c2f.pth"
39
+ }
40
+
41
+ finetune_coarse_conf = {
42
+ "$include": ["conf/lora/lora.yml"],
43
+ "fine_tune": True,
44
+ "train/AudioLoader.sources": audio_files_or_folders,
45
+ "val/AudioLoader.sources": audio_files_or_folders,
46
+ "save_path": f"./runs/{name}/coarse",
47
+ "fine_tune_checkpoint": "./models/vampnet/coarse.pth"
48
+ }
49
+
50
+ interface_conf = {
51
+ "Interface.coarse_ckpt": f"./runs/{name}/coarse/latest/vampnet/weights.pth",
52
+
53
+ "Interface.coarse2fine_ckpt": f"./runs/{name}/c2f/latest/vampnet/weights.pth",
54
+ "Interface.wavebeat_ckpt": "./models/wavebeat.pth",
55
+
56
+ "Interface.codec_ckpt": "./models/vampnet/codec.pth",
57
+ "AudioLoader.sources": [audio_files_or_folders],
58
+ }
59
+
60
+ # save the confs
61
+ with open(finetune_dir / "c2f.yml", "w") as f:
62
+ yaml.dump(finetune_c2f_conf, f)
63
+
64
+ with open(finetune_dir / "coarse.yml", "w") as f:
65
+ yaml.dump(finetune_coarse_conf, f)
66
+
67
+ with open(finetune_dir / "interface.yml", "w") as f:
68
+ yaml.dump(interface_conf, f)
69
+
70
+
71
+ print(f"generated confs in {finetune_dir}. run training jobs with `python scripts/exp/train.py --args.load {finetune_dir}/<c2f/coarse>.yml` ")
72
+
73
+ if __name__ == "__main__":
74
+ args = argbind.parse_args()
75
+
76
+ with argbind.scope(args):
77
+ fine_tune()
78
+
79
+
80
+
81
+
scripts/exp/train.py ADDED
@@ -0,0 +1,686 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import warnings
4
+ from pathlib import Path
5
+ from typing import Optional
6
+ from dataclasses import dataclass
7
+
8
+ import argbind
9
+ import audiotools as at
10
+ import torch
11
+ import torch.nn as nn
12
+ from audiotools import AudioSignal
13
+ from audiotools.data import transforms as tfm
14
+ from einops import rearrange
15
+ from rich import pretty
16
+ from rich.traceback import install
17
+ from torch.utils.tensorboard import SummaryWriter
18
+
19
+ import vampnet
20
+ from vampnet.modules.transformer import VampNet
21
+ from vampnet.util import codebook_unflatten, codebook_flatten
22
+ from vampnet import mask as pmask
23
+ # from dac.model.dac import DAC
24
+ from lac.model.lac import LAC as DAC
25
+
26
+ from audiotools.ml.decorators import (
27
+ timer, Tracker, when
28
+ )
29
+
30
+ import loralib as lora
31
+
32
+ import torch._dynamo
33
+ torch._dynamo.config.verbose=True
34
+
35
+
36
+ # Enable cudnn autotuner to speed up training
37
+ # (can be altered by the funcs.seed function)
38
+ torch.backends.cudnn.benchmark = bool(int(os.getenv("CUDNN_BENCHMARK", 1)))
39
+ # Uncomment to trade memory for speed.
40
+
41
+ # Install to make things look nice
42
+ warnings.filterwarnings("ignore", category=UserWarning)
43
+ pretty.install()
44
+ install()
45
+
46
+ # optim
47
+ Accelerator = argbind.bind(at.ml.Accelerator, without_prefix=True)
48
+ CrossEntropyLoss = argbind.bind(nn.CrossEntropyLoss)
49
+ AdamW = argbind.bind(torch.optim.AdamW)
50
+ NoamScheduler = argbind.bind(vampnet.scheduler.NoamScheduler)
51
+
52
+ # transforms
53
+ filter_fn = lambda fn: hasattr(fn, "transform") and fn.__qualname__ not in [
54
+ "BaseTransform",
55
+ "Compose",
56
+ "Choose",
57
+ ]
58
+
59
+ # model
60
+ VampNet = argbind.bind(VampNet)
61
+
62
+
63
+ # data
64
+ AudioLoader = argbind.bind(at.datasets.AudioLoader)
65
+ AudioDataset = argbind.bind(at.datasets.AudioDataset, "train", "val")
66
+
67
+ IGNORE_INDEX = -100
68
+
69
+
70
+ @argbind.bind("train", "val", without_prefix=True)
71
+ def build_transform():
72
+ transform = tfm.Compose(
73
+ tfm.VolumeNorm(("const", -24)),
74
+ # tfm.PitchShift(),
75
+ tfm.RescaleAudio(),
76
+ )
77
+ return transform
78
+
79
+
80
+ @torch.no_grad()
81
+ def apply_transform(transform_fn, batch):
82
+ sig: AudioSignal = batch["signal"]
83
+ kwargs = batch["transform_args"]
84
+
85
+ sig: AudioSignal = transform_fn(sig.clone(), **kwargs)
86
+ return sig
87
+
88
+
89
+ def build_datasets(args, sample_rate: int):
90
+ with argbind.scope(args, "train"):
91
+ train_data = AudioDataset(
92
+ AudioLoader(), sample_rate, transform=build_transform()
93
+ )
94
+ with argbind.scope(args, "val"):
95
+ val_data = AudioDataset(AudioLoader(), sample_rate, transform=build_transform())
96
+ return train_data, val_data
97
+
98
+
99
+ def rand_float(shape, low, high, rng):
100
+ return rng.draw(shape)[:, 0] * (high - low) + low
101
+
102
+
103
+ def flip_coin(shape, p, rng):
104
+ return rng.draw(shape)[:, 0] < p
105
+
106
+
107
+ def num_params_hook(o, p):
108
+ return o + f" {p/1e6:<.3f}M params."
109
+
110
+
111
+ def add_num_params_repr_hook(model):
112
+ import numpy as np
113
+ from functools import partial
114
+
115
+ for n, m in model.named_modules():
116
+ o = m.extra_repr()
117
+ p = sum([np.prod(p.size()) for p in m.parameters()])
118
+
119
+ setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p))
120
+
121
+
122
+ def accuracy(
123
+ preds: torch.Tensor,
124
+ target: torch.Tensor,
125
+ top_k: int = 1,
126
+ ignore_index: Optional[int] = None,
127
+ ) -> torch.Tensor:
128
+ # Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class)
129
+ preds = rearrange(preds, "b p s -> (b s) p")
130
+ target = rearrange(target, "b s -> (b s)")
131
+
132
+ # return torchmetrics.functional.accuracy(preds, target, task='multiclass', top_k=topk, num_classes=preds.shape[-1], ignore_index=ignore_index)
133
+ if ignore_index is not None:
134
+ # Create a mask for the ignored index
135
+ mask = target != ignore_index
136
+ # Apply the mask to the target and predictions
137
+ preds = preds[mask]
138
+ target = target[mask]
139
+
140
+ # Get the top-k predicted classes and their indices
141
+ _, pred_indices = torch.topk(preds, k=top_k, dim=-1)
142
+
143
+ # Determine if the true target is in the top-k predicted classes
144
+ correct = torch.sum(torch.eq(pred_indices, target.unsqueeze(1)), dim=1)
145
+
146
+ # Calculate the accuracy
147
+ accuracy = torch.mean(correct.float())
148
+
149
+ return accuracy
150
+
151
+ def _metrics(z_hat, r, target, flat_mask, output):
152
+ for r_range in [(0, 0.5), (0.5, 1.0)]:
153
+ unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX)
154
+ masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
155
+
156
+ assert target.shape[0] == r.shape[0]
157
+ # grab the indices of the r values that are in the range
158
+ r_idx = (r >= r_range[0]) & (r < r_range[1])
159
+
160
+ # grab the target and z_hat values that are in the range
161
+ r_unmasked_target = unmasked_target[r_idx]
162
+ r_masked_target = masked_target[r_idx]
163
+ r_z_hat = z_hat[r_idx]
164
+
165
+ for topk in (1, 25):
166
+ s, e = r_range
167
+ tag = f"accuracy-{s}-{e}/top{topk}"
168
+
169
+ output[f"{tag}/unmasked"] = accuracy(
170
+ preds=r_z_hat,
171
+ target=r_unmasked_target,
172
+ ignore_index=IGNORE_INDEX,
173
+ top_k=topk,
174
+ )
175
+ output[f"{tag}/masked"] = accuracy(
176
+ preds=r_z_hat,
177
+ target=r_masked_target,
178
+ ignore_index=IGNORE_INDEX,
179
+ top_k=topk,
180
+ )
181
+
182
+
183
+ @dataclass
184
+ class State:
185
+ model: VampNet
186
+ codec: DAC
187
+
188
+ optimizer: AdamW
189
+ scheduler: NoamScheduler
190
+ criterion: CrossEntropyLoss
191
+ grad_clip_val: float
192
+
193
+ rng: torch.quasirandom.SobolEngine
194
+
195
+ train_data: AudioDataset
196
+ val_data: AudioDataset
197
+
198
+ tracker: Tracker
199
+
200
+
201
+ @timer()
202
+ def train_loop(state: State, batch: dict, accel: Accelerator):
203
+ state.model.train()
204
+ batch = at.util.prepare_batch(batch, accel.device)
205
+ signal = apply_transform(state.train_data.transform, batch)
206
+
207
+ output = {}
208
+ vn = accel.unwrap(state.model)
209
+ with accel.autocast():
210
+ with torch.inference_mode():
211
+ state.codec.to(accel.device)
212
+ z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
213
+ z = z[:, : vn.n_codebooks, :]
214
+
215
+ n_batch = z.shape[0]
216
+ r = state.rng.draw(n_batch)[:, 0].to(accel.device)
217
+
218
+ mask = pmask.random(z, r)
219
+ mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
220
+ z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
221
+
222
+ z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
223
+
224
+ dtype = torch.bfloat16 if accel.amp else None
225
+ with accel.autocast(dtype=dtype):
226
+ z_hat = state.model(z_mask_latent)
227
+
228
+ target = codebook_flatten(
229
+ z[:, vn.n_conditioning_codebooks :, :],
230
+ )
231
+
232
+ flat_mask = codebook_flatten(
233
+ mask[:, vn.n_conditioning_codebooks :, :],
234
+ )
235
+
236
+ # replace target with ignore index for masked tokens
237
+ t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
238
+ output["loss"] = state.criterion(z_hat, t_masked)
239
+
240
+ _metrics(
241
+ r=r,
242
+ z_hat=z_hat,
243
+ target=target,
244
+ flat_mask=flat_mask,
245
+ output=output,
246
+ )
247
+
248
+
249
+ accel.backward(output["loss"])
250
+
251
+ output["other/learning_rate"] = state.optimizer.param_groups[0]["lr"]
252
+ output["other/batch_size"] = z.shape[0]
253
+
254
+
255
+ accel.scaler.unscale_(state.optimizer)
256
+ output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_(
257
+ state.model.parameters(), state.grad_clip_val
258
+ )
259
+
260
+ accel.step(state.optimizer)
261
+ state.optimizer.zero_grad()
262
+
263
+ state.scheduler.step()
264
+ accel.update()
265
+
266
+
267
+ return {k: v for k, v in sorted(output.items())}
268
+
269
+
270
+ @timer()
271
+ @torch.no_grad()
272
+ def val_loop(state: State, batch: dict, accel: Accelerator):
273
+ state.model.eval()
274
+ state.codec.eval()
275
+ batch = at.util.prepare_batch(batch, accel.device)
276
+ signal = apply_transform(state.val_data.transform, batch)
277
+
278
+ vn = accel.unwrap(state.model)
279
+ z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
280
+ z = z[:, : vn.n_codebooks, :]
281
+
282
+ n_batch = z.shape[0]
283
+ r = state.rng.draw(n_batch)[:, 0].to(accel.device)
284
+
285
+ mask = pmask.random(z, r)
286
+ mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
287
+ z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
288
+
289
+ z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
290
+
291
+ z_hat = state.model(z_mask_latent)
292
+
293
+ target = codebook_flatten(
294
+ z[:, vn.n_conditioning_codebooks :, :],
295
+ )
296
+
297
+ flat_mask = codebook_flatten(
298
+ mask[:, vn.n_conditioning_codebooks :, :]
299
+ )
300
+
301
+ output = {}
302
+ # replace target with ignore index for masked tokens
303
+ t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
304
+ output["loss"] = state.criterion(z_hat, t_masked)
305
+
306
+ _metrics(
307
+ r=r,
308
+ z_hat=z_hat,
309
+ target=target,
310
+ flat_mask=flat_mask,
311
+ output=output,
312
+ )
313
+
314
+ return output
315
+
316
+
317
+ def validate(state, val_dataloader, accel):
318
+ for batch in val_dataloader:
319
+ output = val_loop(state, batch, accel)
320
+ # Consolidate state dicts if using ZeroRedundancyOptimizer
321
+ if hasattr(state.optimizer, "consolidate_state_dict"):
322
+ state.optimizer.consolidate_state_dict()
323
+ return output
324
+
325
+
326
+ def checkpoint(state, save_iters, save_path, fine_tune):
327
+ if accel.local_rank != 0:
328
+ state.tracker.print(f"ERROR:Skipping checkpoint on rank {accel.local_rank}")
329
+ return
330
+
331
+ metadata = {"logs": dict(state.tracker.history)}
332
+
333
+ tags = ["latest"]
334
+ state.tracker.print(f"Saving to {str(Path('.').absolute())}")
335
+
336
+ if state.tracker.step in save_iters:
337
+ tags.append(f"{state.tracker.step // 1000}k")
338
+
339
+ if state.tracker.is_best("val", "loss"):
340
+ state.tracker.print(f"Best model so far")
341
+ tags.append("best")
342
+
343
+ if fine_tune:
344
+ for tag in tags:
345
+ # save the lora model
346
+ (Path(save_path) / tag).mkdir(parents=True, exist_ok=True)
347
+ torch.save(
348
+ lora.lora_state_dict(accel.unwrap(state.model)),
349
+ f"{save_path}/{tag}/lora.pth"
350
+ )
351
+
352
+ for tag in tags:
353
+ model_extra = {
354
+ "optimizer.pth": state.optimizer.state_dict(),
355
+ "scheduler.pth": state.scheduler.state_dict(),
356
+ "tracker.pth": state.tracker.state_dict(),
357
+ "metadata.pth": metadata,
358
+ }
359
+
360
+ accel.unwrap(state.model).metadata = metadata
361
+ accel.unwrap(state.model).save_to_folder(
362
+ f"{save_path}/{tag}", model_extra, package=False
363
+ )
364
+
365
+
366
+ def save_sampled(state, z, writer):
367
+ num_samples = z.shape[0]
368
+
369
+ for i in range(num_samples):
370
+ sampled = accel.unwrap(state.model).generate(
371
+ codec=state.codec,
372
+ time_steps=z.shape[-1],
373
+ start_tokens=z[i : i + 1],
374
+ )
375
+ sampled.cpu().write_audio_to_tb(
376
+ f"sampled/{i}",
377
+ writer,
378
+ step=state.tracker.step,
379
+ plot_fn=None,
380
+ )
381
+
382
+
383
+ def save_imputation(state, z, val_idx, writer):
384
+ n_prefix = int(z.shape[-1] * 0.25)
385
+ n_suffix = int(z.shape[-1] * 0.25)
386
+
387
+ vn = accel.unwrap(state.model)
388
+
389
+ mask = pmask.inpaint(z, n_prefix, n_suffix)
390
+ mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
391
+ z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
392
+
393
+ imputed_noisy = vn.decode(z_mask, state.codec)
394
+ imputed_true = vn.decode(z, state.codec)
395
+
396
+ imputed = []
397
+ for i in range(len(z)):
398
+ imputed.append(
399
+ vn.generate(
400
+ codec=state.codec,
401
+ time_steps=z.shape[-1],
402
+ start_tokens=z[i][None, ...],
403
+ mask=mask[i][None, ...],
404
+ )
405
+ )
406
+ imputed = AudioSignal.batch(imputed)
407
+
408
+ for i in range(len(val_idx)):
409
+ imputed_noisy[i].cpu().write_audio_to_tb(
410
+ f"inpainted_prompt/{i}",
411
+ writer,
412
+ step=state.tracker.step,
413
+ plot_fn=None,
414
+ )
415
+ imputed[i].cpu().write_audio_to_tb(
416
+ f"inpainted_middle/{i}",
417
+ writer,
418
+ step=state.tracker.step,
419
+ plot_fn=None,
420
+ )
421
+ imputed_true[i].cpu().write_audio_to_tb(
422
+ f"reconstructed/{i}",
423
+ writer,
424
+ step=state.tracker.step,
425
+ plot_fn=None,
426
+ )
427
+
428
+
429
+ @torch.no_grad()
430
+ def save_samples(state: State, val_idx: int, writer: SummaryWriter):
431
+ state.model.eval()
432
+ state.codec.eval()
433
+ vn = accel.unwrap(state.model)
434
+
435
+ batch = [state.val_data[i] for i in val_idx]
436
+ batch = at.util.prepare_batch(state.val_data.collate(batch), accel.device)
437
+
438
+ signal = apply_transform(state.val_data.transform, batch)
439
+
440
+ z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
441
+ z = z[:, : vn.n_codebooks, :]
442
+
443
+ r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device)
444
+
445
+
446
+ mask = pmask.random(z, r)
447
+ mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
448
+ z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
449
+
450
+ z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
451
+
452
+ z_hat = state.model(z_mask_latent)
453
+
454
+ z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
455
+ z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks)
456
+ z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
457
+
458
+ generated = vn.decode(z_pred, state.codec)
459
+ reconstructed = vn.decode(z, state.codec)
460
+ masked = vn.decode(z_mask.squeeze(1), state.codec)
461
+
462
+ for i in range(generated.batch_size):
463
+ audio_dict = {
464
+ "original": signal[i],
465
+ "masked": masked[i],
466
+ "generated": generated[i],
467
+ "reconstructed": reconstructed[i],
468
+ }
469
+ for k, v in audio_dict.items():
470
+ v.cpu().write_audio_to_tb(
471
+ f"onestep/_{i}.r={r[i]:0.2f}/{k}",
472
+ writer,
473
+ step=state.tracker.step,
474
+ plot_fn=None,
475
+ )
476
+
477
+ save_sampled(state=state, z=z, writer=writer)
478
+ save_imputation(state=state, z=z, val_idx=val_idx, writer=writer)
479
+
480
+
481
+
482
+ @argbind.bind(without_prefix=True)
483
+ def load(
484
+ args,
485
+ accel: at.ml.Accelerator,
486
+ tracker: Tracker,
487
+ save_path: str,
488
+ resume: bool = False,
489
+ tag: str = "latest",
490
+ fine_tune_checkpoint: Optional[str] = None,
491
+ grad_clip_val: float = 5.0,
492
+ ) -> State:
493
+ codec = DAC.load(args["codec_ckpt"], map_location="cpu")
494
+ codec.eval()
495
+
496
+ model, v_extra = None, {}
497
+
498
+ if args["fine_tune"]:
499
+ assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
500
+ model = torch.compile(
501
+ VampNet.load(location=Path(fine_tune_checkpoint),
502
+ map_location="cpu",
503
+ )
504
+ )
505
+
506
+ if resume:
507
+ kwargs = {
508
+ "folder": f"{save_path}/{tag}",
509
+ "map_location": "cpu",
510
+ "package": False,
511
+ }
512
+ tracker.print(f"Loading checkpoint from {kwargs['folder']}")
513
+ if (Path(kwargs["folder"]) / "vampnet").exists():
514
+ model, v_extra = VampNet.load_from_folder(**kwargs)
515
+ else:
516
+ raise ValueError(
517
+ f"Could not find a VampNet checkpoint in {kwargs['folder']}"
518
+ )
519
+
520
+
521
+
522
+
523
+ model = torch.compile(VampNet()) if model is None else model
524
+ model = accel.prepare_model(model)
525
+
526
+ # assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
527
+ assert (
528
+ accel.unwrap(model).vocab_size == codec.quantizer.quantizers[0].codebook_size
529
+ )
530
+
531
+
532
+ if accel.world_size > 1:
533
+ from torch.distributed.optim import ZeroRedundancyOptimizer
534
+ optimizer = ZeroRedundancyOptimizer(model.parameters(), AdamW)
535
+ print(f"OPTIMIZER LR is {optimizer.param_groups[0]['lr']}")
536
+ else:
537
+ optimizer = AdamW(model.parameters())
538
+
539
+ scheduler = NoamScheduler(optimizer, d_model=accel.unwrap(model).embedding_dim)
540
+ scheduler.step()
541
+
542
+ if "optimizer.pth" in v_extra:
543
+ optimizer.load_state_dict(v_extra["optimizer.pth"])
544
+ scheduler.load_state_dict(v_extra["scheduler.pth"])
545
+ if "tracker.pth" in v_extra:
546
+ tracker.load_state_dict(v_extra["tracker.pth"])
547
+
548
+ criterion = CrossEntropyLoss()
549
+
550
+ sample_rate = codec.sample_rate
551
+
552
+ # a better rng for sampling from our schedule
553
+ rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=args["seed"])
554
+
555
+ # log a model summary w/ num params
556
+ if accel.local_rank == 0:
557
+ add_num_params_repr_hook(accel.unwrap(model))
558
+ with open(f"{save_path}/model.txt", "w") as f:
559
+ f.write(repr(accel.unwrap(model)))
560
+
561
+ # load the datasets
562
+ train_data, val_data = build_datasets(args, sample_rate)
563
+
564
+ return State(
565
+ tracker=tracker,
566
+ model=model,
567
+ codec=codec,
568
+ optimizer=optimizer,
569
+ scheduler=scheduler,
570
+ criterion=criterion,
571
+ rng=rng,
572
+ train_data=train_data,
573
+ val_data=val_data,
574
+ grad_clip_val=grad_clip_val,
575
+ )
576
+
577
+
578
+ @argbind.bind(without_prefix=True)
579
+ def train(
580
+ args,
581
+ accel: at.ml.Accelerator,
582
+ seed: int = 0,
583
+ codec_ckpt: str = None,
584
+ save_path: str = "ckpt",
585
+ num_iters: int = int(1000e6),
586
+ save_iters: list = [10000, 50000, 100000, 300000, 500000,],
587
+ sample_freq: int = 10000,
588
+ val_freq: int = 1000,
589
+ batch_size: int = 12,
590
+ val_idx: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
591
+ num_workers: int = 10,
592
+ fine_tune: bool = False,
593
+ ):
594
+ assert codec_ckpt is not None, "codec_ckpt is required"
595
+
596
+ seed = seed + accel.local_rank
597
+ at.util.seed(seed)
598
+ writer = None
599
+
600
+ if accel.local_rank == 0:
601
+ writer = SummaryWriter(log_dir=f"{save_path}/logs/")
602
+ argbind.dump_args(args, f"{save_path}/args.yml")
603
+
604
+ tracker = Tracker(
605
+ writer=writer, log_file=f"{save_path}/log.txt", rank=accel.local_rank
606
+ )
607
+
608
+ # load the codec model
609
+ state: State = load(
610
+ args=args,
611
+ accel=accel,
612
+ tracker=tracker,
613
+ save_path=save_path)
614
+ print("initialized state.")
615
+
616
+ train_dataloader = accel.prepare_dataloader(
617
+ state.train_data,
618
+ start_idx=state.tracker.step * batch_size,
619
+ num_workers=num_workers,
620
+ batch_size=batch_size,
621
+ collate_fn=state.train_data.collate,
622
+ )
623
+ val_dataloader = accel.prepare_dataloader(
624
+ state.val_data,
625
+ start_idx=0,
626
+ num_workers=num_workers,
627
+ batch_size=batch_size,
628
+ collate_fn=state.val_data.collate,
629
+ persistent_workers=num_workers > 0,
630
+ )
631
+ print("initialized dataloader.")
632
+
633
+
634
+
635
+ if fine_tune:
636
+ lora.mark_only_lora_as_trainable(state.model)
637
+ print("marked only lora as trainable.")
638
+
639
+ # Wrap the functions so that they neatly track in TensorBoard + progress bars
640
+ # and only run when specific conditions are met.
641
+ global train_loop, val_loop, validate, save_samples, checkpoint
642
+
643
+ train_loop = tracker.log("train", "value", history=False)(
644
+ tracker.track("train", num_iters, completed=state.tracker.step)(train_loop)
645
+ )
646
+ val_loop = tracker.track("val", len(val_dataloader))(val_loop)
647
+ validate = tracker.log("val", "mean")(validate)
648
+
649
+ save_samples = when(lambda: accel.local_rank == 0)(save_samples)
650
+ checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
651
+
652
+ print("starting training loop.")
653
+ with tracker.live:
654
+ for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
655
+ train_loop(state, batch, accel)
656
+
657
+ last_iter = (
658
+ tracker.step == num_iters - 1 if num_iters is not None else False
659
+ )
660
+
661
+ if tracker.step % sample_freq == 0 or last_iter:
662
+ save_samples(state, val_idx, writer)
663
+
664
+ if tracker.step % val_freq == 0 or last_iter:
665
+ validate(state, val_dataloader, accel)
666
+ checkpoint(
667
+ state=state,
668
+ save_iters=save_iters,
669
+ save_path=save_path,
670
+ fine_tune=fine_tune)
671
+
672
+ # Reset validation progress bar, print summary since last validation.
673
+ tracker.done("val", f"Iteration {tracker.step}")
674
+
675
+ if last_iter:
676
+ break
677
+
678
+
679
+ if __name__ == "__main__":
680
+ args = argbind.parse_args()
681
+ args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0
682
+ with argbind.scope(args):
683
+ with Accelerator() as accel:
684
+ if accel.local_rank != 0:
685
+ sys.tracebacklimit = 0
686
+ train(args, accel)
scripts/utils/README.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scripts
2
+
3
+ ## process_zip.py
4
+
5
+ Some requirements that may not be installed in the docker image:
6
+ * argbind
7
+ * wav2wav (pip install git+https://github.com/descriptinc/lyrebird-wav2wav.git or `pip install git+https://github.com/descriptinc/lyrebird-wav2wav.git@<branchname>`)
8
+
9
+ ### zip folder structure
10
+
11
+ The zip folder should have the following internal structure:
12
+
13
+ ```
14
+ base_folder/
15
+ test_case_1/
16
+ before.wav
17
+ test_case_2/
18
+ before.wav
19
+ ...
20
+ test_case_n/
21
+ before.wav
22
+ ```
23
+
24
+ Note: There can be issues with the output zip if the input zip folder structure is too deep or too shallow. IF you want/need to use a zip file with a different folder structure, adjust this:
25
+ https://github.com/descriptinc/lyrebird-wav2wav/blob/136c923ce19df03876a515ca0ed83854710cfa30/scripts/utils/process_zip.py#L28
26
+
27
+ ### Execution
28
+ `python process_zip.py <path/to/zip> -tag <string>`
scripts/utils/gtzan_embeddings.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TODO: train a linear probe
3
+ usage:
4
+ python gtzan_embeddings.py --args.load conf/interface.yml --Interface.device cuda --path_to_gtzan /path/to/gtzan/genres_original --output_dir /path/to/output
5
+ """
6
+ from pathlib import Path
7
+ from typing import List
8
+
9
+ import audiotools as at
10
+ from audiotools import AudioSignal
11
+ import argbind
12
+ import torch
13
+ import numpy as np
14
+ import zipfile
15
+ import json
16
+
17
+ from vampnet.interface import Interface
18
+ import tqdm
19
+
20
+ # bind the Interface to argbind
21
+ Interface = argbind.bind(Interface)
22
+
23
+ DEBUG = False
24
+
25
+ def smart_plotly_export(fig, save_path):
26
+ img_format = save_path.split('.')[-1]
27
+ if img_format == 'html':
28
+ fig.write_html(save_path)
29
+ elif img_format == 'bytes':
30
+ return fig.to_image(format='png')
31
+ #TODO: come back and make this prettier
32
+ elif img_format == 'numpy':
33
+ import io
34
+ from PIL import Image
35
+
36
+ def plotly_fig2array(fig):
37
+ #convert Plotly fig to an array
38
+ fig_bytes = fig.to_image(format="png", width=1200, height=700)
39
+ buf = io.BytesIO(fig_bytes)
40
+ img = Image.open(buf)
41
+ return np.asarray(img)
42
+
43
+ return plotly_fig2array(fig)
44
+ elif img_format == 'jpeg' or 'png' or 'webp':
45
+ fig.write_image(save_path)
46
+ else:
47
+ raise ValueError("invalid image format")
48
+
49
+ def dim_reduce(emb, labels, save_path, n_components=3, method='tsne', title=''):
50
+ """
51
+ dimensionality reduction for visualization!
52
+ saves an html plotly figure to save_path
53
+ parameters:
54
+ emb (np.ndarray): the samples to be reduces with shape (samples, features)
55
+ labels (list): list of labels for embedding
56
+ save_path (str): path where u wanna save ur figure
57
+ method (str): umap, tsne, or pca
58
+ title (str): title for ur figure
59
+ returns:
60
+ proj (np.ndarray): projection vector with shape (samples, dimensions)
61
+ """
62
+ import pandas as pd
63
+ import plotly.express as px
64
+ if method == 'umap':
65
+ from umap import UMAP
66
+ reducer = umap.UMAP(n_components=n_components)
67
+ elif method == 'tsne':
68
+ from sklearn.manifold import TSNE
69
+ reducer = TSNE(n_components=n_components)
70
+ elif method == 'pca':
71
+ from sklearn.decomposition import PCA
72
+ reducer = PCA(n_components=n_components)
73
+ else:
74
+ raise ValueError
75
+
76
+ proj = reducer.fit_transform(emb)
77
+
78
+ if n_components == 2:
79
+ df = pd.DataFrame(dict(
80
+ x=proj[:, 0],
81
+ y=proj[:, 1],
82
+ instrument=labels
83
+ ))
84
+ fig = px.scatter(df, x='x', y='y', color='instrument',
85
+ title=title+f"_{method}")
86
+
87
+ elif n_components == 3:
88
+ df = pd.DataFrame(dict(
89
+ x=proj[:, 0],
90
+ y=proj[:, 1],
91
+ z=proj[:, 2],
92
+ instrument=labels
93
+ ))
94
+ fig = px.scatter_3d(df, x='x', y='y', z='z',
95
+ color='instrument',
96
+ title=title)
97
+ else:
98
+ raise ValueError("cant plot more than 3 components")
99
+
100
+ fig.update_traces(marker=dict(size=6,
101
+ line=dict(width=1,
102
+ color='DarkSlateGrey')),
103
+ selector=dict(mode='markers'))
104
+
105
+ return smart_plotly_export(fig, save_path)
106
+
107
+
108
+
109
+ # per JukeMIR, we want the emebddings from the middle layer?
110
+ def vampnet_embed(sig: AudioSignal, interface: Interface, layer=10):
111
+ with torch.inference_mode():
112
+ # preprocess the signal
113
+ sig = interface.preprocess(sig)
114
+
115
+ # get the coarse vampnet model
116
+ vampnet = interface.coarse
117
+
118
+ # get the tokens
119
+ z = interface.encode(sig)[:, :vampnet.n_codebooks, :]
120
+ z_latents = vampnet.embedding.from_codes(z, interface.codec)
121
+
122
+ # do a forward pass through the model, get the embeddings
123
+ _z, embeddings = vampnet(z_latents, return_activations=True)
124
+ # print(f"got embeddings with shape {embeddings.shape}")
125
+ # [layer, batch, time, n_dims]
126
+ # [20, 1, 600ish, 768]
127
+
128
+
129
+ # squeeze batch dim (1 bc layer should be dim 0)
130
+ assert embeddings.shape[1] == 1, f"expected batch dim to be 1, got {embeddings.shape[0]}"
131
+ embeddings = embeddings.squeeze(1)
132
+
133
+ num_layers = embeddings.shape[0]
134
+ assert layer < num_layers, f"layer {layer} is out of bounds for model with {num_layers} layers"
135
+
136
+ # do meanpooling over the time dimension
137
+ embeddings = embeddings.mean(dim=-2)
138
+ # [20, 768]
139
+
140
+ # return the embeddings
141
+ return embeddings
142
+
143
+ from dataclasses import dataclass, fields
144
+ @dataclass
145
+ class Embedding:
146
+ genre: str
147
+ filename: str
148
+ embedding: np.ndarray
149
+
150
+ def save(self, path):
151
+ """Save the Embedding object to a given path as a zip file."""
152
+ with zipfile.ZipFile(path, 'w') as archive:
153
+
154
+ # Save numpy array
155
+ with archive.open('embedding.npy', 'w') as f:
156
+ np.save(f, self.embedding)
157
+
158
+ # Save non-numpy data as json
159
+ non_numpy_data = {f.name: getattr(self, f.name) for f in fields(self) if f.name != 'embedding'}
160
+ with archive.open('data.json', 'w') as f:
161
+ f.write(json.dumps(non_numpy_data).encode('utf-8'))
162
+
163
+ @classmethod
164
+ def load(cls, path):
165
+ """Load the Embedding object from a given zip path."""
166
+ with zipfile.ZipFile(path, 'r') as archive:
167
+
168
+ # Load numpy array
169
+ with archive.open('embedding.npy') as f:
170
+ embedding = np.load(f)
171
+
172
+ # Load non-numpy data from json
173
+ with archive.open('data.json') as f:
174
+ data = json.loads(f.read().decode('utf-8'))
175
+
176
+ return cls(embedding=embedding, **data)
177
+
178
+
179
+ @argbind.bind(without_prefix=True)
180
+ def main(
181
+ path_to_gtzan: str = None,
182
+ cache_dir: str = "./.gtzan_emb_cache",
183
+ output_dir: str = "./gtzan_vampnet_embeddings",
184
+ layers: List[int] = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19]
185
+ ):
186
+ path_to_gtzan = Path(path_to_gtzan)
187
+ assert path_to_gtzan.exists(), f"{path_to_gtzan} does not exist"
188
+
189
+ cache_dir = Path(cache_dir)
190
+ output_dir = Path(output_dir)
191
+ output_dir.mkdir(exist_ok=True, parents=True)
192
+
193
+ # load our interface
194
+ # argbind will automatically load the default config,
195
+ interface = Interface()
196
+
197
+ # gtzan should have a folder for each genre, so let's get the list of genres
198
+ genres = [Path(x).name for x in path_to_gtzan.iterdir() if x.is_dir()]
199
+ print(f"Found {len(genres)} genres")
200
+ print(f"genres: {genres}")
201
+
202
+ # collect audio files, genres, and embeddings
203
+ data = []
204
+ for genre in genres:
205
+ audio_files = list(at.util.find_audio(path_to_gtzan / genre))
206
+ print(f"Found {len(audio_files)} audio files for genre {genre}")
207
+
208
+ for audio_file in tqdm.tqdm(audio_files, desc=f"embedding genre {genre}"):
209
+ # check if we have a cached embedding for this file
210
+ cached_path = (cache_dir / f"{genre}_{audio_file.stem}.emb")
211
+ if cached_path.exists():
212
+ # if so, load it
213
+ if DEBUG:
214
+ print(f"loading cached embedding for {cached_path.stem}")
215
+ embedding = Embedding.load(cached_path)
216
+ else:
217
+ try:
218
+ sig = AudioSignal(audio_file)
219
+ except Exception as e:
220
+ print(f"failed to load {audio_file.name} with error {e}")
221
+ print(f"skipping {audio_file.name}")
222
+ continue
223
+
224
+ # gets the embedding
225
+ emb = vampnet_embed(sig, interface).cpu().numpy()
226
+
227
+ # create an embedding we can save/load
228
+ embedding = Embedding(
229
+ genre=genre,
230
+ filename=audio_file.name,
231
+ embedding=emb
232
+ )
233
+
234
+ # cache the embeddings
235
+ cached_path.parent.mkdir(exist_ok=True, parents=True)
236
+ embedding.save(cached_path)
237
+ data.append(embedding)
238
+
239
+ # now, let's do a dim reduction on the embeddings
240
+ # and visualize them.
241
+
242
+ # collect a list of embeddings and labels
243
+ embeddings = [d.embedding for d in data]
244
+ labels = [d.genre for d in data]
245
+
246
+ # convert the embeddings to a numpy array
247
+ embeddings = np.stack(embeddings)
248
+
249
+ # do dimensionality reduction for each layer we're given
250
+ for layer in tqdm.tqdm(layers, desc="dim reduction"):
251
+ dim_reduce(
252
+ embeddings[:, layer, :], labels,
253
+ save_path=str(output_dir / f'vampnet-gtzan-layer={layer}.html'),
254
+ n_components=2, method='tsne',
255
+ title=f'vampnet-gtzan-layer={layer}'
256
+ )
257
+
258
+
259
+
260
+
261
+ if __name__ == "__main__":
262
+ args = argbind.parse_args()
263
+ with argbind.scope(args):
264
+ main()
scripts/utils/plots.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import seaborn as sns
3
+ from pandas.api.types import CategoricalDtype
4
+
5
+ def plot_metrics(metrics, condition_to_latex, title, color_palette):
6
+ # Add a new column to your dataframe with the latex representation
7
+ metrics['condition_latex'] = metrics['condition'].map(condition_to_latex)
8
+
9
+ # Order condition_latex as per the condition_to_latex dictionary
10
+ cat_type = CategoricalDtype(categories=condition_to_latex.values(), ordered=True)
11
+ metrics['condition_latex'] = metrics['condition_latex'].astype(cat_type)
12
+
13
+ # Compute mean and std for each condition for each metric
14
+ grouped = metrics.groupby('condition_latex')[['mel', 'frechet']].agg(['mean', 'std'])
15
+
16
+ fig, axs = plt.subplots(2, 1, figsize=(7, 5.25))
17
+
18
+ # Set the main title for the figure
19
+ fig.suptitle(title, fontsize=16)
20
+
21
+ # Get color for each bar in the plot
22
+ bar_colors = [color_palette[condition] for condition in grouped.index]
23
+
24
+ # Plot mel
25
+ sns.boxplot(x='condition_latex', y='mel', data=metrics, ax=axs[0], palette=color_palette, showfliers=False)
26
+ axs[0].set_ylabel('Mel Spectrogram Loss \u2190')
27
+ axs[0].set_xlabel('') # Remove x-axis label
28
+ axs[0].set_xticklabels(grouped.index, rotation=0, ha='center')
29
+
30
+ # Plot frechet
31
+ axs[1].bar(grouped.index, grouped['frechet']['mean'], yerr=grouped['frechet']['std'], color=bar_colors)
32
+ axs[1].set_ylabel('FAD \u2190')
33
+ axs[1].set_xlabel('') # Remove x-axis label
34
+ axs[1].set_xticklabels(grouped.index, rotation=0, ha='center')
35
+
36
+ # Adjust the space between plots
37
+ plt.subplots_adjust(hspace=0.1)
38
+
39
+ # Remove any unnecessary space around the plot
40
+ plt.tight_layout(rect=[0, 0, 1, 0.96])
41
+
42
+ # Reduce the space between suptitle and the plot
43
+ plt.subplots_adjust(top=0.92)
scripts/utils/remove_quiet_files.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # removes files with loudness below 24db
2
+
3
+ from pathlib import Path
4
+ import shutil
5
+ import audiotools as at
6
+ import argbind
7
+
8
+ @argbind.bind(without_prefix=True)
9
+ def remove_quiet_files(
10
+ src_dir: Path = None,
11
+ dest_dir: Path = None,
12
+ min_loudness: float = -30,
13
+ ):
14
+ # copy src to dest
15
+ dest_dir.mkdir(parents=True, exist_ok=True)
16
+ shutil.copytree(src_dir, dest_dir, dirs_exist_ok=True)
17
+
18
+ audio_files = at.util.find_audio(dest_dir)
19
+ for audio_file in audio_files:
20
+ sig = at.AudioSignal(audio_file)
21
+ if sig.loudness() < min_loudness:
22
+ audio_file.unlink()
23
+ print(f"removed {audio_file}")
24
+
25
+ if __name__ == "__main__":
26
+ args = argbind.parse_args()
27
+
28
+ with argbind.scope(args):
29
+ remove_quiet_files()
scripts/utils/split.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import random
3
+ import shutil
4
+ import os
5
+ import json
6
+
7
+ import argbind
8
+ from tqdm import tqdm
9
+ from tqdm.contrib.concurrent import thread_map
10
+
11
+ from audiotools.core import util
12
+
13
+
14
+ @argbind.bind(without_prefix=True)
15
+ def train_test_split(
16
+ audio_folder: str = ".",
17
+ test_size: float = 0.2,
18
+ seed: int = 42,
19
+ ):
20
+ print(f"finding audio")
21
+
22
+ audio_folder = Path(audio_folder)
23
+ audio_files = util.find_audio(audio_folder)
24
+ print(f"found {len(audio_files)} audio files")
25
+
26
+ # split according to test_size
27
+ n_test = int(len(audio_files) * test_size)
28
+ n_train = len(audio_files) - n_test
29
+
30
+ # shuffle
31
+ random.seed(seed)
32
+ random.shuffle(audio_files)
33
+
34
+ train_files = audio_files[:n_train]
35
+ test_files = audio_files[n_train:]
36
+
37
+
38
+ print(f"Train files: {len(train_files)}")
39
+ print(f"Test files: {len(test_files)}")
40
+ continue_ = input("Continue [yn]? ") or "n"
41
+
42
+ if continue_ != "y":
43
+ return
44
+
45
+ for split, files in (
46
+ ("train", train_files), ("test", test_files)
47
+ ):
48
+ for file in tqdm(files):
49
+ out_file = audio_folder.parent / f"{audio_folder.name}-{split}" / Path(file).name
50
+ out_file.parent.mkdir(exist_ok=True, parents=True)
51
+ try:
52
+ os.symlink(file, out_file)
53
+ except FileExistsError:
54
+ print(f"File {out_file} already exists, skipping")
55
+
56
+ # save split as json
57
+ with open(Path(audio_folder) / f"{split}.json", "w") as f:
58
+ json.dump([str(f) for f in files], f)
59
+
60
+
61
+
62
+ if __name__ == "__main__":
63
+ args = argbind.parse_args()
64
+
65
+ with argbind.scope(args):
66
+ train_test_split()
scripts/utils/split_long_audio_file.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import argbind
3
+
4
+ import audiotools as at
5
+ import tqdm
6
+
7
+
8
+ @argbind.bind(without_prefix=True)
9
+ def split_long_audio_file(
10
+ file: str = None,
11
+ max_chunk_size_s: int = 60*10
12
+ ):
13
+ file = Path(file)
14
+ output_dir = file.parent / file.stem
15
+ output_dir.mkdir()
16
+
17
+ sig = at.AudioSignal(file)
18
+
19
+ # split into chunks
20
+ for i, sig in tqdm.tqdm(enumerate(sig.windows(
21
+ window_duration=max_chunk_size_s, hop_duration=max_chunk_size_s/2,
22
+ preprocess=True))
23
+ ):
24
+ sig.write(output_dir / f"{i}.wav")
25
+
26
+ print(f"wrote {len(list(output_dir.glob('*.wav')))} files to {output_dir}")
27
+
28
+ return output_dir
29
+
30
+ if __name__ == "__main__":
31
+ args = argbind.parse_args()
32
+
33
+ with argbind.scope(args):
34
+ split_long_audio_file()
scripts/utils/stage.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from pathlib import Path
4
+
5
+ import argbind
6
+ import rich
7
+ from audiotools.ml import Experiment
8
+
9
+
10
+ @argbind.bind(without_prefix=True)
11
+ def run(
12
+ run_dir: str = os.getenv("PATH_TO_RUNS", "runs"),
13
+ name: str = None,
14
+ recent: bool = False,
15
+ ):
16
+ if recent:
17
+ paths = sorted(Path(run_dir).iterdir(), key=os.path.getmtime)
18
+ paths = [p.name for p in paths if p.is_dir()]
19
+ if paths:
20
+ name = paths[-1]
21
+
22
+ with Experiment(run_dir, name) as exp:
23
+ exp.snapshot()
24
+ rich.print(f"Created a snapshot of {exp.parent_directory} at {exp.exp_dir}")
25
+
26
+
27
+ if __name__ == "__main__":
28
+ args = argbind.parse_args()
29
+ with argbind.scope(args):
30
+ run()
scripts/utils/visualize_embeddings.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TODO: train a linear probe
3
+ usage:
4
+ python gtzan_embeddings.py --args.load conf/interface.yml --Interface.device cuda --path_to_audio /path/to/audio/labels --output_dir /path/to/output
5
+ """
6
+ from pathlib import Path
7
+ from typing import List
8
+
9
+ import audiotools as at
10
+ from audiotools import AudioSignal
11
+ import argbind
12
+ import torch
13
+ import numpy as np
14
+ import zipfile
15
+ import json
16
+
17
+ from vampnet.interface import Interface
18
+ import tqdm
19
+
20
+ # bind the Interface to argbind
21
+ Interface = argbind.bind(Interface)
22
+
23
+ DEBUG = False
24
+
25
+
26
+ def smart_plotly_export(fig, save_path: Path):
27
+ img_format = save_path.suffix[1:]
28
+ if img_format == "html":
29
+ fig.write_html(save_path)
30
+ elif img_format == 'bytes':
31
+ return fig.to_image(format='png')
32
+ #TODO: come back and make this prettier
33
+ elif img_format == 'numpy':
34
+ import io
35
+ from PIL import Image
36
+
37
+ def plotly_fig2array(fig):
38
+ #convert Plotly fig to an array
39
+ fig_bytes = fig.to_image(format="png", width=1200, height=700)
40
+ buf = io.BytesIO(fig_bytes)
41
+ img = Image.open(buf)
42
+ return np.asarray(img)
43
+
44
+ return plotly_fig2array(fig)
45
+ elif img_format == 'jpeg' or 'png' or 'webp':
46
+ fig.write_image(save_path)
47
+ else:
48
+ raise ValueError("invalid image format")
49
+
50
+
51
+ def dim_reduce(annotated_embeddings, layer, output_dir, n_components=3, method="tsne"):
52
+ """
53
+ dimensionality reduction for visualization!
54
+ saves an html plotly figure to save_path
55
+ parameters:
56
+ annotated_embeddings (list): the annotated enmbeddings to be reduced; embeddings have shape (samples, features)
57
+ labels (list): list of labels for embedding
58
+ save_path (str): path where u wanna save ur figure
59
+ method (str): umap, tsne, or pca
60
+ title (str): title for ur figure
61
+ returns:
62
+ proj (np.ndarray): projection vector with shape (samples, dimensions)
63
+ """
64
+ import pandas as pd
65
+ import plotly.express as px
66
+
67
+ fig_name = f"vampnet-embeddings-layer={layer}"
68
+ fig_title = f"{fig_name}_{method}"
69
+ save_path = (output_dir / fig_name).with_suffix(".html")
70
+
71
+ if method == "umap":
72
+ from umap import UMAP
73
+ reducer = umap.UMAP(n_components=n_components)
74
+ elif method == "tsne":
75
+ from sklearn.manifold import TSNE
76
+
77
+ reducer = TSNE(n_components=n_components)
78
+ elif method == "pca":
79
+ from sklearn.decomposition import PCA
80
+
81
+ reducer = PCA(n_components=n_components)
82
+ else:
83
+ raise ValueError(f"invalid method: {method}")
84
+
85
+ labels = [emb.label for emb in annotated_embeddings]
86
+ names = [emb.filename for emb in annotated_embeddings]
87
+ embs = [emb.embedding for emb in annotated_embeddings]
88
+ embs_at_layer = np.stack(embs)[:, layer, :]
89
+ projs = reducer.fit_transform(embs_at_layer)
90
+
91
+ df = pd.DataFrame(
92
+ {
93
+ "label": labels,
94
+ "name": names,
95
+ "x": projs[:, 0],
96
+ "y": projs[:, 1],
97
+ }
98
+ )
99
+ if n_components == 2:
100
+ fig = px.scatter(
101
+ df, x="x", y="y", color="label", hover_name="name", title=fig_title,
102
+ )
103
+
104
+ elif n_components == 3:
105
+ df['z'] = projs[:, 2]
106
+ fig = px.scatter_3d(
107
+ df, x="x", y="y", z="z", color="label", hover_name="name", title=fig_title
108
+ )
109
+ else:
110
+ raise ValueError(f"can't plot {n_components} components")
111
+
112
+ fig.update_traces(
113
+ marker=dict(size=6, line=dict(width=1, color="DarkSlateGrey")),
114
+ selector=dict(mode="markers"),
115
+ )
116
+
117
+ return smart_plotly_export(fig, save_path)
118
+
119
+
120
+
121
+ # per JukeMIR, we want the emebddings from the middle layer?
122
+ def vampnet_embed(sig: AudioSignal, interface: Interface, layer=10):
123
+ with torch.inference_mode():
124
+ # preprocess the signal
125
+ sig = interface.preprocess(sig)
126
+
127
+ # get the coarse vampnet model
128
+ vampnet = interface.coarse
129
+
130
+ # get the tokens
131
+ z = interface.encode(sig)[:, :vampnet.n_codebooks, :]
132
+ z_latents = vampnet.embedding.from_codes(z, interface.codec)
133
+
134
+ # do a forward pass through the model, get the embeddings
135
+ _z, embeddings = vampnet(z_latents, return_activations=True)
136
+ # print(f"got embeddings with shape {embeddings.shape}")
137
+ # [layer, batch, time, n_dims]
138
+ # [20, 1, 600ish, 768]
139
+
140
+
141
+ # squeeze batch dim (1 bc layer should be dim 0)
142
+ assert embeddings.shape[1] == 1, f"expected batch dim to be 1, got {embeddings.shape[0]}"
143
+ embeddings = embeddings.squeeze(1)
144
+
145
+ num_layers = embeddings.shape[0]
146
+ assert layer < num_layers, f"layer {layer} is out of bounds for model with {num_layers} layers"
147
+
148
+ # do meanpooling over the time dimension
149
+ embeddings = embeddings.mean(dim=-2)
150
+ # [20, 768]
151
+
152
+ # return the embeddings
153
+ return embeddings
154
+
155
+ from dataclasses import dataclass, fields
156
+ @dataclass
157
+ class AnnotatedEmbedding:
158
+ label: str
159
+ filename: str
160
+ embedding: np.ndarray
161
+
162
+ def save(self, path):
163
+ """Save the Embedding object to a given path as a zip file."""
164
+ with zipfile.ZipFile(path, 'w') as archive:
165
+
166
+ # Save numpy array
167
+ with archive.open('embedding.npy', 'w') as f:
168
+ np.save(f, self.embedding)
169
+
170
+ # Save non-numpy data as json
171
+ non_numpy_data = {f.name: getattr(self, f.name) for f in fields(self) if f.name != 'embedding'}
172
+ with archive.open('data.json', 'w') as f:
173
+ f.write(json.dumps(non_numpy_data).encode('utf-8'))
174
+
175
+ @classmethod
176
+ def load(cls, path):
177
+ """Load the Embedding object from a given zip path."""
178
+ with zipfile.ZipFile(path, 'r') as archive:
179
+
180
+ # Load numpy array
181
+ with archive.open('embedding.npy') as f:
182
+ embedding = np.load(f)
183
+
184
+ # Load non-numpy data from json
185
+ with archive.open('data.json') as f:
186
+ data = json.loads(f.read().decode('utf-8'))
187
+
188
+ return cls(embedding=embedding, **data)
189
+
190
+
191
+ @argbind.bind(without_prefix=True)
192
+ def main(
193
+ path_to_audio: str = None,
194
+ cache_dir: str = "./.emb_cache",
195
+ output_dir: str = "./vampnet_embeddings",
196
+ layers: List[int] = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19],
197
+ method: str = "tsne",
198
+ n_components: int = 2,
199
+ ):
200
+ path_to_audio = Path(path_to_audio)
201
+ assert path_to_audio.exists(), f"{path_to_audio} does not exist"
202
+
203
+ cache_dir = Path(cache_dir)
204
+ output_dir = Path(output_dir)
205
+ output_dir.mkdir(exist_ok=True, parents=True)
206
+
207
+ # load our interface
208
+ # argbind will automatically load the default config,
209
+ interface = Interface()
210
+
211
+ # we expect path_to_audio to consist of a folder for each label, so let's get the list of labels
212
+ labels = [Path(x).name for x in path_to_audio.iterdir() if x.is_dir()]
213
+ print(f"Found {len(labels)} labels")
214
+ print(f"labels: {labels}")
215
+
216
+ # collect audio files, labels, and embeddings
217
+ annotated_embeddings = []
218
+ for label in labels:
219
+ audio_files = list(at.util.find_audio(path_to_audio / label))
220
+ print(f"Found {len(audio_files)} audio files for label {label}")
221
+
222
+ for audio_file in tqdm.tqdm(audio_files, desc=f"embedding label {label}"):
223
+ # check if we have a cached embedding for this file
224
+ cached_path = cache_dir / f"{label}_{audio_file.stem}.emb"
225
+ if cached_path.exists():
226
+ # if so, load it
227
+ if DEBUG:
228
+ print(f"loading cached embedding for {cached_path.stem}")
229
+ embedding = AnnotatedEmbedding.load(cached_path)
230
+ else:
231
+ try:
232
+ sig = AudioSignal(audio_file)
233
+ except Exception as e:
234
+ print(f"failed to load {audio_file.name} with error {e}")
235
+ print(f"skipping {audio_file.name}")
236
+ continue
237
+
238
+ # gets the embedding
239
+ emb = vampnet_embed(sig, interface).cpu().numpy()
240
+
241
+ # create an embedding we can save/load
242
+ embedding = AnnotatedEmbedding(
243
+ label=label, filename=audio_file.name, embedding=emb
244
+ )
245
+
246
+ # cache the embeddings
247
+ cached_path.parent.mkdir(exist_ok=True, parents=True)
248
+ embedding.save(cached_path)
249
+ annotated_embeddings.append(embedding)
250
+
251
+ # now, let's do a dim reduction on the embeddings and visualize them.
252
+ for layer in tqdm.tqdm(layers, desc="dim reduction"):
253
+ dim_reduce(
254
+ annotated_embeddings,
255
+ layer,
256
+ output_dir=output_dir,
257
+ n_components=n_components,
258
+ method=method,
259
+ )
260
+
261
+
262
+ if __name__ == "__main__":
263
+ args = argbind.parse_args()
264
+ with argbind.scope(args):
265
+ main()
scripts/utils/xeno-canto-dl.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from xenopy import Query
2
+
3
+
4
+ SPECIES = [
5
+ "American Robin",
6
+ "Northern Cardinal",
7
+ "Mourning Dove",
8
+ "American Crow",
9
+ "Baltimore Oriole",
10
+ "Blue Jay",
11
+ "Eastern Bluebird",
12
+ "House Finch",
13
+ "American Goldfinch",
14
+ "House Sparrow",
15
+ "Song Sparrow",
16
+ "Tufted Titmouse",
17
+ "White-breasted Nuthatch",
18
+ "European Starling",
19
+ "American Redstart",
20
+ "Red-winged Blackbird",
21
+ "Brown-headed Cowbird",
22
+ "Common Grackle",
23
+ "Boat-tailed Grackle",
24
+ "Common Yellowthroat",
25
+ "Northern Mockingbird",
26
+ "Carolina Wren",
27
+ "Eastern Meadowlark",
28
+ "Chipping Sparrow",
29
+ "Tree Swallow",
30
+ "Barn Swallow",
31
+ "Cliff Swallow",
32
+ "Pine Siskin",
33
+ "Indigo Bunting",
34
+ "Eastern Towhee",
35
+ "Carolina Chickadee",
36
+ "Great Crested Flycatcher",
37
+ "Eastern Wood-Pewee",
38
+ "Ovenbird",
39
+ "Northern Flicker",
40
+ "Red-eyed Vireo",
41
+ "American Woodcock",
42
+ "Eastern Phoebe",
43
+ "Downy Woodpecker",
44
+ "Scarlet Tanager",
45
+ "Yellow Warbler",
46
+ "White-eyed Vireo",
47
+ "Common Loon",
48
+ "White-throated Sparrow",
49
+ "Yellow-throated Vireo",
50
+ "Great Blue Heron",
51
+ "Belted Kingfisher",
52
+ "Pied-billed Grebe",
53
+ "Wild Turkey",
54
+ "Wood Thrush",
55
+ "Rose-breasted Grosbeak",
56
+ "Field Sparrow",
57
+ "Hooded Warbler",
58
+ "Northern Parula",
59
+ "Chestnut-sided Warbler",
60
+ "Blue-winged Warbler",
61
+ "Red-bellied Woodpecker",
62
+ "Yellow-billed Cuckoo",
63
+ "Gray Catbird",
64
+ "Northern Saw-whet Owl",
65
+ "Osprey",
66
+ "Common Nighthawk",
67
+ "Broad-winged Hawk",
68
+ "Black-throated Green Warbler",
69
+ "Great Horned Owl",
70
+ "Common Raven",
71
+ "Barred Owl",
72
+ "Canada Warbler",
73
+ "Magnolia Warbler",
74
+ "Black-and-white Warbler",
75
+ "Eastern Kingbird",
76
+ "Swainson's Thrush",
77
+ "Worm-eating Warbler",
78
+ "Prairie Warbler",
79
+ "Baltimore Oriole",
80
+ "Black-throated Blue Warbler",
81
+ "Louisiana Waterthrush",
82
+ "Blackburnian Warbler",
83
+ "Black-capped Chickadee",
84
+ "Cerulean Warbler",
85
+ "Red-shouldered Hawk",
86
+ "Cooper's Hawk",
87
+ "Yellow-throated Warbler",
88
+ "Blue-headed Vireo",
89
+ "Blackpoll Warbler",
90
+ "Ruffed Grouse",
91
+ "Kentucky Warbler",
92
+ "Hermit Thrush",
93
+ "Cedar Waxwing",
94
+ "Eastern Screech-Owl",
95
+ "Northern Goshawk",
96
+ "Green Heron",
97
+ "Red-tailed Hawk",
98
+ "Black Vulture",
99
+ "Hairy Woodpecker",
100
+ "Golden-crowned Kinglet",
101
+ "Ruby-crowned Kinglet",
102
+ "Bicknell's Thrush",
103
+ "Blue-gray Gnatcatcher",
104
+ "Veery",
105
+ "Pileated Woodpecker",
106
+ "Purple Finch",
107
+ "White-crowned Sparrow",
108
+ "Snow Bunting",
109
+ "Pine Grosbeak",
110
+ "American Tree Sparrow",
111
+ "Dark-eyed Junco",
112
+ "Snowy Owl",
113
+ "White-winged Crossbill",
114
+ "Red Crossbill",
115
+ "Common Redpoll",
116
+ "Northern Shrike",
117
+ "Northern Harrier",
118
+ "Rough-legged Hawk",
119
+ "Long-eared Owl",
120
+ "Evening Grosbeak",
121
+ "Northern Pintail",
122
+ "American Black Duck",
123
+ "Mallard",
124
+ "Canvasback",
125
+ "Redhead",
126
+ "Ring-necked Duck",
127
+ "Greater Scaup",
128
+ "Lesser Scaup",
129
+ "Bufflehead",
130
+ "Common Goldeneye",
131
+ "Hooded Merganser",
132
+ "Common Merganser",
133
+ "Red-breasted Merganser",
134
+ "Ruddy Duck",
135
+ "Wood Duck",
136
+ "Gadwall",
137
+ "American Wigeon",
138
+ "Northern Shoveler",
139
+ "Green-winged Teal",
140
+ "Blue-winged Teal",
141
+ "Cinnamon Teal",
142
+ "Ringed Teal",
143
+ "Cape Teal",
144
+ "Northern Fulmar",
145
+ "Yellow-billed Loon",
146
+ "Red-throated Loon",
147
+ "Arctic Loon",
148
+ "Pacific Loon",
149
+ "Horned Grebe",
150
+ "Red-necked Grebe",
151
+ "Eared Grebe",
152
+ "Western Grebe",
153
+ "Clark's Grebe",
154
+ "Double-crested Cormorant",
155
+ "Pelagic Cormorant",
156
+ "Great Cormorant",
157
+ "American White Pelican",
158
+ "Brown Pelican",
159
+ "Brandt's Cormorant",
160
+ "Least Bittern",
161
+ "Great Egret",
162
+ "Snowy Egret",
163
+ "Little Blue Heron",
164
+ "Tricolored Heron",
165
+ "Reddish Egret",
166
+ "Black-crowned Night-Heron",
167
+ "Yellow-crowned Night-Heron",
168
+ "White Ibis",
169
+ "Glossy Ibis",
170
+ "Roseate Spoonbill",
171
+ "Wood Stork",
172
+ "Black-bellied Whistling-Duck",
173
+ "Fulvous Whistling-Duck",
174
+ "Greater White-fronted Goose",
175
+ "Snow Goose",
176
+ "Ross's Goose",
177
+ "Canada Goose",
178
+ "Brant",
179
+ "Mute Swan",
180
+ "Tundra Swan",
181
+ "Whooper Swan",
182
+ "Sandhill Crane",
183
+ "Black-necked Stilt",
184
+ "American Avocet",
185
+ "Northern Jacana",
186
+ "Greater Yellowlegs",
187
+ "Lesser Yellowlegs",
188
+ "Willet",
189
+ "Spotted Sandpiper",
190
+ "Upland Sandpiper",
191
+ "Whimbrel",
192
+ "Long-billed Curlew",
193
+ "Marbled Godwit",
194
+ "Ruddy Turnstone",
195
+ "Red Knot",
196
+ "Sanderling",
197
+ "Semipalmated Sandpiper",
198
+ "Western Sandpiper",
199
+ "Least Sandpiper",
200
+ "White-rumped Sandpiper",
201
+ "Baird's Sandpiper",
202
+ "Pectoral Sandpiper",
203
+ "Dunlin",
204
+ "Buff-breasted Sandpiper",
205
+ "Short-billed Dowitcher",
206
+ "Long-billed Dowitcher",
207
+ "Common Snipe",
208
+ "American Woodcock",
209
+ "Wilson's Phalarope",
210
+ "Red-necked Phalarope",
211
+ "Red Phalarope"
212
+ ]
213
+
214
+ from pathlib import Path
215
+
216
+ def remove_spaces(s):
217
+ return s.replace(" ", "")
218
+
219
+ for species in SPECIES:
220
+ if Path("/media/CHONK/hugo/xeno-canto-full/" + remove_spaces(species)).exists():
221
+ continue
222
+ try:
223
+ q = Query(
224
+ name=species, q="A", length="10-30",
225
+ )
226
+
227
+ # retrieve metadata
228
+ metafiles = q.retrieve_meta(verbose=True)
229
+ # retrieve recordings
230
+ q.retrieve_recordings(multiprocess=True, nproc=10, attempts=10, outdir="/media/CHONK/hugo/xeno-canto-full/")
231
+
232
+ except:
233
+ print("Failed to download " + species)
234
+ continue
setup.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import find_packages
2
+ from setuptools import setup
3
+
4
+ with open("README.md") as f:
5
+ long_description = f.read()
6
+
7
+ setup(
8
+ name="vampnet",
9
+ version="0.0.1",
10
+ classifiers=[
11
+ "Intended Audience :: Developers",
12
+ "Natural Language :: English",
13
+ "Programming Language :: Python :: 3.7",
14
+ "Topic :: Artistic Software",
15
+ "Topic :: Multimedia",
16
+ "Topic :: Multimedia :: Sound/Audio",
17
+ "Topic :: Multimedia :: Sound/Audio :: Editors",
18
+ "Topic :: Software Development :: Libraries",
19
+ ],
20
+ description="Generative Music Modeling.",
21
+ long_description=long_description,
22
+ long_description_content_type="text/markdown",
23
+ author="Hugo Flores García, Prem Seetharaman",
24
+ author_email="[email protected]",
25
+ url="https://github.com/hugofloresgarcia/vampnet",
26
+ license="MIT",
27
+ packages=find_packages(),
28
+ install_requires=[
29
+ "torch",
30
+ "argbind>=0.3.2",
31
+ "numpy==1.23",
32
+ "wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat",
33
+ "lac @ git+https://github.com/hugofloresgarcia/lac.git",
34
+ "descript-audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git",
35
+ "gradio",
36
+ "loralib",
37
+ "torch_pitch_shift",
38
+ "plotly",
39
+ ],
40
+ )
token_telephone/tt.py ADDED
@@ -0,0 +1,616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ttutil import hsv_to_rgb, dbg, log, set_debug, pow2db, db2pow
2
+ from dataclasses import dataclass, field
3
+ import os
4
+ from pathlib import Path
5
+ import random
6
+ import time
7
+ from threading import Thread
8
+ import gc
9
+ gc.disable()
10
+
11
+ import sounddevice as sd
12
+
13
+ from blessed import Terminal
14
+
15
+ import numpy as np
16
+ import torch
17
+ from einops import rearrange
18
+
19
+ PROFILE = False
20
+ DEBUG = False
21
+ DEBUG_NO_VAMPNET = False
22
+ set_debug(DEBUG)
23
+ # if DEBUG:
24
+ # import gc
25
+ # # log when gc start and stops
26
+ # gc.set_debug(gc.DEBUG_STATS)
27
+
28
+ @dataclass
29
+ class LoadState:
30
+ t0: float = None
31
+ loaded: bool = False
32
+
33
+ load_state = LoadState()
34
+
35
+ def on_random_color():
36
+ def random_rgb_bg():
37
+ return np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255)
38
+ return term.on_color_rgb(*random_rgb_bg())
39
+
40
+ # draw the intro screen before slow imports
41
+ def color_tokenize_txt(text: str):
42
+ # apply a random bg color to each letter
43
+ return "".join(on_random_color()(letter) for letter in text)
44
+
45
+ def color_tokenize_words(text: str):
46
+ return " ".join(on_random_color()(word) for word in text.split(" "))
47
+
48
+ def draw_intro_screen():
49
+ global load_state
50
+ load_state.t0 = time.time()
51
+ avg_time = 20 # average loading time
52
+
53
+ while not load_state.loaded:
54
+ print(term.clear)
55
+ print(term.move_xy(0, 1) + term.center(color_tokenize_words("hugo flores garcía")))
56
+ print(term.move_xy(0, 3) + term.center(color_tokenize_words("and")))
57
+ print(term.move_xy(0, 5) + term.center(color_tokenize_words("stephan moore")))
58
+ print(term.move_xy(0, 7) + term.center(color_tokenize_words("present")))
59
+ print(term.move_xy(0, 9) + term.center(term.bold(color_tokenize_txt("token telephone"))))
60
+
61
+ # print(term.move_xy(0, 10) + term.center(color_tokenize_txt("loading ")), end="")
62
+ # make a little loading bar
63
+ elapsed = time.time() - load_state.t0
64
+ num_dots = int((elapsed / avg_time) * 20)
65
+ num_spaces = 20 - num_dots
66
+ print(term.move_xy(0, 12) + term.center(color_tokenize_words("loading")))
67
+ print(term.move_xy(0, 13) + term.center(color_tokenize_txt(f"[{'.' * num_dots}") + f"{' ' * num_spaces}]"))
68
+ time.sleep(0.3)
69
+
70
+ log(f"loading took {time.time() - load_state.t0} seconds")
71
+ return
72
+
73
+ # the program
74
+ term = Terminal()
75
+
76
+ # draw the intro screen on a background thread
77
+ Thread(target=draw_intro_screen).start()
78
+
79
+ # disable garbage collection
80
+ from audiotools import AudioSignal
81
+ from vamp_helper import load_interface, ez_variation
82
+
83
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
84
+ # ~~~~~~ configs! ~~~~~~~~
85
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
86
+
87
+ MAX_LOUDNESS = -20
88
+ MIN_LOUDNESS = -40
89
+ COLS = 40
90
+ ROWS = 13
91
+
92
+ device = 'Scarlett 4i4 4th Gen'
93
+ sample_rate = 48000
94
+ num_channels = 4
95
+ blocksize = 16384
96
+
97
+
98
+ # TODO:
99
+ # still some quirks to work around recording time:
100
+ # do we wanna stop recording and wait a full cycle before letting people record again?
101
+ # how do we wanna balance the volume of a new input vs what's currently gonig on?
102
+ # should people have to take turns in between new loops?
103
+ # otherwise, we're doing great i think
104
+ # we also need to add a crossfade. This means maybe cutting off the last 0.1 seconds of the loop, and the beginning 0.1
105
+ # and use that to crossfade.
106
+
107
+ # TODO: do I wanna train a diff model to swap every 2hrs or something?
108
+ # how lond does model swapping take? how can I make it faster?
109
+
110
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
111
+ # ~~~~~~ looper ~~~~~~~~
112
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
113
+
114
+ @dataclass
115
+ class State:
116
+ # looper state
117
+ feedback: float = 0.25
118
+ duration: float = 5.0
119
+ record_channel: int = 0
120
+
121
+ loopbuf: np.ndarray = None # the main loop buffer. the token telephone audio is here
122
+ looper_in: np.ndarray = None # a buffer that stores the audio that's being recorded
123
+
124
+ buf_in: np.ndarray = None # the input block with audio samples in the audio callbac
125
+ lookback_buf: np.ndarray = None # stores some lookback audio for when the threshold is passed, to propery capture transients
126
+
127
+ recording: bool = False
128
+ playing: bool = False
129
+
130
+ # ramps
131
+ record_ramp_in: bool = False
132
+ record_ramp_out: bool = False
133
+
134
+ # n_record_layers: int = 2 # number of times we'll record over before clearing
135
+ # cur_rec_layer: int = 0
136
+ recording_locked: bool = False
137
+
138
+ rec_time: float = 0
139
+ cur_hold_time: float = None
140
+ pos: int = 0
141
+ rms_db: float = float("-inf")
142
+
143
+ trig_threshold_db = -25 # a more sane default is -20
144
+ hold_seconds = 1.0
145
+ rel_threshold_db = -40 # a more sane default is -30
146
+
147
+ status: str = field(default=None)
148
+
149
+ # token telephone configs
150
+ z_buf: torch.Tensor = None
151
+ input_ready = False
152
+ input_channel = 0
153
+ token_telephone_processing: bool = False
154
+ num_telephone_chans = 4
155
+ tt_cur_ch = 0
156
+
157
+ def __post_init__(self):
158
+ self.loopbuf = np.zeros((num_channels, int(self.duration * sample_rate)))
159
+ self.looper_in = np.zeros((1, int(self.duration * sample_rate)))
160
+
161
+ # hold 200ms of lookback to account for rising attacks.
162
+ num_lookback_samples = max(int(sample_rate * 0.2), int(blocksize))
163
+ log(f"num_lookback_samples {num_lookback_samples} ({num_lookback_samples / sample_rate} seconds)")
164
+ self.lookback_buf = np.zeros((1, num_lookback_samples))
165
+
166
+ self.buf_in = np.zeros((num_channels, blocksize))
167
+
168
+
169
+
170
+ def check_if_record(st: State, ain: np.ndarray, on_release_callback=None):
171
+ # get our rms value
172
+ rms = pow2db(np.sqrt(np.mean(ain**2)))
173
+ st.rms_db = rms
174
+
175
+ # determine if we should ater the looper state
176
+ # if we werent recording and we cross the trigger threshold
177
+ # start recording
178
+ # if not st.recording and rms > st.trig_threshold_db and not st.recording_locked:
179
+ if not st.recording and rms > st.trig_threshold_db and not st.recording_locked:
180
+ st.recording = True
181
+ st.record_ramp_in = True
182
+
183
+ # if we were recording and we cross the release threshold
184
+ # begin the hold period
185
+ if (st.recording and rms < st.rel_threshold_db) or st.rec_time > (st.duration-st.hold_seconds):
186
+ # if we dont have a hold time, set it
187
+ if st.cur_hold_time is None:
188
+ st.cur_hold_time = time.time()
189
+
190
+ # release if we have a hold time and we've held for the required time,
191
+ if (time.time() - st.cur_hold_time) > st.hold_seconds:
192
+ st.record_ramp_out = True
193
+ st.rec_time = 0
194
+ if on_release_callback is not None:
195
+ st.input_ready = True
196
+ on_release_callback(st)
197
+ st.cur_hold_time = None
198
+ else:
199
+ pass
200
+ else:
201
+ st.cur_hold_time = None
202
+
203
+
204
+ def launch_token_telephone(st: State):
205
+ if interface is None:
206
+ log("no interface loaded, can't do token telephone!")
207
+ time.sleep(10)
208
+ return
209
+
210
+ # if we're already processing, do nothing
211
+ if st.token_telephone_processing:
212
+ return
213
+ else:
214
+ log("starting token telephone!")
215
+ Thread(target=do_token_telephone, args=(st,)).start()
216
+
217
+
218
+ def do_token_telephone(st: State,):
219
+ st.token_telephone_processing = True
220
+ while True:
221
+ lrc = st.record_channel
222
+ t0 = time.time()
223
+ cur_ch = st.tt_cur_ch
224
+
225
+ # if there was input ready, start back from the top.
226
+ if st.input_ready:
227
+ log(f"there was input ready, processing!")
228
+ # NOTE: hugo, trying something new here. what happens if
229
+ # we don't reset the channel when input is ready,
230
+ # and instead let it come in anywhere in the cycle?
231
+ # st.tt_cur_ch = 0 # uncomment to go back to reality
232
+
233
+ # clear the lrc, reset for next record.
234
+ st.input_ready = False
235
+
236
+ # reocrd the channel that we'll be processing in and lock recording
237
+ st.input_channel = cur_ch
238
+ st.recording_locked = True
239
+
240
+ # first, let's preprocess looper in
241
+ sig_looper_in = AudioSignal(
242
+ torch.from_numpy(st.looper_in).unsqueeze(0),
243
+ sample_rate=sample_rate
244
+ )
245
+ sig_loopbuf_curch = AudioSignal(
246
+ torch.from_numpy(st.loopbuf[cur_ch:cur_ch+1]).unsqueeze(0),
247
+ sample_rate=sample_rate
248
+ )
249
+ # make sure looperin matches the midpoint in loudness
250
+ ldns_mid = max(sig_loopbuf_curch.loudness(), sig_looper_in.loudness())
251
+ sig_looper_in = sig_looper_in.normalize(ldns_mid)
252
+ st.looper_in = sig_looper_in.samples.cpu().numpy().squeeze(0)
253
+
254
+ st.loopbuf[cur_ch:cur_ch + 1] = (
255
+ st.looper_in + st.loopbuf[cur_ch:cur_ch+1] * st.feedback
256
+ )
257
+ # also lower the volumes of the other channels
258
+ for i in range(4):
259
+ if i != cur_ch:
260
+ st.loopbuf[i:i+1] = st.loopbuf[i:i+1] * 0.5 # -3dB
261
+
262
+ st.looper_in = np.zeros_like(st.looper_in)
263
+
264
+ loop_input = st.loopbuf[cur_ch:cur_ch+1]
265
+
266
+ # ~~~ VAMPNET STUFF ~~~~
267
+ sig = AudioSignal(
268
+ torch.from_numpy(loop_input).unsqueeze(0),
269
+ sample_rate=sample_rate
270
+ )
271
+ input_loudness = sig.loudness()
272
+ log(f"INPUT loudness {input_loudness}")
273
+ if input_loudness > MAX_LOUDNESS:
274
+ log(f"input loudness {input_loudness} is over {MAX_LOUDNESS}!")
275
+ sig = sig.normalize(MAX_LOUDNESS)
276
+ elif input_loudness < MIN_LOUDNESS:
277
+ log(f"input loudness {input_loudness} is under {MIN_LOUDNESS}!")
278
+ sig = sig.normalize(MIN_LOUDNESS)
279
+
280
+ sig = ez_variation(interface, sig)
281
+ sig = sig.resample(sample_rate)
282
+
283
+ # notify if we've gone over the loudness
284
+ sig = sig.normalize(input_loudness)
285
+ outloudness = sig.loudness()
286
+ if outloudness > MAX_LOUDNESS:
287
+ log(f"out loudness {sig.loudness()} is over {MAX_LOUDNESS}!")
288
+ sig = sig.normalize(MAX_LOUDNESS)
289
+ elif outloudness < MIN_LOUDNESS:
290
+ log(f"out loudness {sig.loudness()} is under {MIN_LOUDNESS}!")
291
+ sig = sig.normalize(MIN_LOUDNESS)
292
+
293
+ # put it back in the loopbuf
294
+ # write to the next channel
295
+ # (TODO: instead of trimming to loopbuf.shape[1], maybe we can just have the loopbuf be the right size from init time.)
296
+ cur_ch = (cur_ch + 1) % st.num_telephone_chans
297
+ st.tt_cur_ch = cur_ch
298
+ if False: # HUGO: is there a time where we want feedback?
299
+ st.loopbuf[cur_ch:cur_ch+1] = (
300
+ sig.samples.cpu().numpy().squeeze(0)[:, :st.loopbuf.shape[1]]
301
+ + st.feedback * st.loopbuf[cur_ch:cur_ch+1]
302
+ )
303
+ else:
304
+ st.loopbuf[cur_ch:cur_ch+1] = (
305
+ sig.samples.cpu().numpy().squeeze(0)[:, :st.loopbuf.shape[1]]
306
+ )
307
+
308
+ log(f"output loudness {sig.loudness()}")
309
+ log(f"telephone loop took {time.time() - t0} seconds... next channel {cur_ch}\n\n")
310
+
311
+ # if we've made it back to the input channel, we can unlock the recording
312
+ log(f"cur_ch {cur_ch} input_channel {st.input_channel}")
313
+ if cur_ch == st.input_channel:
314
+ st.recording_locked = False
315
+ log(f"recording unlocked!")
316
+
317
+
318
+ # unlock the recording if we've successfully written to all channels
319
+ # if st.recording_locked and cur_ch == 0:
320
+ # st.recording_locked = False
321
+ # log(f"recording locked {st.recording_locked}")
322
+
323
+ st.token_telephone_processing = False
324
+ return
325
+
326
+ # TODO: since we're using this really high threshold
327
+ # we always need to record about 100ms in advance, to catch the beginning of the attacks.
328
+
329
+ def looper_process_block(st, block: np.ndarray):
330
+ lrc = st.record_channel
331
+
332
+ # treat the lookback buffer as a circular buffer
333
+ st.lookback_buf = np.roll(st.lookback_buf, block.shape[1], axis=1)
334
+ st.lookback_buf[:, -block.shape[1]:] = block[lrc:lrc+1, :]
335
+
336
+
337
+ # check if we need to record.
338
+ if st.recording:
339
+ start_i = (st.pos + block.shape[1]) - st.lookback_buf.shape[1]
340
+ end_i = st.pos + st.lookback_buf.shape[1]
341
+
342
+ indices = np.take(
343
+ np.arange(st.loopbuf.shape[1]),
344
+ np.arange(start_i, end_i),
345
+ mode="wrap"
346
+ )
347
+ _audio_in = st.lookback_buf[:, :]
348
+ # ramp in if we need to
349
+ if st.record_ramp_in:
350
+ _audio_in = _audio_in * np.linspace(0, 1, _audio_in.shape[1])
351
+ st.record_ramp_in=False
352
+
353
+ if st.record_ramp_out:
354
+ _audio_in = _audio_in * np.linspace(1, 0, _audio_in.shape[1])
355
+ st.record_ramp_out=False
356
+ st.recording = False
357
+
358
+ st.looper_in[:, indices] = (
359
+ 0.9 * st.looper_in[:, indices] + _audio_in
360
+ )
361
+
362
+ # incremement the recording time
363
+ st.rec_time += st.lookback_buf.shape[1] / sample_rate
364
+
365
+ # check if we need to play
366
+ crossfade_samples = int(0.1 * sample_rate)
367
+ if st.playing:
368
+ play_pos = (st.pos + block.shape[1]) % st.loopbuf.shape[1] # read one buffer ahead
369
+ indices = np.arange(play_pos, play_pos + block.shape[1])
370
+ block = st.loopbuf.take(indices, axis=1, mode="wrap")[:, :] # this doesn't have any crossfading. # TODO: this is still not working!
371
+
372
+ # if we've recorded more than the loop size
373
+ if st.rec_time > st.duration and st.recording:
374
+ # play the loop
375
+ play_pos = st.pos + block.shape[1] # read one buffer ahead
376
+ indices = np.arange(play_pos, play_pos + block.shape[1])
377
+
378
+ block[lrc:lrc] = st.looper_in.take(indices, axis=1, mode="wrap")[:, :]
379
+
380
+ # advance looper state
381
+ st.pos = (st.pos + block.shape[1]) % st.loopbuf.shape[1]
382
+
383
+ return block
384
+
385
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
386
+ # ~~~~~~ drawing ~~~~~~~~
387
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
388
+
389
+ def draw_rms_bar(st, x, y, width, height):
390
+ rms_min = -50
391
+ rms_max = -10
392
+ rms = st.rms_db
393
+ rms = max(rms, rms_min)
394
+ threshold = st.trig_threshold_db
395
+ rel_threshold = st.rel_threshold_db
396
+
397
+ rms_block = int((rms - rms_min) / (rms_max - rms_min) * height)
398
+ threshold_block = (threshold - rms_min) / (rms_max - rms_min) * height
399
+ rel_threshold_block = (rel_threshold - rms_min) / (rms_max - rms_min) * height
400
+
401
+ # draw the rms curve
402
+ for i in range(rms_block, height+4):
403
+ with term.location(x+4, y+height-i):
404
+ print(term.clear_bol)
405
+ for i in range(rms_block):
406
+ rms_val = i * (rms_max - rms_min) / height + rms_min
407
+ with term.location(x, y+height-2-i):
408
+ if i < threshold_block:
409
+ print(" " + term.on_green(f"*"))
410
+ else:
411
+ print(" " + term.on_red(f"*"))
412
+
413
+ # at the very bottom of the bar, draw the rms value
414
+ with term.location(x, y+height-1):
415
+ print(f"{rms:.1f}dB")
416
+ # print(f" rms")
417
+
418
+
419
+ def draw_looper(st):
420
+ x = 0
421
+ y = 0
422
+ width = COLS
423
+ height = ROWS
424
+
425
+ tt_refresh_every = 0.3
426
+ if not hasattr(draw_looper, "last_draw"):
427
+ draw_looper.last_draw = 0
428
+ should_draw = True
429
+ else:
430
+ should_draw = (time.time() - draw_looper.last_draw) > tt_refresh_every
431
+ if should_draw:
432
+ draw_looper.last_draw = time.time()
433
+
434
+
435
+ draw_rms_bar(st, x, y, width - 10, height)
436
+
437
+ if should_draw:
438
+ with term.location(width // 2-4, 1):
439
+ for i, letter in enumerate("token telephone"):
440
+ print(on_random_color()(letter), end="")
441
+
442
+ # with term.location(ROWS-2, COLS // 2):
443
+ # print(f"status {st.status}!!!")
444
+
445
+
446
+ # if we're recording, draw a red unlderlined "rec" sign on the bottom right
447
+ # with term.location(width-8, height-1):
448
+ # if st.recording:
449
+ # print(term.on_red("rec"))
450
+ # else:
451
+ # print(term.on_gray50("rec"))
452
+
453
+ # # if we're playing draw a green underline "play" sign on the bottom right
454
+ # with term.location(width-4, height-1):
455
+ # if st.playing:
456
+ # print(term.on_green("play"))
457
+ # else:
458
+ # print(term.on_gray50("play"))
459
+
460
+
461
+ # draw the timeline at the bottom using ---
462
+ with term.location(6, height):
463
+ timeline = ["-"] * (width - 12)
464
+ playhead = int((st.pos / st.loopbuf.shape[1]) * (width - 12))
465
+ timeline[playhead] = "v"
466
+ print("|"+"".join(timeline) + "|")
467
+
468
+
469
+ # draw the main message at the very center:
470
+ msg_loc = (width // 2, height // 2+1)
471
+ _x, _y = msg_loc
472
+ if not st.recording:
473
+ if not st.recording_locked:
474
+ print(term.move_xy(0, _y-1) + term.center("make a sound", width=width+5))
475
+ print(term.move_xy(0, _y+0) + term.center("to", width=width+5))
476
+ print(term.move_xy(0, _y+1) + term.center("record", width=width+5))
477
+ else:
478
+ # how many seconds left until we can record again?
479
+ # how many more chs do we need to go through before we can record again?
480
+ if st.tt_cur_ch < st.input_channel:
481
+ chs_remaining = st.input_channel - st.tt_cur_ch
482
+ else:
483
+ chs_remaining = 4-st.tt_cur_ch + st.input_channel
484
+ locked_time_remaining = chs_remaining * st.duration + st.duration - (st.pos / sample_rate)
485
+ print(term.move_xy(0, _y-1) + term.center("please wait", width=width+5))
486
+ print(term.move_xy(0, _y+0) + term.center(term.on_green(f"{locked_time_remaining:.1f}s"), width=width+5))
487
+ print(term.move_xy(0, _y+1) + term.center("for your turn :)", width=width+5))
488
+ else:
489
+ print(term.move_xy(0, _y-1) + term.center(term.on_red("recording"), width=width+5))
490
+ print(term.move_xy(0, _y+0) + term.center(f"{(st.duration) - st.rec_time:.1f}s left", width=width+5))
491
+ print(term.move_xy(0, _y+1) + term.center("", width=width+5))
492
+
493
+
494
+ # we'll draw channel 0 (1) on the bottom right corner
495
+ # channel 1 (2) on the top right corner
496
+ # channel 2 (3) on the top left corner
497
+ # channel 3 (4) on the bottom left corner
498
+ my = 3 # margin
499
+ mx = 10
500
+ locations = {
501
+ 1: (width - mx, height - my),
502
+ 2: (width - mx, 1+my),
503
+ 3: (mx, 1+my),
504
+ 4: (mx, height - my),
505
+ }
506
+ for i in range(1, 5):
507
+ if should_draw:
508
+ if st.tt_cur_ch == i - 1 and st.token_telephone_processing:
509
+ x, y = locations[i]
510
+ on_random_colors = lambda n: "".join(on_random_color()(" ") for _ in range(n))
511
+ print(term.move_xy(x, y-1) + on_random_colors(5))
512
+ print(term.move_xy(x, y) + on_random_color()(" ") + f" {i} " + on_random_color()(" "))
513
+ print(term.move_xy(x, y+1) + on_random_colors(5))
514
+ else:
515
+ # same thing, but a gray instead of random colors
516
+ x, y = locations[i]
517
+ on_gray_colors = lambda n: "".join(term.on_gray50(" ") for _ in range(n))
518
+ print(term.move_xy(x, y-1) + on_gray_colors(5))
519
+ print(term.move_xy(x, y) + term.on_gray50(" ") + f" {i} " + term.on_gray50(" "))
520
+ print(term.move_xy(x, y+1) + on_gray_colors(5))
521
+
522
+
523
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
524
+ # ~~~~~~ live audio ~~~~~~~~
525
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
526
+ def audio_init():
527
+ sd.default.samplerate = sample_rate
528
+ sd.default.device = device
529
+
530
+ # ~~~~~~ the main audio callback ~~~~~~~~~
531
+ def callback(st, indata, outdata, frames, _time, status):
532
+ t0 = time.time()
533
+ lrc = st.record_channel
534
+
535
+ if status:
536
+ log(f"status is {status}")
537
+ st.status = status
538
+
539
+ # log dtype, status, frames, time, max min
540
+ # log(f"indata {indata.dtype} max {indata.max()} min {indata.min()} {status} {frames} {_time}")
541
+
542
+
543
+ ain = rearrange(indata, 't n -> n t', n=num_channels)
544
+
545
+ # convert audio to from int32 to float32
546
+ ain = ain.astype(np.float32) / np.iinfo(np.int16).max
547
+ buf_in = ain
548
+
549
+ # if it's all zeros, we're not recording
550
+ # so we can just pass it through
551
+ if np.all(buf_in == 0):
552
+ st.status = st.status + "no input"
553
+ return
554
+
555
+ st.buf_in = buf_in
556
+ check_if_record(
557
+ st, buf_in,
558
+ on_release_callback=launch_token_telephone
559
+ )
560
+ buf_in = looper_process_block(st, buf_in)
561
+
562
+ # pass our st.loopbuf to the output
563
+ ain = buf_in
564
+
565
+ # convert back to int32
566
+ ain = (ain * np.iinfo(np.int16).max).astype(np.int16)
567
+
568
+ outdata[:] = rearrange(ain, 'n t -> t n')
569
+
570
+ # log(f"outdata {outdata.dtype} max {outdata.max()} min {outdata.min()} --- took {time.time() - t0} seconds")
571
+
572
+
573
+
574
+ if DEBUG_NO_VAMPNET:
575
+ interface=None
576
+ else:
577
+ interface = load_interface(model_choice="opera")
578
+
579
+ load_state.loaded = True
580
+
581
+ def main():
582
+ if PROFILE:
583
+ import yappi
584
+ yappi.start()
585
+
586
+ try:
587
+ audio_init()
588
+ st = State()
589
+ st.playing = True
590
+
591
+ from functools import partial
592
+ cb = partial(callback, st)
593
+
594
+ with term.fullscreen(), term.cbreak():
595
+ with sd.Stream(channels=num_channels, callback=cb, blocksize=blocksize, prime_output_buffers_using_stream_callback=True, dtype=np.int16):
596
+ while True:
597
+ with term.hidden_cursor():
598
+ if DEBUG:
599
+ time.sleep(100)
600
+ else:
601
+ draw_looper(st)
602
+
603
+ except KeyboardInterrupt:
604
+ print(term.clear)
605
+ if PROFILE:
606
+ yappi.stop()
607
+
608
+ # retrieve thread stats by their thread id (given by yappi)
609
+ threads = yappi.get_thread_stats()
610
+ for thread in threads:
611
+ print(
612
+ "Function stats for (%s) (%d)" % (thread.name, thread.id)
613
+ ) # it is the Thread.__class__.__name__
614
+ yappi.get_func_stats(ctx_id=thread.id).print_all()
615
+
616
+ main()
token_telephone/ttutil.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+ ROOT = Path(__file__).parent
4
+
5
+ import numpy as np
6
+ from queue import Queue
7
+
8
+ # make a log file!!
9
+ logfile= ROOT / "log.txt"
10
+ if logfile.exists():
11
+ logfile.unlink()
12
+ logging.basicConfig(filename=logfile, level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S", format="%(asctime)s | %(levelname)s | %(message)s")
13
+
14
+
15
+ def hsv_to_rgb(h, s, v):
16
+ # from https://en.wikipedia.org/wiki/HSL_and_HSV#From_HSV
17
+ c = v * s
18
+ h_ = h / 60
19
+ x = c * (1 - abs(h_ % 2 - 1))
20
+ m = v - c
21
+
22
+ if h_ < 1:
23
+ r, g, b = c, x, 0
24
+ elif h_ < 2:
25
+ r, g, b = x, c, 0
26
+ elif h_ < 3:
27
+ r, g, b = 0, c, x
28
+ elif h_ < 4:
29
+ r, g, b = 0, x, c
30
+ elif h_ < 5:
31
+ r, g, b = x, 0, c
32
+ else:
33
+ r, g, b = c, 0, x
34
+
35
+ return r + m, g + m, b + m
36
+
37
+
38
+ def dbg(*args):
39
+ print(" ".join(map(str, args)))
40
+
41
+
42
+ # we'll want to log on a separate thread
43
+ # so that we can log without blocking the main thread
44
+
45
+ # make a queue for logging
46
+ log_queue = Queue()
47
+
48
+ # log to a file instead of the console
49
+ def log(msg):
50
+ # log_queue.put(msg)
51
+ logging.info(msg)
52
+ pass
53
+
54
+ def set_debug(debug):
55
+ if debug:
56
+ # print log to console
57
+ logging.getLogger().addHandler(logging.StreamHandler())
58
+
59
+
60
+ def pow2db(x):
61
+ return 10 * np.log10(x + 1e-6)
62
+
63
+
64
+ def db2pow(x):
65
+ return 10 ** (x / 10)
token_telephone/vamp_helper.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import time
3
+ import os
4
+ from contextlib import contextmanager
5
+ import random
6
+
7
+ import numpy as np
8
+ import audiotools as at
9
+ from audiotools import AudioSignal
10
+ import argbind
11
+ import shutil
12
+ import torch
13
+ import yaml
14
+
15
+
16
+ from vampnet.interface import Interface, signal_concat
17
+ from vampnet import mask as pmask
18
+
19
+ from ttutil import log
20
+
21
+ # TODO: incorporate discord bot (if mem allows)
22
+ # in a separate thread, send audio samples for listening
23
+ # and send back the results
24
+ # as well as the params for sampling
25
+ # also a command that lets you clear the current signal
26
+ # if you want to start over
27
+
28
+
29
+ device = "cuda" if torch.cuda.is_available() else "cpu"
30
+
31
+ VAMPNET_DIR = Path(".").resolve()
32
+
33
+ @contextmanager
34
+ def chdir(path):
35
+ old_dir = os.getcwd()
36
+ os.chdir(path)
37
+ try:
38
+ yield
39
+ finally:
40
+ os.chdir(old_dir)
41
+
42
+ def load_interface(model_choice="default") -> Interface:
43
+ with chdir(VAMPNET_DIR):
44
+
45
+
46
+ # populate the model choices with any interface.yml files in the generated confs
47
+ MODEL_CHOICES = {
48
+ "default": {
49
+ "Interface.coarse_ckpt": "models/vampnet/coarse.pth",
50
+ "Interface.coarse2fine_ckpt": "models/vampnet/c2f.pth",
51
+ "Interface.codec_ckpt": "models/vampnet/codec.pth",
52
+ }
53
+ }
54
+ generated_confs = Path("conf/generated")
55
+ for conf_file in generated_confs.glob("*/interface.yml"):
56
+ with open(conf_file) as f:
57
+ _conf = yaml.safe_load(f)
58
+
59
+ # check if the coarse, c2f, and codec ckpts exist
60
+ # otherwise, dont' add this model choice
61
+ if not (
62
+ Path(_conf["Interface.coarse_ckpt"]).exists() and
63
+ Path(_conf["Interface.coarse2fine_ckpt"]).exists() and
64
+ Path(_conf["Interface.codec_ckpt"]).exists()
65
+ ):
66
+ continue
67
+
68
+ MODEL_CHOICES[conf_file.parent.name] = _conf
69
+
70
+ interface = Interface(
71
+ device=device,
72
+ coarse_ckpt=MODEL_CHOICES[model_choice]["Interface.coarse_ckpt"],
73
+ coarse2fine_ckpt=MODEL_CHOICES[model_choice]["Interface.coarse2fine_ckpt"],
74
+ codec_ckpt=MODEL_CHOICES[model_choice]["Interface.codec_ckpt"],
75
+ )
76
+
77
+ interface.model_choices = MODEL_CHOICES
78
+ interface.to("cuda" if torch.cuda.is_available() else "cpu")
79
+ return interface
80
+
81
+ def load_model(interface: Interface, model_choice: str):
82
+ interface.reload(
83
+ interface.model_choices[model_choice]["Interface.coarse_ckpt"],
84
+ interface.model_choices[model_choice]["Interface.coarse2fine_ckpt"],
85
+ )
86
+
87
+ def ez_variation(
88
+ interface,
89
+ sig: AudioSignal,
90
+ seed: int = None,
91
+ model_choice: str = None,
92
+ ):
93
+ t0 = time.time()
94
+
95
+ if seed is None:
96
+ seed = int(torch.randint(0, 2**32, (1,)).item())
97
+ at.util.seed(seed)
98
+
99
+ # reload the model if necessary
100
+ if model_choice is not None:
101
+ load_model(interface, model_choice)
102
+
103
+ # SAMPLING MASK PARAMS, hard code for now, we'll prob want a more preset-ey thing for the actual thin
104
+ # we probably honestly just want to oscillate between the same 4 presets
105
+ # in a predictable order such that they have a predictable outcome
106
+ periodic_p = random.choice([3])
107
+ n_mask_codebooks = 3
108
+ sampletemp = random.choice([1.0,])
109
+ dropout = random.choice([0.0, 0.0])
110
+
111
+ top_p = None # NOTE: top p may be the culprit behind the collapse into single pitches.
112
+
113
+ # parameters for the build_mask function
114
+ build_mask_kwargs = dict(
115
+ rand_mask_intensity=1.0,
116
+ prefix_s=0.0,
117
+ suffix_s=0.0,
118
+ periodic_prompt=int(periodic_p),
119
+ periodic_prompt2=int(periodic_p),
120
+ periodic_prompt_width=1,
121
+ _dropout=dropout,
122
+ upper_codebook_mask=int(n_mask_codebooks),
123
+ upper_codebook_mask_2=int(n_mask_codebooks),
124
+ )
125
+
126
+ # parameters for the vamp function
127
+ vamp_kwargs = dict(
128
+ temperature=sampletemp,
129
+ typical_filtering=True,
130
+ typical_mass=0.15,
131
+ typical_min_tokens=64,
132
+ top_p=top_p,
133
+ seed=seed,
134
+ sample_cutoff=1.0,
135
+ )
136
+
137
+ # save the mask as a txt file
138
+ interface.set_chunk_size(10.0)
139
+ sig, mask, codes = interface.ez_vamp(
140
+ sig,
141
+ batch_size=1,
142
+ feedback_steps=1,
143
+ time_stretch_factor=1,
144
+ build_mask_kwargs=build_mask_kwargs,
145
+ vamp_kwargs=vamp_kwargs,
146
+ return_mask=True,
147
+ )
148
+
149
+ log(f"vamp took {time.time() - t0} seconds")
150
+ return sig
151
+
152
+
153
+
154
+ def main():
155
+ import tqdm
156
+
157
+ interface = load_interface()
158
+ sig = AudioSignal.excerpt("assets/example.wav", duration=7.0)
159
+ sig = interface.preprocess(sig)
160
+ sig.write('ttout/in.wav')
161
+ insig = sig.clone()
162
+
163
+ fdbk_every = 4
164
+ fdbk = 0.5
165
+
166
+ for i in tqdm.tqdm(range(1000)):
167
+ sig = ez_variation(interface, sig, model_choice="orchestral")
168
+ sig.write(f'ttout/out{i}.wav')
169
+
170
+
171
+ if __name__ == "__main__":
172
+ main()
vampnet/__init__.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from . import modules
3
+ from pathlib import Path
4
+ from . import scheduler
5
+ from .interface import Interface
6
+ from .modules.transformer import VampNet
7
+
8
+
9
+ __version__ = "0.0.1"
10
+
11
+ ROOT = Path(__file__).parent.parent
12
+ MODELS_DIR = ROOT / "models" / "vampnet"
13
+
14
+ from huggingface_hub import hf_hub_download, HfFileSystem
15
+ DEFAULT_HF_MODEL_REPO = "hugggof/vampnet"
16
+ FS = HfFileSystem()
17
+
18
+ def download_codec():
19
+ # from dac.model.dac import DAC
20
+ from lac.model.lac import LAC as DAC
21
+ repo_id = DEFAULT_HF_MODEL_REPO
22
+ filename = "codec.pth"
23
+ codec_path = hf_hub_download(
24
+ repo_id=repo_id,
25
+ filename=filename,
26
+ subfolder=None,
27
+ local_dir=MODELS_DIR
28
+ )
29
+ return codec_path
30
+
31
+
32
+ def download_default():
33
+ filenames = ["coarse.pth", "c2f.pth"]
34
+ repo_id = DEFAULT_HF_MODEL_REPO
35
+ paths = []
36
+ for filename in filenames:
37
+ path = f"{MODELS_DIR}/{filename}"
38
+ if not Path(path).exists():
39
+ path = hf_hub_download(
40
+ repo_id=repo_id,
41
+ filename=filename,
42
+ subfolder=None,
43
+ local_dir=MODELS_DIR,
44
+ local_dir_use_symlinks=False,
45
+ local_files_only=False
46
+ )
47
+ paths.append(path)
48
+
49
+ # load the models
50
+ return paths[0], paths[1]
51
+
52
+
53
+ def download_finetuned(name):
54
+ repo_id = f"{DEFAULT_HF_MODEL_REPO}"
55
+ filenames = ["coarse.pth", "c2f.pth"]
56
+ paths = []
57
+ for filename in filenames:
58
+ path = f"{MODELS_DIR}/{name}/loras/{filename}"
59
+ if not Path(path).exists():
60
+ path = hf_hub_download(
61
+ repo_id=repo_id,
62
+ filename=filename,
63
+ subfolder=f"loras/{name}",
64
+ local_dir=MODELS_DIR,
65
+ local_dir_use_symlinks=False,
66
+ local_files_only=False
67
+ )
68
+ paths.append(path)
69
+
70
+ # load the models
71
+ return paths[0], paths[1]
72
+
73
+ def list_finetuned():
74
+ diritems = FS.listdir(f"{DEFAULT_HF_MODEL_REPO}/loras")
75
+ # iterate through all the names
76
+ valid_diritems = []
77
+ for item in diritems:
78
+ model_file_items = FS.listdir(item["name"])
79
+ item_names = [item["name"].split("/")[-1] for item in model_file_items]
80
+ # check that theres a "c2f.pth" and "coarse.pth" in the items
81
+ c2f_exists = "c2f.pth" in item_names
82
+ coarse_exists = "coarse.pth" in item_names
83
+ if c2f_exists and coarse_exists:
84
+ valid_diritems.append(item)
85
+
86
+ # get the names of the valid items
87
+ names = [item["name"].split("/")[-1] for item in valid_diritems]
88
+ return names
89
+
90
+
vampnet/beats.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import warnings
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Any
7
+ from typing import List
8
+ from typing import Tuple
9
+ from typing import Union
10
+
11
+ import librosa
12
+ import torch
13
+ import numpy as np
14
+ from audiotools import AudioSignal
15
+
16
+
17
+ logging.basicConfig(level=logging.INFO)
18
+
19
+ ###################
20
+ # beat sync utils #
21
+ ###################
22
+
23
+ AGGREGATOR_REGISTRY = {
24
+ "mean": np.mean,
25
+ "median": np.median,
26
+ "max": np.max,
27
+ "min": np.min,
28
+ }
29
+
30
+
31
+ def list_aggregators() -> list:
32
+ return list(AGGREGATOR_REGISTRY.keys())
33
+
34
+
35
+ @dataclass
36
+ class TimeSegment:
37
+ start: float
38
+ end: float
39
+
40
+ @property
41
+ def duration(self):
42
+ return self.end - self.start
43
+
44
+ def __str__(self) -> str:
45
+ return f"{self.start} - {self.end}"
46
+
47
+ def find_overlapping_segment(
48
+ self, segments: List["TimeSegment"]
49
+ ) -> Union["TimeSegment", None]:
50
+ """Find the first segment that overlaps with this segment, or None if no segment overlaps"""
51
+ for s in segments:
52
+ if s.start <= self.start and s.end >= self.end:
53
+ return s
54
+ return None
55
+
56
+
57
+ def mkdir(path: Union[Path, str]) -> Path:
58
+ p = Path(path)
59
+ p.mkdir(parents=True, exist_ok=True)
60
+ return p
61
+
62
+
63
+
64
+ ###################
65
+ # beat data #
66
+ ###################
67
+ @dataclass
68
+ class BeatSegment(TimeSegment):
69
+ downbeat: bool = False # if there's a downbeat on the start_time
70
+
71
+
72
+ class Beats:
73
+ def __init__(self, beat_times, downbeat_times):
74
+ if isinstance(beat_times, np.ndarray):
75
+ beat_times = beat_times.tolist()
76
+ if isinstance(downbeat_times, np.ndarray):
77
+ downbeat_times = downbeat_times.tolist()
78
+ self._beat_times = beat_times
79
+ self._downbeat_times = downbeat_times
80
+ self._use_downbeats = False
81
+
82
+ def use_downbeats(self, use_downbeats: bool = True):
83
+ """use downbeats instead of beats when calling beat_times"""
84
+ self._use_downbeats = use_downbeats
85
+
86
+ def beat_segments(self, signal: AudioSignal) -> List[BeatSegment]:
87
+ """
88
+ segments a song into time segments corresponding to beats.
89
+ the first segment starts at 0 and ends at the first beat time.
90
+ the last segment starts at the last beat time and ends at the end of the song.
91
+ """
92
+ beat_times = self._beat_times.copy()
93
+ downbeat_times = self._downbeat_times
94
+ beat_times.insert(0, 0)
95
+ beat_times.append(signal.signal_duration)
96
+
97
+ downbeat_ids = np.intersect1d(beat_times, downbeat_times, return_indices=True)[
98
+ 1
99
+ ]
100
+ is_downbeat = [
101
+ True if i in downbeat_ids else False for i in range(len(beat_times))
102
+ ]
103
+ segments = [
104
+ BeatSegment(start_time, end_time, downbeat)
105
+ for start_time, end_time, downbeat in zip(
106
+ beat_times[:-1], beat_times[1:], is_downbeat
107
+ )
108
+ ]
109
+ return segments
110
+
111
+ def get_beats(self) -> np.ndarray:
112
+ """returns an array of beat times, in seconds
113
+ if downbeats is True, returns an array of downbeat times, in seconds
114
+ """
115
+ return np.array(
116
+ self._downbeat_times if self._use_downbeats else self._beat_times
117
+ )
118
+
119
+ @property
120
+ def beat_times(self) -> np.ndarray:
121
+ """return beat times"""
122
+ return np.array(self._beat_times)
123
+
124
+ @property
125
+ def downbeat_times(self) -> np.ndarray:
126
+ """return downbeat times"""
127
+ return np.array(self._downbeat_times)
128
+
129
+ def beat_times_to_feature_frames(
130
+ self, signal: AudioSignal, features: np.ndarray
131
+ ) -> np.ndarray:
132
+ """convert beat times to frames, given an array of time-varying features"""
133
+ beat_times = self.get_beats()
134
+ beat_frames = (
135
+ beat_times * signal.sample_rate / signal.signal_length * features.shape[-1]
136
+ ).astype(np.int64)
137
+ return beat_frames
138
+
139
+ def sync_features(
140
+ self, feature_frames: np.ndarray, features: np.ndarray, aggregate="median"
141
+ ) -> np.ndarray:
142
+ """sync features to beats"""
143
+ if aggregate not in AGGREGATOR_REGISTRY:
144
+ raise ValueError(f"unknown aggregation method {aggregate}")
145
+
146
+ return librosa.util.sync(
147
+ features, feature_frames, aggregate=AGGREGATOR_REGISTRY[aggregate]
148
+ )
149
+
150
+ def to_json(self) -> dict:
151
+ """return beats and downbeats as json"""
152
+ return {
153
+ "beats": self._beat_times,
154
+ "downbeats": self._downbeat_times,
155
+ "use_downbeats": self._use_downbeats,
156
+ }
157
+
158
+ @classmethod
159
+ def from_dict(cls, data: dict):
160
+ """load beats and downbeats from json"""
161
+ inst = cls(data["beats"], data["downbeats"])
162
+ inst.use_downbeats(data["use_downbeats"])
163
+ return inst
164
+
165
+ def save(self, output_dir: Path):
166
+ """save beats and downbeats to json"""
167
+ mkdir(output_dir)
168
+ with open(output_dir / "beats.json", "w") as f:
169
+ json.dump(self.to_json(), f)
170
+
171
+ @classmethod
172
+ def load(cls, input_dir: Path):
173
+ """load beats and downbeats from json"""
174
+ beats_file = Path(input_dir) / "beats.json"
175
+ with open(beats_file, "r") as f:
176
+ data = json.load(f)
177
+ return cls.from_dict(data)
178
+
179
+
180
+ ###################
181
+ # beat tracking #
182
+ ###################
183
+
184
+
185
+ class BeatTracker:
186
+ def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]:
187
+ """extract beats from an audio signal"""
188
+ raise NotImplementedError
189
+
190
+ def __call__(self, signal: AudioSignal) -> Beats:
191
+ """extract beats from an audio signal
192
+ NOTE: if the first beat (and/or downbeat) is detected within the first 100ms of the audio,
193
+ it is discarded. This is to avoid empty bins with no beat synced features in the first beat.
194
+ Args:
195
+ signal (AudioSignal): signal to beat track
196
+ Returns:
197
+ Tuple[np.ndarray, np.ndarray]: beats and downbeats
198
+ """
199
+ beats, downbeats = self.extract_beats(signal)
200
+ return Beats(beats, downbeats)
201
+
202
+
203
+ class WaveBeat(BeatTracker):
204
+ def __init__(self, ckpt_path: str = "checkpoints/wavebeat", device: str = "cpu"):
205
+ from wavebeat.dstcn import dsTCNModel
206
+
207
+ model = dsTCNModel.load_from_checkpoint(ckpt_path, map_location=torch.device(device))
208
+ model.eval()
209
+
210
+ self.device = device
211
+ self.model = model
212
+
213
+ def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]:
214
+ """returns beat and downbeat times, in seconds"""
215
+ # extract beats
216
+ beats, downbeats = self.model.predict_beats_from_array(
217
+ audio=signal.audio_data.squeeze(0),
218
+ sr=signal.sample_rate,
219
+ use_gpu=self.device != "cpu",
220
+ )
221
+
222
+ return beats, downbeats
223
+
224
+
225
+ class MadmomBeats(BeatTracker):
226
+ def __init__(self):
227
+ raise NotImplementedError
228
+
229
+ def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]:
230
+ """returns beat and downbeat times, in seconds"""
231
+ pass
232
+
233
+
234
+ BEAT_TRACKER_REGISTRY = {
235
+ "wavebeat": WaveBeat,
236
+ "madmom": MadmomBeats,
237
+ }
238
+
239
+
240
+ def list_beat_trackers() -> list:
241
+ return list(BEAT_TRACKER_REGISTRY.keys())
242
+
243
+
244
+ def load_beat_tracker(beat_tracker: str, **kwargs) -> BeatTracker:
245
+ if beat_tracker not in BEAT_TRACKER_REGISTRY:
246
+ raise ValueError(
247
+ f"Unknown beat tracker {beat_tracker}. Available: {list_beat_trackers()}"
248
+ )
249
+
250
+ return BEAT_TRACKER_REGISTRY[beat_tracker](**kwargs)
vampnet/interface.py ADDED
@@ -0,0 +1,623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import math
4
+ import logging
5
+
6
+ import torch
7
+ import numpy as np
8
+ from audiotools import AudioSignal
9
+ import tqdm
10
+
11
+ from .modules.transformer import VampNet
12
+ from .beats import WaveBeat
13
+ from .mask import *
14
+
15
+ # from dac.model.dac import DAC
16
+ from lac.model.lac import LAC as DAC
17
+
18
+
19
+ def signal_concat(
20
+ audio_signals: list,
21
+ ):
22
+ audio_data = torch.cat([x.audio_data for x in audio_signals], dim=-1)
23
+
24
+ return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
25
+
26
+
27
+ def _load_model(
28
+ ckpt: str,
29
+ lora_ckpt: str = None,
30
+ device: str = "cpu",
31
+ chunk_size_s: int = 10,
32
+ ):
33
+ # we need to set strict to False if the model has lora weights to add later
34
+ model = VampNet.load(location=Path(ckpt), map_location="cpu", strict=False)
35
+
36
+ # load lora weights if needed
37
+ if lora_ckpt is not None:
38
+ if not Path(lora_ckpt).exists():
39
+ should_cont = input(
40
+ f"lora checkpoint {lora_ckpt} does not exist. continue? (y/n) "
41
+ )
42
+ if should_cont != "y":
43
+ raise Exception("aborting")
44
+ else:
45
+ model.load_state_dict(torch.load(lora_ckpt, map_location="cpu"), strict=False)
46
+
47
+ model.to(device)
48
+ model.eval()
49
+ model.chunk_size_s = chunk_size_s
50
+ return model
51
+
52
+
53
+
54
+ class Interface(torch.nn.Module):
55
+ def __init__(
56
+ self,
57
+ coarse_ckpt: str = None,
58
+ coarse_lora_ckpt: str = None,
59
+ coarse2fine_ckpt: str = None,
60
+ coarse2fine_lora_ckpt: str = None,
61
+ codec_ckpt: str = None,
62
+ wavebeat_ckpt: str = None,
63
+ device: str = "cpu",
64
+ coarse_chunk_size_s: int = 10,
65
+ coarse2fine_chunk_size_s: int = 3,
66
+ compile=True,
67
+ ):
68
+ super().__init__()
69
+ assert codec_ckpt is not None, "must provide a codec checkpoint"
70
+ self.codec = DAC.load(Path(codec_ckpt))
71
+ self.codec.eval()
72
+ self.codec.to(device)
73
+ self.codec_path = Path(codec_ckpt)
74
+
75
+ assert coarse_ckpt is not None, "must provide a coarse checkpoint"
76
+ self.coarse = _load_model(
77
+ ckpt=coarse_ckpt,
78
+ lora_ckpt=coarse_lora_ckpt,
79
+ device=device,
80
+ chunk_size_s=coarse_chunk_size_s,
81
+ )
82
+ self.coarse_path = Path(coarse_ckpt)
83
+
84
+ # check if we have a coarse2fine ckpt
85
+ if coarse2fine_ckpt is not None:
86
+ self.c2f_path = Path(coarse2fine_ckpt)
87
+ self.c2f = _load_model(
88
+ ckpt=coarse2fine_ckpt,
89
+ lora_ckpt=coarse2fine_lora_ckpt,
90
+ device=device,
91
+ chunk_size_s=coarse2fine_chunk_size_s,
92
+ )
93
+ else:
94
+ self.c2f_path = None
95
+ self.c2f = None
96
+
97
+ if wavebeat_ckpt is not None:
98
+ logging.debug(f"loading wavebeat from {wavebeat_ckpt}")
99
+ self.beat_tracker = WaveBeat(wavebeat_ckpt)
100
+ self.beat_tracker.model.to(device)
101
+ else:
102
+ self.beat_tracker = None
103
+
104
+ self.device = device
105
+ self.loudness = -24.0
106
+
107
+ if compile:
108
+ logging.debug(f"compiling models")
109
+ self.coarse = torch.compile(self.coarse)
110
+ if self.c2f is not None:
111
+ self.c2f = torch.compile(self.c2f)
112
+ self.codec = torch.compile(self.codec)
113
+
114
+
115
+ @classmethod
116
+ def default(cls):
117
+ from . import download_codec, download_default
118
+ print(f"loading default vampnet")
119
+ codec_path = download_codec()
120
+ coarse_path, c2f_path = download_default()
121
+
122
+ return Interface(
123
+ coarse_ckpt=coarse_path,
124
+ coarse2fine_ckpt=c2f_path,
125
+ codec_ckpt=codec_path,
126
+ )
127
+
128
+ @classmethod
129
+ def available_models(cls):
130
+ from . import list_finetuned
131
+ return list_finetuned()
132
+
133
+
134
+ def load_finetuned(self, name: str):
135
+ assert name in self.available_models(), f"{name} is not a valid model name"
136
+ from . import download_finetuned
137
+ coarse_path, c2f_path = download_finetuned(name)
138
+ self.reload(
139
+ coarse_ckpt=coarse_path,
140
+ c2f_ckpt=c2f_path,
141
+ )
142
+
143
+ def reload(
144
+ self,
145
+ coarse_ckpt: str = None,
146
+ c2f_ckpt: str = None,
147
+ ):
148
+ if coarse_ckpt is not None:
149
+ # check if we already loaded, if so, don't reload
150
+ if self.coarse_path == Path(coarse_ckpt):
151
+ logging.debug(f"already loaded {coarse_ckpt}")
152
+ else:
153
+ self.coarse = _load_model(
154
+ ckpt=coarse_ckpt,
155
+ device=self.device,
156
+ chunk_size_s=self.coarse.chunk_size_s,
157
+ )
158
+ self.coarse_path = Path(coarse_ckpt)
159
+ logging.debug(f"loaded {coarse_ckpt}")
160
+
161
+ if c2f_ckpt is not None:
162
+ if self.c2f_path == Path(c2f_ckpt):
163
+ logging.debug(f"already loaded {c2f_ckpt}")
164
+ else:
165
+ self.c2f = _load_model(
166
+ ckpt=c2f_ckpt,
167
+ device=self.device,
168
+ chunk_size_s=self.c2f.chunk_size_s,
169
+ )
170
+ self.c2f_path = Path(c2f_ckpt)
171
+ logging.debug(f"loaded {c2f_ckpt}")
172
+
173
+ def s2t(self, seconds: float):
174
+ """seconds to tokens"""
175
+ if isinstance(seconds, np.ndarray):
176
+ return np.ceil(seconds * self.codec.sample_rate / self.codec.hop_length)
177
+ else:
178
+ return math.ceil(seconds * self.codec.sample_rate / self.codec.hop_length)
179
+
180
+ def s2t2s(self, seconds: float):
181
+ """seconds to tokens to seconds"""
182
+ return self.t2s(self.s2t(seconds))
183
+
184
+ def t2s(self, tokens: int):
185
+ """tokens to seconds"""
186
+ return tokens * self.codec.hop_length / self.codec.sample_rate
187
+
188
+ def to(self, device):
189
+ self.device = device
190
+ self.coarse.to(device)
191
+ self.codec.to(device)
192
+
193
+ if self.c2f is not None:
194
+ self.c2f.to(device)
195
+
196
+ if self.beat_tracker is not None:
197
+ self.beat_tracker.model.to(device)
198
+ return self
199
+
200
+ def decode(self, z: torch.Tensor):
201
+ return self.coarse.decode(z, self.codec)
202
+
203
+ def _preprocess(self, signal: AudioSignal):
204
+ signal = (
205
+ signal.clone()
206
+ .resample(self.codec.sample_rate)
207
+ .to_mono()
208
+ .normalize(self.loudness)
209
+ .ensure_max_of_audio(1.0)
210
+ )
211
+ logging.debug(f"length before codec preproc: {signal.samples.shape}")
212
+ signal.samples, length = self.codec.preprocess(signal.samples, signal.sample_rate)
213
+ logging.debug(f"length after codec preproc: {signal.samples.shape}")
214
+ return signal
215
+
216
+ @torch.inference_mode()
217
+ def encode(self, signal: AudioSignal):
218
+ signal = signal.to(self.device)
219
+ signal = self._preprocess(signal)
220
+ z = self.codec.encode(signal.samples, signal.sample_rate)["codes"]
221
+ return z
222
+
223
+ def snap_to_beats(
224
+ self,
225
+ signal: AudioSignal
226
+ ):
227
+ assert hasattr(self, "beat_tracker"), "No beat tracker loaded"
228
+ beats, downbeats = self.beat_tracker.extract_beats(signal)
229
+
230
+ # trim the signa around the first beat time
231
+ samples_begin = int(beats[0] * signal.sample_rate )
232
+ samples_end = int(beats[-1] * signal.sample_rate)
233
+ logging.debug(beats[0])
234
+ signal = signal.clone().trim(samples_begin, signal.length - samples_end)
235
+
236
+ return signal
237
+
238
+ def make_beat_mask(self,
239
+ signal: AudioSignal,
240
+ before_beat_s: float = 0.0,
241
+ after_beat_s: float = 0.02,
242
+ mask_downbeats: bool = True,
243
+ mask_upbeats: bool = True,
244
+ downbeat_downsample_factor: int = None,
245
+ beat_downsample_factor: int = None,
246
+ dropout: float = 0.0,
247
+ invert: bool = True,
248
+ ):
249
+ """make a beat synced mask. that is, make a mask that
250
+ places 1s at and around the beat, and 0s everywhere else.
251
+ """
252
+ assert self.beat_tracker is not None, "No beat tracker loaded"
253
+
254
+ # get the beat times
255
+ beats, downbeats = self.beat_tracker.extract_beats(signal)
256
+
257
+ # get the beat indices in z
258
+ beats_z, downbeats_z = self.s2t(beats), self.s2t(downbeats)
259
+
260
+ # remove downbeats from beats
261
+ beats_z = torch.tensor(beats_z)[~torch.isin(torch.tensor(beats_z), torch.tensor(downbeats_z))]
262
+ beats_z = beats_z.tolist()
263
+ downbeats_z = downbeats_z.tolist()
264
+
265
+ # make the mask
266
+ seq_len = self.s2t(signal.duration)
267
+ mask = torch.zeros(seq_len, device=self.device)
268
+
269
+ mask_b4 = self.s2t(before_beat_s)
270
+ mask_after = self.s2t(after_beat_s)
271
+
272
+ if beat_downsample_factor is not None:
273
+ if beat_downsample_factor < 1:
274
+ raise ValueError("mask_beat_downsample_factor must be >= 1 or None")
275
+ else:
276
+ beat_downsample_factor = 1
277
+
278
+ if downbeat_downsample_factor is not None:
279
+ if downbeat_downsample_factor < 1:
280
+ raise ValueError("mask_beat_downsample_factor must be >= 1 or None")
281
+ else:
282
+ downbeat_downsample_factor = 1
283
+
284
+ beats_z = beats_z[::beat_downsample_factor]
285
+ downbeats_z = downbeats_z[::downbeat_downsample_factor]
286
+ logging.debug(f"beats_z: {len(beats_z)}")
287
+ logging.debug(f"downbeats_z: {len(downbeats_z)}")
288
+
289
+ if mask_upbeats:
290
+ for beat_idx in beats_z:
291
+ _slice = int(beat_idx - mask_b4), int(beat_idx + mask_after)
292
+ num_steps = mask[_slice[0]:_slice[1]].shape[0]
293
+ _m = torch.ones(num_steps, device=self.device)
294
+ _m_mask = torch.bernoulli(_m * (1 - dropout))
295
+ _m = _m * _m_mask.long()
296
+
297
+ mask[_slice[0]:_slice[1]] = _m
298
+
299
+ if mask_downbeats:
300
+ for downbeat_idx in downbeats_z:
301
+ _slice = int(downbeat_idx - mask_b4), int(downbeat_idx + mask_after)
302
+ num_steps = mask[_slice[0]:_slice[1]].shape[0]
303
+ _m = torch.ones(num_steps, device=self.device)
304
+ _m_mask = torch.bernoulli(_m * (1 - dropout))
305
+ _m = _m * _m_mask.long()
306
+
307
+ mask[_slice[0]:_slice[1]] = _m
308
+
309
+ mask = mask.clamp(0, 1)
310
+ if invert:
311
+ mask = 1 - mask
312
+
313
+ mask = mask[None, None, :].bool().long()
314
+ if self.c2f is not None:
315
+ mask = mask.repeat(1, self.c2f.n_codebooks, 1)
316
+ else:
317
+ mask = mask.repeat(1, self.coarse.n_codebooks, 1)
318
+ return mask
319
+
320
+ def set_chunk_size(self, chunk_size_s: float):
321
+ self.coarse.chunk_size_s = chunk_size_s
322
+
323
+ @torch.inference_mode()
324
+ def coarse_to_fine(
325
+ self,
326
+ z: torch.Tensor,
327
+ mask: torch.Tensor = None,
328
+ return_mask: bool = False,
329
+ **kwargs
330
+ ):
331
+ assert self.c2f is not None, "No coarse2fine model loaded"
332
+ length = z.shape[-1]
333
+ chunk_len = self.s2t(self.c2f.chunk_size_s)
334
+ n_chunks = math.ceil(z.shape[-1] / chunk_len)
335
+
336
+ # zero pad to chunk_len
337
+ if length % chunk_len != 0:
338
+ pad_len = chunk_len - (length % chunk_len)
339
+ z = torch.nn.functional.pad(z, (0, pad_len))
340
+ mask = torch.nn.functional.pad(mask, (0, pad_len), value=1) if mask is not None else None
341
+
342
+ n_codebooks_to_append = self.c2f.n_codebooks - z.shape[1]
343
+ if n_codebooks_to_append > 0:
344
+ z = torch.cat([
345
+ z,
346
+ torch.zeros(z.shape[0], n_codebooks_to_append, z.shape[-1]).long().to(self.device)
347
+ ], dim=1)
348
+ logging.debug(f"appended {n_codebooks_to_append} codebooks to z")
349
+
350
+ # set the mask to 0 for all conditioning codebooks
351
+ if mask is not None:
352
+ mask = mask.clone()
353
+ mask[:, :self.c2f.n_conditioning_codebooks, :] = 0
354
+
355
+ fine_z = []
356
+ for i in range(n_chunks):
357
+ chunk = z[:, :, i * chunk_len : (i + 1) * chunk_len]
358
+ mask_chunk = mask[:, :, i * chunk_len : (i + 1) * chunk_len] if mask is not None else None
359
+
360
+ with torch.autocast("cuda", dtype=torch.bfloat16):
361
+ chunk = self.c2f.generate(
362
+ codec=self.codec,
363
+ time_steps=chunk_len,
364
+ start_tokens=chunk,
365
+ return_signal=False,
366
+ mask=mask_chunk,
367
+ cfg_guidance=None,
368
+ **kwargs
369
+ )
370
+ fine_z.append(chunk)
371
+
372
+ fine_z = torch.cat(fine_z, dim=-1)
373
+ if return_mask:
374
+ return fine_z[:, :, :length].clone(), apply_mask(fine_z, mask, self.c2f.mask_token)[0][:, :, :length].clone()
375
+
376
+ return fine_z[:, :, :length].clone()
377
+
378
+ @torch.inference_mode()
379
+ def coarse_vamp(
380
+ self,
381
+ z,
382
+ mask,
383
+ return_mask=False,
384
+ gen_fn=None,
385
+ **kwargs
386
+ ):
387
+ # coarse z
388
+ cz = z[:, : self.coarse.n_codebooks, :].clone()
389
+ mask = mask[:, : self.coarse.n_codebooks, :]
390
+ # assert cz.shape[-1] <= self.s2t(self.coarse.chunk_size_s), f"the sequence of tokens provided must match the one specified in the coarse chunk size, but got {cz.shape[-1]} and {self.s2t(self.coarse.chunk_size_s)}"
391
+
392
+ # cut into chunks, keep the last chunk separate if it's too small
393
+ chunk_len = self.s2t(self.coarse.chunk_size_s)
394
+ n_chunks = math.ceil(cz.shape[-1] / chunk_len)
395
+ last_chunk_len = cz.shape[-1] % chunk_len
396
+
397
+ cz_chunks = []
398
+ mask_chunks = []
399
+ for i in range(n_chunks):
400
+ chunk = cz[:, :, i * chunk_len : (i + 1) * chunk_len]
401
+ mask_chunk = mask[:, :, i * chunk_len : (i + 1) * chunk_len]
402
+
403
+ # make sure that the very first and last timestep of each chunk is 0 so that we don't get a weird
404
+ # discontinuity when we stitch the chunks back together
405
+ # only if there's already a 0 somewhere in the chunk
406
+ if torch.any(mask_chunk == 0):
407
+ mask_chunk[:, :, 0] = 0
408
+ mask_chunk[:, :, -1] = 0
409
+
410
+ cz_chunks.append(chunk)
411
+ mask_chunks.append(mask_chunk)
412
+
413
+ # now vamp each chunk
414
+ cz_masked_chunks = []
415
+ cz_vamped_chunks = []
416
+ for chunk, mask_chunk in zip(cz_chunks, mask_chunks):
417
+ cz_masked_chunk, mask_chunk = apply_mask(chunk, mask_chunk, self.coarse.mask_token)
418
+ cz_masked_chunk = cz_masked_chunk[:, : self.coarse.n_codebooks, :]
419
+ cz_masked_chunks.append(cz_masked_chunk)
420
+
421
+
422
+ gen_fn = gen_fn or self.coarse.generate
423
+ with torch.autocast("cuda", dtype=torch.bfloat16):
424
+ c_vamp_chunk = gen_fn(
425
+ codec=self.codec,
426
+ time_steps=chunk_len,
427
+ start_tokens=cz_masked_chunk,
428
+ return_signal=False,
429
+ mask=mask_chunk,
430
+ **kwargs
431
+ )
432
+ cz_vamped_chunks.append(c_vamp_chunk)
433
+
434
+ # stitch the chunks back together
435
+ cz_masked = torch.cat(cz_masked_chunks, dim=-1)
436
+ c_vamp = torch.cat(cz_vamped_chunks, dim=-1)
437
+
438
+ # add the fine codes back in
439
+ c_vamp = torch.cat(
440
+ [c_vamp, z[:, self.coarse.n_codebooks :, :]],
441
+ dim=1
442
+ )
443
+
444
+ if return_mask:
445
+ return c_vamp, cz_masked
446
+
447
+ return c_vamp
448
+
449
+ def build_mask(self,
450
+ z: torch.Tensor,
451
+ sig: AudioSignal = None,
452
+ rand_mask_intensity: float = 1.0,
453
+ prefix_s: float = 0.0,
454
+ suffix_s: float = 0.0,
455
+ periodic_prompt: int = 7,
456
+ periodic_prompt_width: int = 1,
457
+ onset_mask_width: int = 0,
458
+ _dropout: float = 0.0,
459
+ upper_codebook_mask: int = 3,
460
+ ncc: int = 0,
461
+ ):
462
+ mask = linear_random(z, rand_mask_intensity)
463
+ mask = mask_and(
464
+ mask,
465
+ inpaint(z, self.s2t(prefix_s), self.s2t(suffix_s)),
466
+ )
467
+
468
+ pmask = periodic_mask(z, periodic_prompt, periodic_prompt_width, random_roll=True)
469
+ mask = mask_and(mask, pmask)
470
+
471
+ if onset_mask_width > 0:
472
+ assert sig is not None, f"must provide a signal to use onset mask"
473
+ mask = mask_and(
474
+ mask, onset_mask(
475
+ sig, z, self,
476
+ width=onset_mask_width
477
+ )
478
+ )
479
+
480
+ mask = dropout(mask, _dropout)
481
+ mask = codebook_unmask(mask, ncc)
482
+
483
+ mask = codebook_mask(mask, int(upper_codebook_mask), None)
484
+ return mask
485
+
486
+ def vamp(
487
+ self,
488
+ codes: torch.Tensor,
489
+ mask: torch.Tensor,
490
+ batch_size: int = 1,
491
+ feedback_steps: int = 1,
492
+ time_stretch_factor: int = 1,
493
+ return_mask: bool = False,
494
+ **kwargs,
495
+ ):
496
+ z = codes
497
+
498
+ # expand z to batch size
499
+ z = z.expand(batch_size, -1, -1)
500
+ mask = mask.expand(batch_size, -1, -1)
501
+
502
+ # stretch mask and z to match the time stretch factor
503
+ # we'll add (stretch_factor - 1) mask tokens in between each timestep of z
504
+ # and we'll make the mask 1 in all the new slots we added
505
+ if time_stretch_factor > 1:
506
+ z = z.repeat_interleave(time_stretch_factor, dim=-1)
507
+ mask = mask.repeat_interleave(time_stretch_factor, dim=-1)
508
+ added_mask = torch.ones_like(mask)
509
+ added_mask[:, :, ::time_stretch_factor] = 0
510
+ mask = mask.bool() | added_mask.bool()
511
+ mask = mask.long()
512
+
513
+ # the forward pass
514
+ logging.debug(z.shape)
515
+ logging.debug("coarse!")
516
+ zv, mask_z = self.coarse_vamp(
517
+ z,
518
+ mask=mask,
519
+ return_mask=True,
520
+ **kwargs
521
+ )
522
+
523
+ # add the top codebooks back in
524
+ if zv.shape[1] < z.shape[1]:
525
+ logging.debug(f"adding {z.shape[1] - zv.shape[1]} codebooks back in")
526
+ zv = torch.cat(
527
+ [zv, z[:, self.coarse.n_codebooks :, :]],
528
+ dim=1
529
+ )
530
+
531
+ # now, coarse2fine
532
+ logging.debug(f"coarse2fine!")
533
+ zv, fine_zv_mask = self.coarse_to_fine(
534
+ zv,
535
+ mask=mask,
536
+ typical_filtering=True,
537
+ _sampling_steps=[2],
538
+ return_mask=True
539
+ )
540
+ mask_z = torch.cat(
541
+ [mask_z[:, :self.coarse.n_codebooks, :], fine_zv_mask[:, self.coarse.n_codebooks:, :]],
542
+ dim=1
543
+ )
544
+
545
+ z = zv
546
+
547
+ if return_mask:
548
+ return z, mask_z.cpu(),
549
+ else:
550
+ return z
551
+
552
+ def visualize_codes(self, z: torch.Tensor):
553
+ import matplotlib.pyplot as plt
554
+ # make sure the figsize is square when imshow is called
555
+ fig = plt.figure(figsize=(10, 7))
556
+ # in subplots, plot z[0] and the mask
557
+ # set title to "codes" and "mask"
558
+ fig.add_subplot(2, 1, 1)
559
+ plt.imshow(z[0].cpu().numpy(), aspect='auto', origin='lower', cmap="tab20")
560
+ plt.title("codes")
561
+ plt.ylabel("codebook index")
562
+ # set the xticks to seconds
563
+
564
+
565
+ if __name__ == "__main__":
566
+ import audiotools as at
567
+ import logging
568
+ logger = logging.getLogger()
569
+ logger.setLevel(logging.INFO)
570
+ torch.set_logging.debugoptions(threshold=10000)
571
+ at.util.seed(42)
572
+
573
+ interface = Interface(
574
+ coarse_ckpt="./models/vampnet/coarse.pth",
575
+ coarse2fine_ckpt="./models/vampnet/c2f.pth",
576
+ codec_ckpt="./models/vampnet/codec.pth",
577
+ device="cuda",
578
+ wavebeat_ckpt="./models/wavebeat.pth"
579
+ )
580
+
581
+
582
+ sig = at.AudioSignal('assets/example.wav')
583
+
584
+ z = interface.encode(sig)
585
+
586
+
587
+ mask = interface.build_mask(
588
+ z=z,
589
+ sig=sig,
590
+ rand_mask_intensity=1.0,
591
+ prefix_s=0.0,
592
+ suffix_s=0.0,
593
+ periodic_prompt=7,
594
+ periodic_prompt2=7,
595
+ periodic_prompt_width=1,
596
+ onset_mask_width=5,
597
+ _dropout=0.0,
598
+ upper_codebook_mask=3,
599
+ upper_codebook_mask_2=None,
600
+ ncc=0,
601
+ )
602
+
603
+ zv, mask_z = interface.coarse_vamp(
604
+ z,
605
+ mask=mask,
606
+ return_mask=True,
607
+ gen_fn=interface.coarse.generate
608
+ )
609
+
610
+
611
+ use_coarse2fine = True
612
+ if use_coarse2fine:
613
+ zv = interface.coarse_to_fine(zv, mask=mask)
614
+ breakpoint()
615
+
616
+ mask = interface.decode(mask_z).cpu()
617
+
618
+ sig = interface.decode(zv).cpu()
619
+
620
+
621
+ logging.debug("done")
622
+
623
+
vampnet/mask.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from audiotools import AudioSignal
5
+
6
+ from .util import scalar_to_batch_tensor
7
+
8
+ def _gamma(r):
9
+ return (r * torch.pi / 2).cos().clamp(1e-10, 1.0)
10
+
11
+ def _invgamma(y):
12
+ if not torch.is_tensor(y):
13
+ y = torch.tensor(y)[None]
14
+ return 2 * y.acos() / torch.pi
15
+
16
+ def full_mask(x: torch.Tensor):
17
+ assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
18
+ return torch.ones_like(x).long()
19
+
20
+ def empty_mask(x: torch.Tensor):
21
+ assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
22
+ return torch.zeros_like(x).long()
23
+
24
+ def apply_mask(
25
+ x: torch.Tensor,
26
+ mask: torch.Tensor,
27
+ mask_token: int
28
+ ):
29
+ assert mask.ndim == 3, "mask must be (batch, n_codebooks, seq), but got {mask.ndim}"
30
+ assert mask.shape == x.shape, f"mask must be same shape as x, but got {mask.shape} and {x.shape}"
31
+ assert mask.dtype == torch.long, "mask must be long dtype, but got {mask.dtype}"
32
+ assert ~torch.any(mask > 1), "mask must be binary"
33
+ assert ~torch.any(mask < 0), "mask must be binary"
34
+
35
+ fill_x = torch.full_like(x, mask_token)
36
+ x = x * (1 - mask) + fill_x * mask
37
+
38
+ return x, mask
39
+
40
+ def random(
41
+ x: torch.Tensor,
42
+ r: torch.Tensor
43
+ ):
44
+ assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
45
+ if not isinstance(r, torch.Tensor):
46
+ r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device)
47
+
48
+ r = _gamma(r)[:, None, None]
49
+ probs = torch.ones_like(x) * r
50
+
51
+ mask = torch.bernoulli(probs)
52
+ mask = mask.round().long()
53
+
54
+ return mask
55
+
56
+ def linear_random(
57
+ x: torch.Tensor,
58
+ r: torch.Tensor,
59
+ ):
60
+ assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
61
+ if not isinstance(r, torch.Tensor):
62
+ r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float()
63
+ r = r[:, None, None]
64
+
65
+ probs = torch.ones_like(x).to(x.device).float()
66
+ # expand to batch and codebook dims
67
+ probs = probs.expand(x.shape[0], x.shape[1], -1)
68
+ probs = probs * r
69
+
70
+ mask = torch.bernoulli(probs)
71
+ mask = mask.round().long()
72
+
73
+ return mask
74
+
75
+ def inpaint(x: torch.Tensor,
76
+ n_prefix,
77
+ n_suffix,
78
+ ):
79
+ assert n_prefix is not None
80
+ assert n_suffix is not None
81
+
82
+ mask = full_mask(x)
83
+
84
+ # if we have a prefix or suffix, set their mask prob to 0
85
+ if n_prefix > 0:
86
+ if not isinstance(n_prefix, torch.Tensor):
87
+ n_prefix = scalar_to_batch_tensor(n_prefix, x.shape[0]).to(x.device)
88
+ for i, n in enumerate(n_prefix):
89
+ if n > 0:
90
+ mask[i, :, :n] = 0.0
91
+ if n_suffix > 0:
92
+ if not isinstance(n_suffix, torch.Tensor):
93
+ n_suffix = scalar_to_batch_tensor(n_suffix, x.shape[0]).to(x.device)
94
+ for i, n in enumerate(n_suffix):
95
+ if n > 0:
96
+ mask[i, :, -n:] = 0.0
97
+
98
+
99
+ return mask
100
+
101
+ def periodic_mask(x: torch.Tensor,
102
+ period: int,width: int = 1,
103
+ random_roll=False,
104
+ ):
105
+ mask = full_mask(x)
106
+ if period == 0:
107
+ return mask
108
+
109
+ if not isinstance(period, torch.Tensor):
110
+ period = scalar_to_batch_tensor(period, x.shape[0])
111
+ for i, factor in enumerate(period):
112
+ if factor == 0:
113
+ continue
114
+ for j in range(mask.shape[-1]):
115
+ if j % factor == 0:
116
+ # figure out how wide the mask should be
117
+ j_start = max(0, j - width // 2 )
118
+ j_end = min(mask.shape[-1] - 1, j + width // 2 ) + 1
119
+ # flip a coin for each position in the mask
120
+ j_mask = torch.bernoulli(torch.ones(j_end - j_start))
121
+ assert torch.all(j_mask == 1)
122
+ j_fill = torch.ones_like(j_mask) * (1 - j_mask)
123
+ assert torch.all(j_fill == 0)
124
+ # fill
125
+ mask[i, :, j_start:j_end] = j_fill
126
+ if random_roll:
127
+ # add a random offset to the mask
128
+ offset = torch.randint(0, period[0], (1,))
129
+ mask = torch.roll(mask, offset.item(), dims=-1)
130
+
131
+ return mask
132
+
133
+ def codebook_unmask(
134
+ mask: torch.Tensor,
135
+ n_conditioning_codebooks: int
136
+ ):
137
+ if n_conditioning_codebooks == None:
138
+ return mask
139
+ # if we have any conditioning codebooks, set their mask to 0
140
+ mask = mask.clone()
141
+ mask[:, :n_conditioning_codebooks, :] = 0
142
+ return mask
143
+
144
+ def codebook_mask(mask: torch.Tensor, val1: int, val2: int = None):
145
+ mask = mask.clone()
146
+ mask[:, val1:, :] = 1
147
+ # val2 = val2 or val1
148
+ # vs = torch.linspace(val1, val2, mask.shape[1])
149
+ # for t, v in enumerate(vs):
150
+ # v = int(v)
151
+ # mask[:, v:, t] = 1
152
+
153
+ return mask
154
+
155
+ def mask_and(
156
+ mask1: torch.Tensor,
157
+ mask2: torch.Tensor
158
+ ):
159
+ assert mask1.shape == mask2.shape, "masks must be same shape"
160
+ return torch.min(mask1, mask2)
161
+
162
+ def dropout(
163
+ mask: torch.Tensor,
164
+ p: float,
165
+ ):
166
+ assert 0 <= p <= 1, "p must be between 0 and 1"
167
+ assert mask.max() <= 1, "mask must be binary"
168
+ assert mask.min() >= 0, "mask must be binary"
169
+ mask = (~mask.bool()).float()
170
+ mask = torch.bernoulli(mask * (1 - p))
171
+ mask = ~mask.round().bool()
172
+ return mask.long()
173
+
174
+ def mask_or(
175
+ mask1: torch.Tensor,
176
+ mask2: torch.Tensor
177
+ ):
178
+ assert mask1.shape == mask2.shape, f"masks must be same shape, but got {mask1.shape} and {mask2.shape}"
179
+ assert mask1.max() <= 1, "mask1 must be binary"
180
+ assert mask2.max() <= 1, "mask2 must be binary"
181
+ assert mask1.min() >= 0, "mask1 must be binary"
182
+ assert mask2.min() >= 0, "mask2 must be binary"
183
+ return (mask1 + mask2).clamp(0, 1)
184
+
185
+ def time_stretch_mask(
186
+ x: torch.Tensor,
187
+ stretch_factor: int,
188
+ ):
189
+ assert stretch_factor >= 1, "stretch factor must be >= 1"
190
+ c_seq_len = x.shape[-1]
191
+ x = x.repeat_interleave(stretch_factor, dim=-1)
192
+
193
+ # trim cz to the original length
194
+ x = x[:, :, :c_seq_len]
195
+
196
+ mask = periodic_mask(x, stretch_factor, width=1)
197
+ return mask
198
+
199
+ def onset_mask(
200
+ sig: AudioSignal,
201
+ z: torch.Tensor,
202
+ interface,
203
+ width: int = 1,
204
+ ):
205
+ import librosa
206
+
207
+ onset_frame_idxs = librosa.onset.onset_detect(
208
+ y=sig.samples[0][0].detach().cpu().numpy(), sr=sig.sample_rate,
209
+ hop_length=interface.codec.hop_length,
210
+ backtrack=True,
211
+ )
212
+ if len(onset_frame_idxs) == 0:
213
+ print("no onsets detected")
214
+ print("onset_frame_idxs", onset_frame_idxs)
215
+ print("mask shape", z.shape)
216
+
217
+ mask = torch.ones_like(z)
218
+ for idx in onset_frame_idxs:
219
+ mask[:, :, idx-width:idx+width] = 0
220
+
221
+ return mask
222
+
223
+
224
+
225
+ if __name__ == "__main__":
226
+ sig = AudioSignal("assets/example.wav")
vampnet/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import audiotools
2
+
3
+ audiotools.ml.BaseModel.INTERN += ["vampnet.modules.**"]
4
+ audiotools.ml.BaseModel.EXTERN += ["einops", "flash_attn.flash_attention", "loralib"]
5
+
6
+ from .transformer import VampNet
vampnet/modules/activations.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+
8
+
9
+ class NewGELU(nn.Module):
10
+ """
11
+ Implementation of the GELU activation function currently in Google BERT repo
12
+ (identical to OpenAI GPT). Also see the Gaussian Error Linear Units
13
+ paper: https://arxiv.org/abs/1606.08415
14
+ """
15
+
16
+ def forward(self, x):
17
+ return (
18
+ 0.5
19
+ * x
20
+ * (
21
+ 1.0
22
+ + torch.tanh(
23
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
24
+ )
25
+ )
26
+ )
27
+
28
+ class GatedGELU(nn.Module):
29
+ def __init__(self):
30
+ super().__init__()
31
+ self.gelu = NewGELU()
32
+
33
+ def forward(self, x, dim: int = -1):
34
+ p1, p2 = x.chunk(2, dim=dim)
35
+ return p1 * self.gelu(p2)
36
+
37
+ class Snake1d(nn.Module):
38
+ def __init__(self, channels):
39
+ super().__init__()
40
+ self.alpha = nn.Parameter(torch.ones(channels))
41
+
42
+ def forward(self, x):
43
+ return x + (self.alpha + 1e-9).reciprocal() * torch.sin(self.alpha * x).pow(2)
44
+
45
+ def get_activation(name: str = "relu"):
46
+ if name == "relu":
47
+ return nn.ReLU
48
+ elif name == "gelu":
49
+ return NewGELU
50
+ elif name == "geglu":
51
+ return GatedGELU
52
+ elif name == "snake":
53
+ return Snake1d
54
+ else:
55
+ raise ValueError(f"Unrecognized activation {name}")
vampnet/modules/layers.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import Optional
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch.nn.utils import weight_norm
10
+
11
+ # Scripting this brings model speed up 1.4x
12
+ @torch.jit.script
13
+ def snake(x, alpha):
14
+ shape = x.shape
15
+ x = x.reshape(shape[0], shape[1], -1)
16
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
17
+ x = x.reshape(shape)
18
+ return x
19
+
20
+
21
+ class Snake1d(nn.Module):
22
+ def __init__(self, channels):
23
+ super().__init__()
24
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
25
+
26
+ def forward(self, x):
27
+ return snake(x, self.alpha)
28
+
29
+
30
+ def num_params(model):
31
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
32
+
33
+
34
+ def recurse_children(module, fn):
35
+ for child in module.children():
36
+ if isinstance(child, nn.ModuleList):
37
+ for c in child:
38
+ yield recurse_children(c, fn)
39
+ if isinstance(child, nn.ModuleDict):
40
+ for c in child.values():
41
+ yield recurse_children(c, fn)
42
+
43
+ yield recurse_children(child, fn)
44
+ yield fn(child)
45
+
46
+
47
+ def WNConv1d(*args, **kwargs):
48
+ return weight_norm(nn.Conv1d(*args, **kwargs))
49
+
50
+
51
+ def WNConvTranspose1d(*args, **kwargs):
52
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
53
+
54
+
55
+ class SequentialWithFiLM(nn.Module):
56
+ """
57
+ handy wrapper for nn.Sequential that allows FiLM layers to be
58
+ inserted in between other layers.
59
+ """
60
+
61
+ def __init__(self, *layers):
62
+ super().__init__()
63
+ self.layers = nn.ModuleList(layers)
64
+
65
+ @staticmethod
66
+ def has_film(module):
67
+ mod_has_film = any(
68
+ [res for res in recurse_children(module, lambda c: isinstance(c, FiLM))]
69
+ )
70
+ return mod_has_film
71
+
72
+ def forward(self, x, cond):
73
+ for layer in self.layers:
74
+ if self.has_film(layer):
75
+ x = layer(x, cond)
76
+ else:
77
+ x = layer(x)
78
+ return x
79
+
80
+
81
+ class FiLM(nn.Module):
82
+ def __init__(self, input_dim: int, output_dim: int):
83
+ super().__init__()
84
+
85
+ self.input_dim = input_dim
86
+ self.output_dim = output_dim
87
+
88
+ if input_dim > 0:
89
+ self.beta = nn.Linear(input_dim, output_dim)
90
+ self.gamma = nn.Linear(input_dim, output_dim)
91
+
92
+ def forward(self, x, r):
93
+ if self.input_dim == 0:
94
+ return x
95
+ else:
96
+ beta, gamma = self.beta(r), self.gamma(r)
97
+ beta, gamma = (
98
+ beta.view(x.size(0), self.output_dim, 1),
99
+ gamma.view(x.size(0), self.output_dim, 1),
100
+ )
101
+ x = x * (gamma + 1) + beta
102
+ return x
103
+
104
+
105
+ class CodebookEmbedding(nn.Module):
106
+ def __init__(
107
+ self,
108
+ vocab_size: int,
109
+ latent_dim: int,
110
+ n_codebooks: int,
111
+ emb_dim: int,
112
+ special_tokens: Optional[Tuple[str]] = None,
113
+ ):
114
+ super().__init__()
115
+ self.n_codebooks = n_codebooks
116
+ self.emb_dim = emb_dim
117
+ self.latent_dim = latent_dim
118
+ self.vocab_size = vocab_size
119
+
120
+ if special_tokens is not None:
121
+ for tkn in special_tokens:
122
+ self.special = nn.ParameterDict(
123
+ {
124
+ tkn: nn.Parameter(torch.randn(n_codebooks, self.latent_dim))
125
+ for tkn in special_tokens
126
+ }
127
+ )
128
+ self.special_idxs = {
129
+ tkn: i + vocab_size for i, tkn in enumerate(special_tokens)
130
+ }
131
+
132
+ self.out_proj = nn.Conv1d(n_codebooks * self.latent_dim, self.emb_dim, 1)
133
+
134
+ def from_codes(self, codes: torch.Tensor, codec):
135
+ """
136
+ get a sequence of continuous embeddings from a sequence of discrete codes.
137
+ unlike it's counterpart in the original VQ-VAE, this function adds for any special tokens
138
+ necessary for the language model, like <MASK>.
139
+ """
140
+ n_codebooks = codes.shape[1]
141
+ latent = []
142
+ for i in range(n_codebooks):
143
+ c = codes[:, i, :]
144
+
145
+ lookup_table = codec.quantizer.quantizers[i].codebook.weight
146
+ if hasattr(self, "special"):
147
+ special_lookup = torch.cat(
148
+ [self.special[tkn][i : i + 1] for tkn in self.special], dim=0
149
+ )
150
+ lookup_table = torch.cat([lookup_table, special_lookup], dim=0)
151
+
152
+ l = F.embedding(c, lookup_table).transpose(1, 2)
153
+ latent.append(l)
154
+
155
+ latent = torch.cat(latent, dim=1)
156
+ return latent
157
+
158
+ def forward(self, latents: torch.Tensor):
159
+ """
160
+ project a sequence of latents to a sequence of embeddings
161
+ """
162
+ x = self.out_proj(latents)
163
+ return x
164
+
vampnet/modules/transformer.py ADDED
@@ -0,0 +1,965 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import logging
3
+ from typing import Optional, Tuple, Union, List
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+ import loralib as lora
11
+ import audiotools as at
12
+
13
+ from .activations import get_activation
14
+ from .layers import CodebookEmbedding
15
+ from .layers import FiLM
16
+ from .layers import SequentialWithFiLM
17
+ from .layers import WNConv1d
18
+ from ..util import scalar_to_batch_tensor, codebook_flatten, codebook_unflatten
19
+ from ..mask import _gamma
20
+
21
+ LORA_R = 8
22
+
23
+ # def log(t, eps=1e-20):
24
+ # return torch.log(t + eps)
25
+
26
+
27
+ def gumbel_noise_like(t):
28
+ noise = torch.zeros_like(t).uniform_(1e-20, 1)
29
+ return -torch.log(-torch.log(noise))
30
+
31
+
32
+ def gumbel_sample(t, temperature=1.0, dim=-1):
33
+ return ((t / max(temperature, 1e-10)) + gumbel_noise_like(t)).argmax(dim=dim)
34
+
35
+
36
+ class RMSNorm(nn.Module):
37
+ def __init__(self, hidden_size: int, eps=1e-6):
38
+ super().__init__()
39
+ self.weight = nn.Parameter(torch.ones(hidden_size))
40
+ self.var_eps = eps
41
+
42
+ def forward(self, x):
43
+ """Returns root mean square normalized version of input `x`
44
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known
45
+ # as Root Mean Square Layer Normalization https://arxiv.org/abs/1910.07467
46
+ # thus varience is calculated w/o mean and there is no bias
47
+ Parameters
48
+ ----------
49
+ x : Tensor[B x T x D]
50
+ Returns
51
+ -------
52
+ Tensor[B x T x D]
53
+ """
54
+ var = x.pow(2).mean(-1, keepdim=True)
55
+ x = x * torch.rsqrt(var + self.var_eps)
56
+
57
+ return self.weight * x
58
+
59
+
60
+ class FeedForward(nn.Module):
61
+ def __init__(
62
+ self, d_model: int = 512, dropout: float = 0.1, activation: str = "geglu"
63
+ ):
64
+ super().__init__()
65
+ factor = 2 if activation == "geglu" else 1
66
+ self.w_1 = lora.Linear(d_model, d_model * 4, bias=False, r=LORA_R)
67
+ self.w_2 = lora.Linear(d_model * 4 // factor, d_model, bias=False, r=LORA_R)
68
+ self.drop = nn.Dropout(dropout)
69
+ self.act = get_activation(activation)()
70
+
71
+ def forward(self, x):
72
+ """Computes position-wise feed-forward layer
73
+ Parameters
74
+ ----------
75
+ x : Tensor[B x T x D]
76
+ Returns
77
+ -------
78
+ Tensor[B x T x D]
79
+ """
80
+ x = self.w_1(x)
81
+ x = self.act(x)
82
+ x = self.drop(x)
83
+ x = self.w_2(x)
84
+ return x
85
+
86
+
87
+ class MultiHeadRelativeAttention(nn.Module):
88
+ def __init__(
89
+ self,
90
+ n_head: int = 8,
91
+ d_model: int = 512,
92
+ dropout: float = 0.1,
93
+ bidirectional: bool = True,
94
+ has_relative_attention_bias: bool = True,
95
+ attention_num_buckets: int = 32,
96
+ attention_max_distance: int = 128,
97
+ ):
98
+ super().__init__()
99
+ d_head = d_model // n_head
100
+ self.n_head = n_head
101
+ self.d_head = d_head
102
+ self.bidirectional = bidirectional
103
+ self.has_relative_attention_bias = has_relative_attention_bias
104
+ self.attention_num_buckets = attention_num_buckets
105
+ self.attention_max_distance = attention_max_distance
106
+
107
+ # Create linear query, key, value projections
108
+ self.w_qs = lora.Linear(d_model, d_model, bias=False, r=LORA_R)
109
+ self.w_ks = nn.Linear(d_model, d_model, bias=False)
110
+ self.w_vs = lora.Linear(d_model, d_model, bias=False, r=LORA_R)
111
+
112
+ # Create linear final output projection
113
+ self.fc = lora.Linear(d_model, d_model, bias=False, r=LORA_R)
114
+
115
+ # Dropout for attention output weights
116
+ self.dropout = nn.Dropout(dropout)
117
+
118
+ # Create relative positional embeddings (if turned on)
119
+ if has_relative_attention_bias:
120
+ self.relative_attention_bias = nn.Embedding(attention_num_buckets, n_head)
121
+
122
+ def _relative_position_bucket(self, relative_position):
123
+ """Converts unbounded relative position into bounded set of buckets
124
+ with half "exact" buckets (1 position = 1 bucket) and half "log-spaced"
125
+ buckets
126
+ Parameters
127
+ ----------
128
+ relative_position : Tensor[T_q x T_kv]
129
+ Relative positions between queries and key_value items
130
+ Returns
131
+ -------
132
+ Tensor[T_q x T_kv]
133
+ Input relative positions converted into buckets
134
+ """
135
+ relative_buckets = 0
136
+ num_buckets = self.attention_num_buckets
137
+ max_distance = self.attention_max_distance
138
+
139
+ # Convert relative position for (-inf, inf) to [0, inf]
140
+ # Negative relative positions correspond to past
141
+ # Positive relative positions correspond to future
142
+ if self.bidirectional:
143
+ # use half buckets for each side (past / future)
144
+ num_buckets //= 2
145
+
146
+ # Shift the position positions by `num_buckets` to wrap around
147
+ # negative positions
148
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
149
+ relative_position = torch.abs(relative_position)
150
+ else:
151
+ # If not bidirectional, ignore positive positions and wrap
152
+ # negative positions to positive
153
+ relative_position = -torch.min(
154
+ relative_position, torch.zeros_like(relative_position)
155
+ )
156
+
157
+ # Allocate half of the buckets are for exact increments in positions
158
+ max_exact = num_buckets // 2
159
+ is_small = relative_position < max_exact
160
+
161
+ # The other half of the buckets are for logarithmically bigger bins in
162
+ # positions up to `max_distance`
163
+ relative_postion_if_large = max_exact + (
164
+ torch.log(relative_position.float() / max_exact)
165
+ / math.log(max_distance / max_exact)
166
+ * (num_buckets - max_exact)
167
+ ).to(torch.long)
168
+
169
+ # Clip the max relative position to `num_buckets - 1`
170
+ relative_postion_if_large = torch.min(
171
+ relative_postion_if_large,
172
+ torch.full_like(relative_postion_if_large, num_buckets - 1),
173
+ )
174
+
175
+ # Choose relative buckets based on small or large positions
176
+ relative_buckets += torch.where(
177
+ is_small, relative_position, relative_postion_if_large
178
+ )
179
+
180
+ return relative_buckets
181
+
182
+ def compute_bias(self, query_length, key_length):
183
+ """Computes a position bias scalar for each index in query_length x key_length
184
+ Parameters
185
+ ----------
186
+ query_length : int
187
+ key_length : int
188
+ Returns
189
+ -------
190
+ Tensor[heads x 1 x T_q x T_kv]
191
+ Position bias to be applied on attention logits
192
+ """
193
+
194
+ query_position = torch.arange(query_length, dtype=torch.long)[:, None]
195
+ key_position = torch.arange(key_length, dtype=torch.long)[None, :]
196
+ relative_position = key_position - query_position
197
+
198
+ # Convert relative position to buckets
199
+ relative_position_bucket = self._relative_position_bucket(relative_position)
200
+ relative_position_bucket = relative_position_bucket.to(
201
+ self.relative_attention_bias.weight.device
202
+ )
203
+
204
+ # Index attention bias values
205
+ values = self.relative_attention_bias(relative_position_bucket)
206
+ values = rearrange(values, "q k h -> h 1 q k")
207
+
208
+ return values
209
+
210
+ def forward(self, q, k, v, mask=None, position_bias=None):
211
+ """Computes attention over (keys, values) for every timestep in query
212
+ Parameters
213
+ ----------
214
+ q : Tensor[B x T_q x d_model]
215
+ Query vectors
216
+ k : Tensor[B x T_kv x d_model]
217
+ Key vectors to compute attention over
218
+ v : Tensor[B x T_kv x d_model]
219
+ Value vectors corresponding to the keys
220
+ mask : Tensor[B x T_q x T_kv], optional
221
+ position_bias: Tensor[head x 1 x T_q x T_kv]
222
+ Returns
223
+ -------
224
+ Tensor[B x T_q x d_model]
225
+ Outputs after attending (key, value) using queries
226
+ """
227
+ # Compute query, key, value projections
228
+ q = rearrange(self.w_qs(q), "b l (head k) -> head b l k", head=self.n_head)
229
+ k = rearrange(self.w_ks(k), "b t (head k) -> head b t k", head=self.n_head)
230
+ v = rearrange(self.w_vs(v), "b t (head k) -> head b t k", head=self.n_head)
231
+
232
+ # Compute attention matrix
233
+ attn = torch.einsum("hblk,hbtk->hblt", [q, k]) / np.sqrt(q.shape[-1])
234
+
235
+ # Add relative position bias to attention scores
236
+ if position_bias is None:
237
+ if self.has_relative_attention_bias:
238
+ position_bias = self.compute_bias(q.size(-2), k.size(-2))
239
+ else:
240
+ position_bias = torch.zeros_like(attn)
241
+ attn += position_bias
242
+
243
+ # Apply mask to attention scores to prevent looking up invalid locations
244
+ if mask is not None:
245
+ attn = attn.masked_fill(mask[None] == 0, -1e9)
246
+
247
+ # Normalize attention scores and add dropout
248
+ attn = torch.softmax(attn, dim=3)
249
+ attn = self.dropout(attn)
250
+
251
+ # Compute attended outputs (product of attention matrix and values)
252
+ output = torch.einsum("hblt,hbtv->hblv", [attn, v])
253
+ output = rearrange(output, "head b l v -> b l (head v)")
254
+ output = self.fc(output)
255
+
256
+ return output, position_bias
257
+
258
+
259
+ class TransformerLayer(nn.Module):
260
+ def __init__(
261
+ self,
262
+ d_model: int = 512,
263
+ d_cond: int = 64,
264
+ n_heads: int = 8,
265
+ bidirectional: bool = True,
266
+ is_decoder: bool = False,
267
+ has_relative_attention_bias: bool = False,
268
+ flash_attn: bool = False,
269
+ dropout: float = 0.1,
270
+ ):
271
+ super().__init__()
272
+ # Store args
273
+ self.is_decoder = is_decoder
274
+
275
+ # Create self-attention layer
276
+ self.norm_1 = RMSNorm(d_model)
277
+ self.film_1 = FiLM(d_cond, d_model)
278
+ self.flash_attn = flash_attn
279
+
280
+ if flash_attn:
281
+ from flash_attn.flash_attention import FlashMHA
282
+ self.self_attn = FlashMHA(
283
+ embed_dim=d_model,
284
+ num_heads=n_heads,
285
+ attention_dropout=dropout,
286
+ causal=False,
287
+ )
288
+ else:
289
+ self.self_attn = MultiHeadRelativeAttention(
290
+ n_heads, d_model, dropout, bidirectional, has_relative_attention_bias
291
+ )
292
+
293
+ # (Optional) Create cross-attention layer
294
+ if is_decoder:
295
+ self.norm_2 = RMSNorm(d_model)
296
+ self.film_2 = FiLM(d_cond, d_model)
297
+ self.cross_attn = MultiHeadRelativeAttention(
298
+ n_heads,
299
+ d_model,
300
+ dropout,
301
+ bidirectional=True,
302
+ has_relative_attention_bias=False,
303
+ )
304
+
305
+ # Create last feed-forward layer
306
+ self.norm_3 = RMSNorm(d_model)
307
+ self.film_3 = FiLM(d_cond, d_model)
308
+ self.feed_forward = FeedForward(d_model=d_model, dropout=dropout)
309
+
310
+ # Create dropout
311
+ self.dropout = nn.Dropout(dropout)
312
+
313
+ def forward(
314
+ self,
315
+ x,
316
+ x_mask,
317
+ cond,
318
+ src=None,
319
+ src_mask=None,
320
+ position_bias=None,
321
+ encoder_decoder_position_bias=None,
322
+ ):
323
+ """Computes one transformer layer consisting of self attention, (op) cross attention
324
+ and feedforward layer
325
+ Parameters
326
+ ----------
327
+ x : Tensor[B x T_q x D]
328
+ x_mask : Tensor[B x T_q]
329
+ src : Tensor[B x T_kv x D], optional
330
+ src_mask : Tensor[B x T_kv x D], optional
331
+ position_bias : Tensor[heads x B x T_q x T_q], optional
332
+ Relative position bias for self attention layer
333
+ encoder_decoder_position_bias : Tensor[heads x B x T_q x T_kv], optional
334
+ Relative position bias for cross attention layer
335
+ Returns
336
+ -------
337
+ Tensor[B x T_q x D]
338
+ """
339
+ y = self.norm_1(x)
340
+ y = self.film_1(y.permute(0, 2, 1), cond).permute(0, 2, 1)
341
+ if self.flash_attn:
342
+ with torch.autocast(y.device.type, dtype=torch.bfloat16):
343
+ y = self.self_attn(y)[0]
344
+ else:
345
+ y, position_bias = self.self_attn(y, y, y, x_mask, position_bias)
346
+ x = x + self.dropout(y)
347
+
348
+ if self.is_decoder:
349
+ y = self.norm_2(x)
350
+ y = self.film_2(y.permute(0, 2, 1), cond).permute(0, 2, 1)
351
+ y, encoder_decoder_position_bias = self.cross_attn(
352
+ y, src, src, src_mask, encoder_decoder_position_bias
353
+ )
354
+ x = x + self.dropout(y)
355
+
356
+ y = self.norm_3(x)
357
+ y = self.film_3(
358
+ y.permute(
359
+ 0,
360
+ 2,
361
+ 1,
362
+ ),
363
+ cond,
364
+ ).permute(0, 2, 1)
365
+ y = self.feed_forward(y)
366
+ x = x + self.dropout(y)
367
+
368
+ return x, position_bias, encoder_decoder_position_bias
369
+
370
+
371
+ class TransformerStack(nn.Module):
372
+ def __init__(
373
+ self,
374
+ d_model: int = 512,
375
+ d_cond: int = 64,
376
+ n_heads: int = 8,
377
+ n_layers: int = 8,
378
+ last_layer: bool = True,
379
+ bidirectional: bool = True,
380
+ flash_attn: bool = False,
381
+ is_decoder: bool = False,
382
+ dropout: float = 0.1,
383
+ ):
384
+ super().__init__()
385
+ # Store args
386
+ self.bidirectional = bidirectional
387
+ self.is_decoder = is_decoder
388
+
389
+ # Create transformer layers
390
+ # In T5, relative attention bias is shared by all layers in the stack
391
+ self.layers = nn.ModuleList(
392
+ [
393
+ TransformerLayer(
394
+ d_model,
395
+ d_cond,
396
+ n_heads,
397
+ bidirectional,
398
+ is_decoder,
399
+ has_relative_attention_bias=True if (i == 0) else False,
400
+ flash_attn=flash_attn,
401
+ dropout=dropout,
402
+ )
403
+ for i in range(n_layers)
404
+ ]
405
+ )
406
+
407
+ # Perform last normalization
408
+ self.norm = RMSNorm(d_model) if last_layer else None
409
+
410
+ def subsequent_mask(self, size):
411
+ return torch.ones(1, size, size).tril().bool()
412
+
413
+ def forward(self, x, x_mask, cond=None, src=None, src_mask=None,
414
+ return_activations: bool = False
415
+ ):
416
+ """Computes a full transformer stack
417
+ Parameters
418
+ ----------
419
+ x : Tensor[B x T_q x D]
420
+ x_mask : Tensor[B x T_q]
421
+ src : Tensor[B x T_kv x D], optional
422
+ src_mask : Tensor[B x T_kv], optional
423
+ Returns
424
+ -------
425
+ Tensor[B x T_q x D]
426
+ """
427
+
428
+ # Convert `src_mask` to (B x T_q x T_kv) shape for cross attention masking
429
+ if self.is_decoder:
430
+ src_mask = x_mask.unsqueeze(-1) * src_mask.unsqueeze(-2)
431
+
432
+ # Convert `x_mask` to (B x T_q x T_q) shape for self attention masking
433
+ x_mask = x_mask.unsqueeze(-2)
434
+ if not self.bidirectional:
435
+ x_mask = x_mask * self.subsequent_mask(x.size(1)).to(x_mask.device)
436
+
437
+ # Initialize position biases
438
+ position_bias = None
439
+ encoder_decoder_position_bias = None
440
+
441
+ # Compute transformer layers
442
+ if return_activations:
443
+ activations = []
444
+ for layer in self.layers:
445
+ x, position_bias, encoder_decoder_position_bias = layer(
446
+ x=x,
447
+ x_mask=x_mask,
448
+ cond=cond,
449
+ src=src,
450
+ src_mask=src_mask,
451
+ position_bias=position_bias,
452
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
453
+ )
454
+ if return_activations:
455
+ activations.append(x.detach())
456
+
457
+
458
+ out = self.norm(x) if self.norm is not None else x
459
+ if return_activations:
460
+ return out, torch.stack(activations)
461
+ else:
462
+ return out
463
+
464
+
465
+ class VampNet(at.ml.BaseModel):
466
+ def __init__(
467
+ self,
468
+ n_heads: int = 20,
469
+ n_layers: int = 16,
470
+ r_cond_dim: int = 0,
471
+ n_codebooks: int = 9,
472
+ n_conditioning_codebooks: int = 0,
473
+ latent_dim: int = 8,
474
+ embedding_dim: int = 1280,
475
+ vocab_size: int = 1024,
476
+ flash_attn: bool = True,
477
+ noise_mode: str = "mask",
478
+ dropout: float = 0.1
479
+ ):
480
+ super().__init__()
481
+ assert r_cond_dim == 0, f"r_cond_dim must be 0 (not supported), but got {r_cond_dim}"
482
+ self.n_heads = n_heads
483
+ self.n_layers = n_layers
484
+ self.r_cond_dim = r_cond_dim
485
+ self.n_codebooks = n_codebooks
486
+ self.n_conditioning_codebooks = n_conditioning_codebooks
487
+ self.embedding_dim = embedding_dim
488
+ self.vocab_size = vocab_size
489
+ self.latent_dim = latent_dim
490
+ self.flash_attn = flash_attn
491
+ self.noise_mode = noise_mode
492
+
493
+ assert self.noise_mode == "mask", "deprecated"
494
+
495
+ self.embedding = CodebookEmbedding(
496
+ latent_dim=latent_dim,
497
+ n_codebooks=n_codebooks,
498
+ vocab_size=vocab_size,
499
+ emb_dim=embedding_dim,
500
+ special_tokens=["MASK"],
501
+ )
502
+ self.mask_token = self.embedding.special_idxs["MASK"]
503
+
504
+ self.transformer = TransformerStack(
505
+ d_model=embedding_dim,
506
+ d_cond=r_cond_dim,
507
+ n_heads=n_heads,
508
+ n_layers=n_layers,
509
+ last_layer=True,
510
+ bidirectional=True,
511
+ flash_attn=flash_attn,
512
+ is_decoder=False,
513
+ dropout=dropout,
514
+ )
515
+
516
+ # Add final conv layer
517
+ self.n_predict_codebooks = n_codebooks - n_conditioning_codebooks
518
+ self.classifier = SequentialWithFiLM(
519
+ WNConv1d(
520
+ embedding_dim,
521
+ vocab_size * self.n_predict_codebooks,
522
+ kernel_size=1,
523
+ padding="same",
524
+ # groups=self.n_predict_codebooks,
525
+ ),
526
+ )
527
+
528
+ def forward(self, x, return_activations: bool = False):
529
+ x = self.embedding(x)
530
+ x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1)
531
+
532
+ x = rearrange(x, "b d n -> b n d")
533
+ out = self.transformer(x=x, x_mask=x_mask, return_activations=return_activations)
534
+ if return_activations:
535
+ out, activations = out
536
+
537
+ out = rearrange(out, "b n d -> b d n")
538
+
539
+ out = self.classifier(out, None) # no cond here!
540
+
541
+ out = rearrange(out, "b (p c) t -> b p (t c)", c=self.n_predict_codebooks)
542
+
543
+ if return_activations:
544
+ return out, activations
545
+ else:
546
+ return out
547
+
548
+ def r_embed(self, r, max_positions=10000):
549
+ if self.r_cond_dim > 0:
550
+ dtype = r.dtype
551
+
552
+ r = _gamma(r) * max_positions
553
+ half_dim = self.r_cond_dim // 2
554
+
555
+ emb = math.log(max_positions) / (half_dim - 1)
556
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
557
+
558
+ emb = r[:, None] * emb[None, :]
559
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
560
+
561
+ if self.r_cond_dim % 2 == 1: # zero pad
562
+ emb = nn.functional.pad(emb, (0, 1), mode="constant")
563
+
564
+ return emb.to(dtype)
565
+ else:
566
+ return r
567
+
568
+ @torch.no_grad()
569
+ def decode(self, z, codec):
570
+ """
571
+ convert a sequence of latents to a signal.
572
+ """
573
+ assert z.ndim == 3
574
+
575
+ # remove mask token
576
+ z = z.masked_fill(z == self.mask_token, 0)
577
+ signal = at.AudioSignal(
578
+ codec.decode(
579
+ codec.quantizer.from_latents(self.embedding.from_codes(z, codec))[0]
580
+ )["audio"],
581
+ codec.sample_rate,
582
+ )
583
+
584
+ # find where the mask token is and replace it with silence in the audio
585
+ for tstep in range(z.shape[-1]):
586
+ if torch.all(z[:, :, tstep] == self.mask_token):
587
+ sample_idx_0 = tstep * codec.hop_length
588
+ sample_idx_1 = sample_idx_0 + codec.hop_length
589
+ signal.samples[:, :, sample_idx_0:sample_idx_1] = 0.0
590
+
591
+ return signal
592
+
593
+ @torch.inference_mode()
594
+ def generate(
595
+ self,
596
+ codec,
597
+ time_steps: int = 300,
598
+ _sampling_steps: List[int] = [12],
599
+ start_tokens: Optional[torch.Tensor] = None,
600
+ temperature: float = 1.0,
601
+ mask: Optional[torch.Tensor] = None,
602
+ mask_temperature: float = 10.5,
603
+ typical_filtering=True,
604
+ typical_mass=0.2,
605
+ typical_min_tokens=64,
606
+ top_p=None,
607
+ seed: int = None,
608
+ sample_cutoff: float = 1.0,
609
+ return_signal=True,
610
+ debug=False,
611
+ causal_weight: float = 0.0,
612
+ cfg_guidance: float = None,
613
+ ):
614
+ if seed is not None:
615
+ at.util.seed(seed)
616
+ sampling_steps = sum(_sampling_steps)
617
+ logging.debug(f"beginning generation with {sampling_steps} steps")
618
+
619
+ #####################
620
+ # resolve initial z #
621
+ #####################
622
+ z = start_tokens
623
+ nb = z.shape[0]
624
+
625
+ if z is None:
626
+ z = torch.full((1, self.n_codebooks, time_steps), self.mask_token).to(
627
+ self.device
628
+ )
629
+
630
+
631
+
632
+ #################
633
+ # resolve mask #
634
+ #################
635
+
636
+ if mask is None:
637
+ mask = torch.ones_like(z).to(self.device).int()
638
+ mask[:, : self.n_conditioning_codebooks, :] = 0.0
639
+ if mask.ndim == 2:
640
+ mask = mask[:, None, :].repeat(1, z.shape[1], 1)
641
+ # init_mask = mask.clone()
642
+
643
+
644
+
645
+ ###########
646
+ # set up #
647
+ ##########
648
+ # apply the mask to z
649
+ z_masked = z.masked_fill(mask.bool(), self.mask_token)
650
+ # logging.debug(f"z_masked: {z_masked}")
651
+
652
+ # how many mask tokens to begin with?
653
+ num_mask_tokens_at_start = (z_masked == self.mask_token).sum()
654
+
655
+ # how many codebooks are we inferring vs conditioning on?
656
+ n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
657
+
658
+ if cfg_guidance is not None:
659
+ # we need to repeat our tensors
660
+ z_uncond = torch.full_like(z, self.mask_token)
661
+
662
+ z_masked = torch.cat(
663
+ (z_masked, z_uncond), dim=0
664
+ )
665
+ z = torch.cat(
666
+ (z, z_uncond), dim=0
667
+ )
668
+ mask = torch.cat(
669
+ (mask, torch.full_like(mask, 1)), dim=0
670
+ )
671
+
672
+ #################
673
+ # begin sampling #
674
+ #################
675
+ from tqdm import tqdm
676
+ for i in range(sampling_steps):
677
+
678
+ # our current schedule step
679
+ r = scalar_to_batch_tensor(
680
+ (i + 1) / sampling_steps,
681
+ z.shape[0]
682
+ ).to(z.device)
683
+
684
+ # get latents
685
+ latents = self.embedding.from_codes(z_masked, codec)
686
+
687
+
688
+ # infer from latents
689
+ # NOTE: this collapses the codebook dimension into the sequence dimension
690
+ logits = self.forward(latents) # b, prob, seq
691
+
692
+ if cfg_guidance is not None:
693
+ logits_cond, logits_uncond = logits[:nb], logits[nb:]
694
+ logits_cond = cfg_guidance * logits_cond + cfg_guidance * (1 - logits_uncond)
695
+
696
+ logits = logits.permute(0, 2, 1) # b, seq, prob
697
+ b = logits.shape[0]
698
+
699
+ sampled_z, selected_probs = sample_from_logits(
700
+ logits, sample=(
701
+ (i / sampling_steps) <= sample_cutoff
702
+ ),
703
+ temperature=temperature,
704
+ typical_filtering=typical_filtering, typical_mass=typical_mass,
705
+ typical_min_tokens=typical_min_tokens,
706
+ top_k=None, top_p=top_p, return_probs=True,
707
+ )
708
+
709
+
710
+ # flatten z_masked and mask, so we can deal with the sampling logic
711
+ # we'll unflatten them at the end of the loop for the next forward pass
712
+ # remove conditioning codebooks, we'll add them back at the end
713
+ z_masked = codebook_flatten(z_masked[:, self.n_conditioning_codebooks:, :])
714
+
715
+ mask = (z_masked == self.mask_token).int()
716
+
717
+ # update the mask, remove conditioning codebooks from the mask
718
+ # add z back into sampled z where the mask was false
719
+ sampled_z = torch.where(
720
+ mask.bool(), sampled_z, z_masked
721
+ )
722
+
723
+ # ignore any tokens that weren't masked
724
+ selected_probs = torch.where(
725
+ mask.bool(), selected_probs, torch.inf
726
+ )
727
+
728
+ # get the num tokens to mask, according to the schedule
729
+ num_to_mask = torch.floor(_gamma(r) * num_mask_tokens_at_start).unsqueeze(1).long()
730
+ logging.debug(f"num to mask: {num_to_mask}")
731
+
732
+ if i != (sampling_steps - 1):
733
+ num_to_mask = torch.maximum(
734
+ torch.tensor(1),
735
+ torch.minimum(
736
+ mask.sum(dim=-1, keepdim=True) - 1,
737
+ num_to_mask
738
+ )
739
+ )
740
+
741
+
742
+ # get our new mask
743
+ mask = mask_by_random_topk(
744
+ num_to_mask, selected_probs, mask_temperature * (1-r)
745
+ )
746
+
747
+ # update the mask
748
+ z_masked = torch.where(
749
+ mask.bool(), self.mask_token, sampled_z
750
+ )
751
+
752
+ z_masked = codebook_unflatten(z_masked, n_infer_codebooks)
753
+ mask = codebook_unflatten(mask, n_infer_codebooks)
754
+
755
+ # add conditioning codebooks back to z_masked
756
+ z_masked = torch.cat(
757
+ (z[:, :self.n_conditioning_codebooks, :], z_masked), dim=1
758
+ )
759
+
760
+ # add conditioning codebooks back to sampled_z
761
+ sampled_z = codebook_unflatten(sampled_z, n_infer_codebooks)
762
+ sampled_z = torch.cat(
763
+ (z[:, :self.n_conditioning_codebooks, :], sampled_z), dim=1
764
+ )
765
+
766
+ if cfg_guidance is not None:
767
+ sampled_z = sampled_z[:nb]
768
+
769
+ if return_signal:
770
+ return self.decode(sampled_z, codec)
771
+ else:
772
+ return sampled_z
773
+
774
+
775
+
776
+
777
+
778
+ def sample_from_logits(
779
+ logits,
780
+ sample: bool = True,
781
+ temperature: float = 1.0,
782
+ top_k: int = None,
783
+ top_p: float = None,
784
+ typical_filtering: bool = False,
785
+ typical_mass: float = 0.2,
786
+ typical_min_tokens: int = 1,
787
+ return_probs: bool = False
788
+ ):
789
+ """Convenience function to sample from a categorial distribution with input as
790
+ unnormalized logits.
791
+
792
+ Parameters
793
+ ----------
794
+ logits : Tensor[..., vocab_size]
795
+ config: SamplingConfig
796
+ The set of hyperparameters to be used for sampling
797
+ sample : bool, optional
798
+ Whether to perform multinomial sampling, by default True
799
+ temperature : float, optional
800
+ Scaling parameter when multinomial samping, by default 1.0
801
+ top_k : int, optional
802
+ Restricts sampling to only `top_k` values acc. to probability,
803
+ by default None
804
+ top_p : float, optional
805
+ Restricts sampling to only those values with cumulative
806
+ probability = `top_p`, by default None
807
+
808
+ Returns
809
+ -------
810
+ Tensor[...]
811
+ Sampled tokens
812
+ """
813
+ shp = logits.shape[:-1]
814
+
815
+ if typical_filtering:
816
+ typical_filter(logits,
817
+ typical_mass=typical_mass,
818
+ typical_min_tokens=typical_min_tokens
819
+ )
820
+
821
+ # Apply top_k sampling
822
+ if top_k is not None:
823
+ v, _ = logits.topk(top_k)
824
+ logits[logits < v[..., [-1]]] = -float("inf")
825
+
826
+ # Apply top_p (nucleus) sampling
827
+ if top_p is not None and top_p < 1.0:
828
+ v, sorted_indices = logits.sort(descending=True)
829
+ cumulative_probs = v.softmax(dim=-1).cumsum(dim=-1)
830
+
831
+ sorted_indices_to_remove = cumulative_probs > top_p
832
+ # Right shift indices_to_remove to keep 1st token over threshold
833
+ sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, 0), value=False)[
834
+ ..., :-1
835
+ ]
836
+
837
+ # Compute indices_to_remove in unsorted array
838
+ indices_to_remove = sorted_indices_to_remove.scatter(
839
+ -1, sorted_indices, sorted_indices_to_remove
840
+ )
841
+
842
+ logits[indices_to_remove] = -float("inf")
843
+
844
+ # Perform multinomial sampling after normalizing logits
845
+ probs = (
846
+ F.softmax(logits / temperature, dim=-1)
847
+ if temperature > 0
848
+ else logits.softmax(dim=-1)
849
+ )
850
+ token = (
851
+ probs.view(-1, probs.size(-1)).multinomial(1).squeeze(1).view(*shp)
852
+ if sample
853
+ else logits.argmax(-1)
854
+ )
855
+
856
+ if return_probs:
857
+ token_probs = probs.take_along_dim(token.unsqueeze(-1), dim=-1).squeeze(-1)
858
+ return token, token_probs
859
+ else:
860
+ return token
861
+
862
+
863
+
864
+ def mask_by_random_topk(
865
+ num_to_mask: int,
866
+ probs: torch.Tensor,
867
+ temperature: float = 1.0,
868
+ ):
869
+ """
870
+ Args:
871
+ num_to_mask (int): number of tokens to mask
872
+ probs (torch.Tensor): probabilities for each sampled event, shape (batch, seq)
873
+ temperature (float, optional): temperature. Defaults to 1.0.
874
+ """
875
+ logging.debug(f"masking by random topk")
876
+ logging.debug(f"num to mask: {num_to_mask}")
877
+ logging.debug(f"probs shape: {probs.shape}")
878
+ logging.debug(f"temperature: {temperature}")
879
+ logging.debug("")
880
+
881
+ noise = gumbel_noise_like(probs)
882
+ temperature = temperature.unsqueeze(-1)
883
+ confidence = torch.log(probs) + temperature * noise
884
+ logging.debug(f"confidence shape: {confidence.shape}")
885
+
886
+ sorted_confidence, sorted_idx = confidence.sort(dim=-1)
887
+ logging.debug(f"sorted confidence shape: {sorted_confidence.shape}")
888
+ logging.debug(f"sorted idx shape: {sorted_idx.shape}")
889
+
890
+ # get the cut off threshold, given the mask length
891
+ cut_off = torch.take_along_dim(
892
+ sorted_confidence, num_to_mask, axis=-1
893
+ )
894
+ logging.debug(f"cut off shape: {cut_off.shape}")
895
+
896
+ # mask out the tokens
897
+ mask = confidence < cut_off
898
+ logging.debug(f"mask shape: {mask.shape}")
899
+
900
+ return mask
901
+
902
+ def typical_filter(
903
+ logits,
904
+ typical_mass: float = 0.95,
905
+ typical_min_tokens: int = 1,):
906
+ nb, nt, _ = logits.shape
907
+ x_flat = rearrange(logits, "b t l -> (b t ) l")
908
+ x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)
909
+ x_flat_norm_p = torch.exp(x_flat_norm)
910
+ entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True)
911
+
912
+ c_flat_shifted = torch.abs((-x_flat_norm) - entropy)
913
+ c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)
914
+ x_flat_cumsum = (
915
+ x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)
916
+ )
917
+
918
+ last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)
919
+ sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(
920
+ 1, last_ind.view(-1, 1)
921
+ )
922
+ if typical_min_tokens > 1:
923
+ sorted_indices_to_remove[..., :typical_min_tokens] = 0
924
+ indices_to_remove = sorted_indices_to_remove.scatter(
925
+ 1, x_flat_indices, sorted_indices_to_remove
926
+ )
927
+ x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf"))
928
+ logits = rearrange(x_flat, "(b t) l -> b t l", t=nt)
929
+ return logits
930
+
931
+
932
+ if __name__ == "__main__":
933
+ # import argbind
934
+ from .layers import num_params
935
+
936
+ VampNet = argbind.bind(VampNet)
937
+
938
+ @argbind.bind(without_prefix=True)
939
+ def try_model(device: str = "cuda", batch_size: int = 2, seq_len_s: float = 10.0):
940
+ seq_len = int(32000 / 512 * seq_len_s)
941
+
942
+ model = VampNet().to(device)
943
+
944
+ z = torch.randint(
945
+ 0, model.vocab_size, size=(batch_size, model.n_codebooks, seq_len)
946
+ ).to(device)
947
+
948
+ r = torch.zeros(batch_size).to(device)
949
+
950
+ z_mask_latent = torch.rand(
951
+ batch_size, model.latent_dim * model.n_codebooks, seq_len
952
+ ).to(device)
953
+ z_hat = model(z_mask_latent)
954
+
955
+ pred = z_hat.argmax(dim=1)
956
+ pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks)
957
+
958
+ logging.debug(f"model has {num_params(model)/1e6:<.3f}M parameters")
959
+ logging.debug(f"prediction has shape {pred.shape}")
960
+
961
+ args = argbind.parse_args()
962
+ with argbind.scope(args):
963
+ try_model()
964
+
965
+
vampnet/scheduler.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import List
3
+
4
+ import torch
5
+
6
+ class NoamScheduler:
7
+ """OG scheduler from transformer paper: https://arxiv.org/pdf/1706.03762.pdf
8
+ Implementation from Annotated Transformer: https://nlp.seas.harvard.edu/2018/04/03/attention.html
9
+ """
10
+
11
+ def __init__(
12
+ self,
13
+ optimizer: torch.optim.Optimizer,
14
+ d_model: int = 512,
15
+ factor: float = 1.0,
16
+ warmup: int = 4000,
17
+ ):
18
+ # Store hparams
19
+ self.warmup = warmup
20
+ self.factor = factor
21
+ self.d_model = d_model
22
+
23
+ # Initialize variables `lr` and `steps`
24
+ self.lr = None
25
+ self.steps = 0
26
+
27
+ # Store the optimizer
28
+ self.optimizer = optimizer
29
+
30
+ def state_dict(self):
31
+ return {
32
+ key: value for key, value in self.__dict__.items() if key != "optimizer"
33
+ }
34
+
35
+ def load_state_dict(self, state_dict):
36
+ self.__dict__.update(state_dict)
37
+
38
+ def step(self):
39
+ self.steps += 1
40
+ self.lr = self.factor * (
41
+ self.d_model ** (-0.5)
42
+ * min(self.steps ** (-0.5), self.steps * self.warmup ** (-1.5))
43
+ )
44
+
45
+ for p in self.optimizer.param_groups:
46
+ p["lr"] = self.lr
47
+
vampnet/util.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tqdm
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+ def scalar_to_batch_tensor(x, batch_size):
7
+ return torch.tensor(x).repeat(batch_size)
8
+
9
+
10
+ def parallelize(
11
+ fn,
12
+ *iterables,
13
+ parallel: str = "thread_map",
14
+ **kwargs
15
+ ):
16
+ if parallel == "thread_map":
17
+ from tqdm.contrib.concurrent import thread_map
18
+ return thread_map(
19
+ fn,
20
+ *iterables,
21
+ **kwargs
22
+ )
23
+ elif parallel == "process_map":
24
+ from tqdm.contrib.concurrent import process_map
25
+ return process_map(
26
+ fn,
27
+ *iterables,
28
+ **kwargs
29
+ )
30
+ elif parallel == "single":
31
+ return [fn(x) for x in tqdm.tqdm(*iterables)]
32
+ else:
33
+ raise ValueError(f"parallel must be one of 'thread_map', 'process_map', 'single', but got {parallel}")
34
+
35
+ def codebook_flatten(tokens: torch.Tensor):
36
+ """
37
+ flatten a sequence of tokens from (batch, codebook, time) to (batch, codebook * time)
38
+ """
39
+ return rearrange(tokens, "b c t -> b (t c)")
40
+
41
+ def codebook_unflatten(flat_tokens: torch.Tensor, n_c: int = None):
42
+ """
43
+ unflatten a sequence of tokens from (batch, codebook * time) to (batch, codebook, time)
44
+ """
45
+ tokens = rearrange(flat_tokens, "b (t c) -> b c t", c=n_c)
46
+ return tokens