Spaces:
Sleeping
Sleeping
defaults
#3
by
hugggof
- opened
This view is limited to 50 files because it contains too many changes.
See the raw diff here.
- .gitignore +3 -7
- README.md +5 -16
- app.py +43 -229
- conf/generated-v0/berta-goldman-speech/c2f.yml +15 -0
- conf/generated-v0/berta-goldman-speech/coarse.yml +8 -0
- conf/generated-v0/berta-goldman-speech/interface.yml +5 -0
- conf/generated-v0/gamelan-xeno-canto/c2f.yml +17 -0
- conf/generated-v0/gamelan-xeno-canto/coarse.yml +10 -0
- conf/generated-v0/gamelan-xeno-canto/interface.yml +6 -0
- conf/generated-v0/nasralla/c2f.yml +15 -0
- conf/generated-v0/nasralla/coarse.yml +8 -0
- conf/generated-v0/nasralla/interface.yml +5 -0
- conf/generated/breaks-steps/c2f.yml +15 -0
- conf/generated/breaks-steps/coarse.yml +8 -0
- conf/generated/breaks-steps/interface.yml +7 -0
- conf/generated/bulgarian-tv-choir/c2f.yml +15 -0
- conf/generated/bulgarian-tv-choir/coarse.yml +8 -0
- conf/generated/bulgarian-tv-choir/interface.yml +7 -0
- conf/generated/dariacore/c2f.yml +15 -0
- conf/generated/dariacore/coarse.yml +8 -0
- conf/generated/dariacore/interface.yml +7 -0
- conf/generated/musica-bolero-marimba/c2f.yml +18 -0
- conf/generated/musica-bolero-marimba/coarse.yml +11 -0
- conf/generated/musica-bolero-marimba/interface.yml +8 -0
- conf/generated/panchos/c2f.yml +15 -0
- conf/generated/panchos/coarse.yml +8 -0
- conf/generated/panchos/interface.yml +7 -0
- conf/generated/titi-monkey/c2f.yml +15 -0
- conf/generated/titi-monkey/coarse.yml +8 -0
- conf/generated/titi-monkey/interface.yml +7 -0
- conf/generated/xeno-canto/c2f.yml +15 -0
- conf/generated/xeno-canto/coarse.yml +8 -0
- conf/generated/xeno-canto/interface.yml +7 -0
- conf/lora/birds.yml +10 -0
- conf/lora/birdss.yml +12 -0
- conf/lora/constructions.yml +10 -0
- conf/lora/ella-baila-sola.yml +10 -0
- conf/lora/gas-station.yml +10 -0
- conf/lora/lora-is-this-charlie-parker.yml +10 -0
- conf/lora/lora.yml +7 -7
- conf/lora/underworld.yml +10 -0
- conf/lora/xeno-canto/c2f.yml +21 -0
- conf/lora/xeno-canto/coarse.yml +10 -0
- conf/vampnet-musdb-drums.yml +22 -0
- conf/vampnet.yml +19 -9
- requirements.txt +2 -4
- scripts/exp/fine_tune.py +7 -6
- scripts/exp/train.py +425 -483
- scripts/utils/{data/augment.py → augment.py} +24 -38
- scripts/utils/gtzan_embeddings.py +0 -263
.gitignore
CHANGED
@@ -175,14 +175,10 @@ lyrebird-audio-codec
|
|
175 |
samples-*/**
|
176 |
|
177 |
gradio-outputs/
|
178 |
-
models/
|
179 |
samples*/
|
180 |
models-all/
|
181 |
models.zip
|
|
|
|
|
|
|
182 |
.git-old
|
183 |
-
conf/generated/*
|
184 |
-
runs*/
|
185 |
-
|
186 |
-
|
187 |
-
gtzan.zip
|
188 |
-
.gtzan_emb_cache
|
|
|
175 |
samples-*/**
|
176 |
|
177 |
gradio-outputs/
|
|
|
178 |
samples*/
|
179 |
models-all/
|
180 |
models.zip
|
181 |
+
audiotools/
|
182 |
+
descript-audio-codec/
|
183 |
+
# *.pth
|
184 |
.git-old
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
@@ -7,27 +7,16 @@ sdk: gradio
|
|
7 |
sdk_version: 3.36.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
|
11 |
---
|
12 |
|
13 |
# VampNet
|
14 |
|
15 |
-
This repository contains recipes for training generative music models on top of the
|
16 |
-
|
17 |
-
## try `unloop`
|
18 |
-
you can try vampnet in a co-creative looper called unloop. see this link: https://github.com/hugofloresgarcia/unloop
|
19 |
|
20 |
# Setting up
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
you'll need a Python 3.9 environment to run VampNet. This is due to a [known issue with madmom](https://github.com/hugofloresgarcia/vampnet/issues/15).
|
25 |
-
|
26 |
-
(for example, using conda)
|
27 |
-
```bash
|
28 |
-
conda create -n vampnet python=3.9
|
29 |
-
conda activate vampnet
|
30 |
-
```
|
31 |
|
32 |
|
33 |
install VampNet
|
@@ -46,7 +35,7 @@ Config files are stored in the `conf/` folder.
|
|
46 |
### Licensing for Pretrained Models:
|
47 |
The weights for the models are licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml). Likewise, any VampNet models fine-tuned on the pretrained models are also licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml).
|
48 |
|
49 |
-
Download the pretrained models from [this link](https://zenodo.org/record/
|
50 |
|
51 |
|
52 |
# Usage
|
@@ -100,7 +89,7 @@ python scripts/exp/train.py --args.load conf/<fine_tune_name>/c2f.yml
|
|
100 |
|
101 |
launch the interface:
|
102 |
```bash
|
103 |
-
python
|
104 |
```
|
105 |
|
106 |
|
|
|
7 |
sdk_version: 3.36.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
duplicated_from: hugggof/vampnet
|
11 |
---
|
12 |
|
13 |
# VampNet
|
14 |
|
15 |
+
This repository contains recipes for training generative music models on top of the Lyrebird Audio Codec.
|
|
|
|
|
|
|
16 |
|
17 |
# Setting up
|
18 |
|
19 |
+
Requires Python 3.9 or later.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
install VampNet
|
|
|
35 |
### Licensing for Pretrained Models:
|
36 |
The weights for the models are licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml). Likewise, any VampNet models fine-tuned on the pretrained models are also licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml).
|
37 |
|
38 |
+
Download the pretrained models from [this link](https://zenodo.org/record/8136545). Then, extract the models to the `models/` folder.
|
39 |
|
40 |
|
41 |
# Usage
|
|
|
89 |
|
90 |
launch the interface:
|
91 |
```bash
|
92 |
+
python demo.py --args.load conf/generated/<fine_tune_name>/interface.yml
|
93 |
```
|
94 |
|
95 |
|
app.py
CHANGED
@@ -1,12 +1,3 @@
|
|
1 |
-
# huggingface space exclusive
|
2 |
-
import os
|
3 |
-
|
4 |
-
# print("installing pyharp")
|
5 |
-
# os.system('pip install "pyharp@git+https://github.com/audacitorch/pyharp.git"')
|
6 |
-
# print("installing madmom")
|
7 |
-
os.system('pip install cython')
|
8 |
-
os.system('pip install madmom')
|
9 |
-
|
10 |
from pathlib import Path
|
11 |
from typing import Tuple
|
12 |
import yaml
|
@@ -24,38 +15,27 @@ import gradio as gr
|
|
24 |
from vampnet.interface import Interface
|
25 |
from vampnet import mask as pmask
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
# loader = AudioLoader()
|
32 |
# AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
shift=interval,
|
42 |
-
sample_rate=signal.sample_rate
|
43 |
-
)
|
44 |
-
return signal
|
45 |
-
|
46 |
-
def load_interface():
|
47 |
-
interface = Interface(
|
48 |
-
coarse_ckpt="./models/vampnet/coarse.pth",
|
49 |
-
coarse2fine_ckpt="./models/vampnet/c2f.pth",
|
50 |
-
codec_ckpt="./models/vampnet/codec.pth",
|
51 |
-
wavebeat_ckpt="./models/wavebeat.pth",
|
52 |
-
device="cuda" if torch.cuda.is_available() else "cpu",
|
53 |
-
)
|
54 |
-
return interface
|
55 |
-
|
56 |
|
57 |
-
|
|
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
OUT_DIR = Path("gradio-outputs")
|
61 |
OUT_DIR.mkdir(exist_ok=True, parents=True)
|
@@ -70,7 +50,7 @@ def load_audio(file):
|
|
70 |
)
|
71 |
sig = interface.preprocess(sig)
|
72 |
|
73 |
-
out_dir = OUT_DIR /
|
74 |
out_dir.mkdir(parents=True, exist_ok=True)
|
75 |
sig.write(out_dir / "input.wav")
|
76 |
return sig.path_to_file
|
@@ -88,10 +68,6 @@ def _vamp(data, return_mask=False):
|
|
88 |
out_dir = OUT_DIR / str(uuid.uuid4())
|
89 |
out_dir.mkdir()
|
90 |
sig = at.AudioSignal(data[input_audio])
|
91 |
-
sig = interface.preprocess(sig)
|
92 |
-
|
93 |
-
if data[pitch_shift_amt] != 0:
|
94 |
-
sig = shift_pitch(sig, data[pitch_shift_amt])
|
95 |
|
96 |
z = interface.encode(sig)
|
97 |
|
@@ -131,58 +107,24 @@ def _vamp(data, return_mask=False):
|
|
131 |
mask = pmask.codebook_unmask(mask, ncc)
|
132 |
|
133 |
|
134 |
-
print(f"dropout {data[dropout]}")
|
135 |
-
print(f"masktemp {data[masktemp]}")
|
136 |
-
print(f"sampletemp {data[sampletemp]}")
|
137 |
-
print(f"top_p {data[top_p]}")
|
138 |
-
print(f"prefix_s {data[prefix_s]}")
|
139 |
-
print(f"suffix_s {data[suffix_s]}")
|
140 |
-
print(f"rand_mask_intensity {data[rand_mask_intensity]}")
|
141 |
-
print(f"num_steps {data[num_steps]}")
|
142 |
-
print(f"periodic_p {data[periodic_p]}")
|
143 |
-
print(f"periodic_w {data[periodic_w]}")
|
144 |
-
print(f"n_conditioning_codebooks {data[n_conditioning_codebooks]}")
|
145 |
-
print(f"use_coarse2fine {data[use_coarse2fine]}")
|
146 |
-
print(f"onset_mask_width {data[onset_mask_width]}")
|
147 |
-
print(f"beat_mask_width {data[beat_mask_width]}")
|
148 |
-
print(f"beat_mask_downbeats {data[beat_mask_downbeats]}")
|
149 |
-
print(f"stretch_factor {data[stretch_factor]}")
|
150 |
-
print(f"seed {data[seed]}")
|
151 |
-
print(f"pitch_shift_amt {data[pitch_shift_amt]}")
|
152 |
-
print(f"sample_cutoff {data[sample_cutoff]}")
|
153 |
-
|
154 |
-
|
155 |
-
_top_p = data[top_p] if data[top_p] > 0 else None
|
156 |
# save the mask as a txt file
|
157 |
np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
|
158 |
|
159 |
-
_seed = data[seed] if data[seed] > 0 else None
|
160 |
zv, mask_z = interface.coarse_vamp(
|
161 |
z,
|
162 |
mask=mask,
|
163 |
sampling_steps=data[num_steps],
|
164 |
-
|
165 |
-
sampling_temperature=data[sampletemp],
|
166 |
return_mask=True,
|
167 |
typical_filtering=data[typical_filtering],
|
168 |
typical_mass=data[typical_mass],
|
169 |
typical_min_tokens=data[typical_min_tokens],
|
170 |
-
top_p=_top_p,
|
171 |
gen_fn=interface.coarse.generate,
|
172 |
-
seed=_seed,
|
173 |
-
sample_cutoff=data[sample_cutoff],
|
174 |
)
|
175 |
|
176 |
if use_coarse2fine:
|
177 |
-
zv = interface.coarse_to_fine(
|
178 |
-
zv,
|
179 |
-
mask_temperature=data[masktemp]*10,
|
180 |
-
sampling_temperature=data[sampletemp],
|
181 |
-
mask=mask,
|
182 |
-
sampling_steps=data[num_steps] // 2,
|
183 |
-
sample_cutoff=data[sample_cutoff],
|
184 |
-
seed=_seed,
|
185 |
-
)
|
186 |
|
187 |
sig = interface.to_signal(zv).cpu()
|
188 |
print("done")
|
@@ -215,9 +157,7 @@ def save_vamp(data):
|
|
215 |
sig_out.write(out_dir / "output.wav")
|
216 |
|
217 |
_data = {
|
218 |
-
"
|
219 |
-
"sampletemp": data[sampletemp],
|
220 |
-
"top_p": data[top_p],
|
221 |
"prefix_s": data[prefix_s],
|
222 |
"suffix_s": data[suffix_s],
|
223 |
"rand_mask_intensity": data[rand_mask_intensity],
|
@@ -228,8 +168,6 @@ def save_vamp(data):
|
|
228 |
"n_conditioning_codebooks": data[n_conditioning_codebooks],
|
229 |
"use_coarse2fine": data[use_coarse2fine],
|
230 |
"stretch_factor": data[stretch_factor],
|
231 |
-
"seed": data[seed],
|
232 |
-
"samplecutoff": data[sample_cutoff],
|
233 |
}
|
234 |
|
235 |
# save with yaml
|
@@ -245,54 +183,13 @@ def save_vamp(data):
|
|
245 |
return f"saved! your save code is {out_dir.stem}", zip_path
|
246 |
|
247 |
|
248 |
-
def harp_vamp(_input_audio, _beat_mask_width, _sampletemp):
|
249 |
-
|
250 |
-
out_dir = OUT_DIR / str(uuid.uuid4())
|
251 |
-
out_dir.mkdir()
|
252 |
-
sig = at.AudioSignal(_input_audio)
|
253 |
-
sig = interface.preprocess(sig)
|
254 |
-
|
255 |
-
z = interface.encode(sig)
|
256 |
-
|
257 |
-
# build the mask
|
258 |
-
mask = pmask.linear_random(z, 1.0)
|
259 |
-
if _beat_mask_width > 0:
|
260 |
-
beat_mask = interface.make_beat_mask(
|
261 |
-
sig,
|
262 |
-
after_beat_s=(_beat_mask_width/1000),
|
263 |
-
)
|
264 |
-
mask = pmask.mask_and(mask, beat_mask)
|
265 |
-
|
266 |
-
# save the mask as a txt file
|
267 |
-
zv, mask_z = interface.coarse_vamp(
|
268 |
-
z,
|
269 |
-
mask=mask,
|
270 |
-
sampling_temperature=_sampletemp,
|
271 |
-
return_mask=True,
|
272 |
-
gen_fn=interface.coarse.generate,
|
273 |
-
)
|
274 |
-
|
275 |
-
|
276 |
-
zv = interface.coarse_to_fine(
|
277 |
-
zv,
|
278 |
-
sampling_temperature=_sampletemp,
|
279 |
-
mask=mask,
|
280 |
-
)
|
281 |
-
|
282 |
-
sig = interface.to_signal(zv).cpu()
|
283 |
-
print("done")
|
284 |
-
|
285 |
-
sig.write(out_dir / "output.wav")
|
286 |
-
|
287 |
-
return sig.path_to_file
|
288 |
-
|
289 |
with gr.Blocks() as demo:
|
290 |
|
291 |
with gr.Row():
|
292 |
with gr.Column():
|
293 |
-
gr.Markdown("# VampNet
|
294 |
gr.Markdown("""## Description:
|
295 |
-
This is a demo of
|
296 |
You can control the extent and nature of variation with a set of manual controls and presets.
|
297 |
Use this interface to experiment with different mask settings and explore the audio outputs.
|
298 |
""")
|
@@ -300,8 +197,8 @@ with gr.Blocks() as demo:
|
|
300 |
gr.Markdown("""
|
301 |
## Instructions:
|
302 |
1. You can start by uploading some audio, or by loading the example audio.
|
303 |
-
2. Choose a preset for the vamp operation, or manually adjust the controls to customize the mask settings.
|
304 |
-
3. Click the "generate (vamp)!!!" button to
|
305 |
4. Optionally, you can add some notes and save the result.
|
306 |
5. You can also use the output as the new input and continue experimenting!
|
307 |
""")
|
@@ -352,25 +249,19 @@ with gr.Blocks() as demo:
|
|
352 |
"beat_mask_downbeats": False,
|
353 |
},
|
354 |
"slight periodic variation": {
|
355 |
-
"periodic_p":
|
356 |
-
"onset_mask_width":
|
357 |
-
"beat_mask_width": 0,
|
358 |
-
"beat_mask_downbeats": False,
|
359 |
-
},
|
360 |
-
"moderate periodic variation": {
|
361 |
-
"periodic_p": 13,
|
362 |
-
"onset_mask_width": 5,
|
363 |
"beat_mask_width": 0,
|
364 |
"beat_mask_downbeats": False,
|
365 |
},
|
366 |
"strong periodic variation": {
|
367 |
-
"periodic_p":
|
368 |
"onset_mask_width": 5,
|
369 |
"beat_mask_width": 0,
|
370 |
"beat_mask_downbeats": False,
|
371 |
},
|
372 |
"very strong periodic variation": {
|
373 |
-
"periodic_p":
|
374 |
"onset_mask_width": 5,
|
375 |
"beat_mask_width": 0,
|
376 |
"beat_mask_downbeats": False,
|
@@ -378,15 +269,9 @@ with gr.Blocks() as demo:
|
|
378 |
"beat-driven variation": {
|
379 |
"periodic_p": 0,
|
380 |
"onset_mask_width": 0,
|
381 |
-
"beat_mask_width":
|
382 |
"beat_mask_downbeats": False,
|
383 |
},
|
384 |
-
"beat-driven variation (downbeats only)": {
|
385 |
-
"periodic_p": 0,
|
386 |
-
"onset_mask_width": 0,
|
387 |
-
"beat_mask_width": 50,
|
388 |
-
"beat_mask_downbeats": True,
|
389 |
-
},
|
390 |
"beat-driven variation (downbeats only, strong)": {
|
391 |
"periodic_p": 0,
|
392 |
"onset_mask_width": 0,
|
@@ -408,20 +293,20 @@ with gr.Blocks() as demo:
|
|
408 |
minimum=0,
|
409 |
maximum=128,
|
410 |
step=1,
|
411 |
-
value=
|
412 |
)
|
413 |
|
414 |
|
415 |
onset_mask_width = gr.Slider(
|
416 |
label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
|
417 |
minimum=0,
|
418 |
-
maximum=
|
419 |
step=1,
|
420 |
value=5,
|
421 |
)
|
422 |
|
423 |
beat_mask_width = gr.Slider(
|
424 |
-
label="beat
|
425 |
minimum=0,
|
426 |
maximum=200,
|
427 |
value=0,
|
@@ -433,14 +318,6 @@ with gr.Blocks() as demo:
|
|
433 |
|
434 |
|
435 |
with gr.Accordion("extras ", open=False):
|
436 |
-
pitch_shift_amt = gr.Slider(
|
437 |
-
label="pitch shift amount (semitones)",
|
438 |
-
minimum=-12,
|
439 |
-
maximum=12,
|
440 |
-
step=1,
|
441 |
-
value=0,
|
442 |
-
)
|
443 |
-
|
444 |
rand_mask_intensity = gr.Slider(
|
445 |
label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
|
446 |
minimum=0.0,
|
@@ -500,34 +377,21 @@ with gr.Blocks() as demo:
|
|
500 |
value=0.0
|
501 |
)
|
502 |
|
503 |
-
|
504 |
-
label="
|
505 |
minimum=0.0,
|
506 |
-
maximum=100.0,
|
507 |
-
value=1.5
|
508 |
-
)
|
509 |
-
sampletemp = gr.Slider(
|
510 |
-
label="sample temperature",
|
511 |
-
minimum=0.1,
|
512 |
maximum=10.0,
|
513 |
-
value=1.
|
514 |
-
step=0.001
|
515 |
)
|
516 |
-
|
517 |
|
518 |
|
519 |
with gr.Accordion("sampling settings", open=False):
|
520 |
-
top_p = gr.Slider(
|
521 |
-
label="top p (0.0 = off)",
|
522 |
-
minimum=0.0,
|
523 |
-
maximum=1.0,
|
524 |
-
value=0.9
|
525 |
-
)
|
526 |
typical_filtering = gr.Checkbox(
|
527 |
label="typical filtering ",
|
528 |
value=False
|
529 |
)
|
530 |
-
typical_mass = gr.Slider(
|
531 |
label="typical mass (should probably stay between 0.1 and 0.5)",
|
532 |
minimum=0.01,
|
533 |
maximum=0.99,
|
@@ -540,18 +404,10 @@ with gr.Blocks() as demo:
|
|
540 |
step=1,
|
541 |
value=64
|
542 |
)
|
543 |
-
sample_cutoff = gr.Slider(
|
544 |
-
label="sample cutoff",
|
545 |
-
minimum=0.0,
|
546 |
-
maximum=1.0,
|
547 |
-
value=0.5,
|
548 |
-
step=0.01
|
549 |
-
)
|
550 |
|
551 |
use_coarse2fine = gr.Checkbox(
|
552 |
label="use coarse2fine",
|
553 |
-
value=True
|
554 |
-
visible=False
|
555 |
)
|
556 |
|
557 |
num_steps = gr.Slider(
|
@@ -571,24 +427,8 @@ with gr.Blocks() as demo:
|
|
571 |
)
|
572 |
|
573 |
|
574 |
-
seed = gr.Number(
|
575 |
-
label="seed (0 for random)",
|
576 |
-
value=0,
|
577 |
-
precision=0,
|
578 |
-
)
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
# mask settings
|
583 |
with gr.Column():
|
584 |
-
|
585 |
-
# lora_choice = gr.Dropdown(
|
586 |
-
# label="lora choice",
|
587 |
-
# choices=list(loras.keys()),
|
588 |
-
# value=LORA_NONE,
|
589 |
-
# visible=False
|
590 |
-
# )
|
591 |
-
|
592 |
vamp_button = gr.Button("generate (vamp)!!!")
|
593 |
output_audio = gr.Audio(
|
594 |
label="output audio",
|
@@ -614,9 +454,7 @@ with gr.Blocks() as demo:
|
|
614 |
_inputs = {
|
615 |
input_audio,
|
616 |
num_steps,
|
617 |
-
|
618 |
-
sampletemp,
|
619 |
-
top_p,
|
620 |
prefix_s, suffix_s,
|
621 |
rand_mask_intensity,
|
622 |
periodic_p, periodic_w,
|
@@ -629,11 +467,7 @@ with gr.Blocks() as demo:
|
|
629 |
typical_mass,
|
630 |
typical_min_tokens,
|
631 |
beat_mask_width,
|
632 |
-
beat_mask_downbeats
|
633 |
-
seed,
|
634 |
-
# lora_choice,
|
635 |
-
pitch_shift_amt,
|
636 |
-
sample_cutoff
|
637 |
}
|
638 |
|
639 |
# connect widgets
|
@@ -663,24 +497,4 @@ with gr.Blocks() as demo:
|
|
663 |
outputs=[thank_you, download_file]
|
664 |
)
|
665 |
|
666 |
-
|
667 |
-
harp_inputs = [
|
668 |
-
input_audio,
|
669 |
-
beat_mask_width,
|
670 |
-
sampletemp,
|
671 |
-
]
|
672 |
-
|
673 |
-
build_endpoint(
|
674 |
-
inputs=harp_inputs,
|
675 |
-
output=output_audio,
|
676 |
-
process_fn=harp_vamp,
|
677 |
-
card=ModelCard(
|
678 |
-
name="vampnet",
|
679 |
-
description="Generate variations on music input, based on small prompts around the beat. NOTE: vampnet's has a maximum context length of 10 seconds. Please split all audio clips into 10 second chunks, or processing will result in an error. ",
|
680 |
-
author="Hugo Flores García",
|
681 |
-
tags=["music", "generative"]
|
682 |
-
),
|
683 |
-
visible=False
|
684 |
-
)
|
685 |
-
|
686 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from pathlib import Path
|
2 |
from typing import Tuple
|
3 |
import yaml
|
|
|
15 |
from vampnet.interface import Interface
|
16 |
from vampnet import mask as pmask
|
17 |
|
18 |
+
# Interface = argbind.bind(Interface)
|
|
|
|
|
|
|
|
|
19 |
# AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
|
20 |
|
21 |
+
interface = Interface(
|
22 |
+
coarse_ckpt="./models/vampnet/coarse.pth",
|
23 |
+
coarse2fine_ckpt="./models/vampnet/c2f.pth",
|
24 |
+
codec_ckpt="./models/vampnet/codec.pth",
|
25 |
+
wavebeat_ckpt="./models/wavebeat.pth",
|
26 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
27 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
+
# loader = AudioLoader()
|
30 |
+
print(f"interface device is {interface.device}")
|
31 |
|
32 |
+
# dataset = at.data.datasets.AudioDataset(
|
33 |
+
# loader,
|
34 |
+
# sample_rate=interface.codec.sample_rate,
|
35 |
+
# duration=interface.coarse.chunk_size_s,
|
36 |
+
# n_examples=5000,
|
37 |
+
# without_replacement=True,
|
38 |
+
# )
|
39 |
|
40 |
OUT_DIR = Path("gradio-outputs")
|
41 |
OUT_DIR.mkdir(exist_ok=True, parents=True)
|
|
|
50 |
)
|
51 |
sig = interface.preprocess(sig)
|
52 |
|
53 |
+
out_dir = OUT_DIR / str(uuid.uuid4())
|
54 |
out_dir.mkdir(parents=True, exist_ok=True)
|
55 |
sig.write(out_dir / "input.wav")
|
56 |
return sig.path_to_file
|
|
|
68 |
out_dir = OUT_DIR / str(uuid.uuid4())
|
69 |
out_dir.mkdir()
|
70 |
sig = at.AudioSignal(data[input_audio])
|
|
|
|
|
|
|
|
|
71 |
|
72 |
z = interface.encode(sig)
|
73 |
|
|
|
107 |
mask = pmask.codebook_unmask(mask, ncc)
|
108 |
|
109 |
|
110 |
+
print(f"created mask with: linear random {data[rand_mask_intensity]}, inpaint {data[prefix_s]}:{data[suffix_s]}, periodic {data[periodic_p]}:{data[periodic_w]}, dropout {data[dropout]}, codebook unmask {ncc}, onset mask {data[onset_mask_width]}, num steps {data[num_steps]}, init temp {data[temp]}, use coarse2fine {data[use_coarse2fine]}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
# save the mask as a txt file
|
112 |
np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
|
113 |
|
|
|
114 |
zv, mask_z = interface.coarse_vamp(
|
115 |
z,
|
116 |
mask=mask,
|
117 |
sampling_steps=data[num_steps],
|
118 |
+
temperature=float(data[temp]*10),
|
|
|
119 |
return_mask=True,
|
120 |
typical_filtering=data[typical_filtering],
|
121 |
typical_mass=data[typical_mass],
|
122 |
typical_min_tokens=data[typical_min_tokens],
|
|
|
123 |
gen_fn=interface.coarse.generate,
|
|
|
|
|
124 |
)
|
125 |
|
126 |
if use_coarse2fine:
|
127 |
+
zv = interface.coarse_to_fine(zv, temperature=data[temp])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
sig = interface.to_signal(zv).cpu()
|
130 |
print("done")
|
|
|
157 |
sig_out.write(out_dir / "output.wav")
|
158 |
|
159 |
_data = {
|
160 |
+
"temp": data[temp],
|
|
|
|
|
161 |
"prefix_s": data[prefix_s],
|
162 |
"suffix_s": data[suffix_s],
|
163 |
"rand_mask_intensity": data[rand_mask_intensity],
|
|
|
168 |
"n_conditioning_codebooks": data[n_conditioning_codebooks],
|
169 |
"use_coarse2fine": data[use_coarse2fine],
|
170 |
"stretch_factor": data[stretch_factor],
|
|
|
|
|
171 |
}
|
172 |
|
173 |
# save with yaml
|
|
|
183 |
return f"saved! your save code is {out_dir.stem}", zip_path
|
184 |
|
185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
with gr.Blocks() as demo:
|
187 |
|
188 |
with gr.Row():
|
189 |
with gr.Column():
|
190 |
+
gr.Markdown("# VampNet")
|
191 |
gr.Markdown("""## Description:
|
192 |
+
This is a demo of VampNet, a masked generative music model capable of doing music variations.
|
193 |
You can control the extent and nature of variation with a set of manual controls and presets.
|
194 |
Use this interface to experiment with different mask settings and explore the audio outputs.
|
195 |
""")
|
|
|
197 |
gr.Markdown("""
|
198 |
## Instructions:
|
199 |
1. You can start by uploading some audio, or by loading the example audio.
|
200 |
+
2. Choose a preset for the vamp operation, or manually adjust the controls to customize the mask settings. Click the load preset button.
|
201 |
+
3. Click the "generate (vamp)!!!" button to generate audio. Listen to the output audio, and the masked audio to hear the mask hints.
|
202 |
4. Optionally, you can add some notes and save the result.
|
203 |
5. You can also use the output as the new input and continue experimenting!
|
204 |
""")
|
|
|
249 |
"beat_mask_downbeats": False,
|
250 |
},
|
251 |
"slight periodic variation": {
|
252 |
+
"periodic_p": 7,
|
253 |
+
"onset_mask_width": 0,
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
"beat_mask_width": 0,
|
255 |
"beat_mask_downbeats": False,
|
256 |
},
|
257 |
"strong periodic variation": {
|
258 |
+
"periodic_p": 13,
|
259 |
"onset_mask_width": 5,
|
260 |
"beat_mask_width": 0,
|
261 |
"beat_mask_downbeats": False,
|
262 |
},
|
263 |
"very strong periodic variation": {
|
264 |
+
"periodic_p": 17,
|
265 |
"onset_mask_width": 5,
|
266 |
"beat_mask_width": 0,
|
267 |
"beat_mask_downbeats": False,
|
|
|
269 |
"beat-driven variation": {
|
270 |
"periodic_p": 0,
|
271 |
"onset_mask_width": 0,
|
272 |
+
"beat_mask_width": 20,
|
273 |
"beat_mask_downbeats": False,
|
274 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
"beat-driven variation (downbeats only, strong)": {
|
276 |
"periodic_p": 0,
|
277 |
"onset_mask_width": 0,
|
|
|
293 |
minimum=0,
|
294 |
maximum=128,
|
295 |
step=1,
|
296 |
+
value=13,
|
297 |
)
|
298 |
|
299 |
|
300 |
onset_mask_width = gr.Slider(
|
301 |
label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
|
302 |
minimum=0,
|
303 |
+
maximum=20,
|
304 |
step=1,
|
305 |
value=5,
|
306 |
)
|
307 |
|
308 |
beat_mask_width = gr.Slider(
|
309 |
+
label="beat mask width (in milliseconds)",
|
310 |
minimum=0,
|
311 |
maximum=200,
|
312 |
value=0,
|
|
|
318 |
|
319 |
|
320 |
with gr.Accordion("extras ", open=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
rand_mask_intensity = gr.Slider(
|
322 |
label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
|
323 |
minimum=0.0,
|
|
|
377 |
value=0.0
|
378 |
)
|
379 |
|
380 |
+
temp = gr.Slider(
|
381 |
+
label="temperature",
|
382 |
minimum=0.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
383 |
maximum=10.0,
|
384 |
+
value=1.8
|
|
|
385 |
)
|
386 |
+
|
387 |
|
388 |
|
389 |
with gr.Accordion("sampling settings", open=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
390 |
typical_filtering = gr.Checkbox(
|
391 |
label="typical filtering ",
|
392 |
value=False
|
393 |
)
|
394 |
+
typical_mass = gr.Slider(
|
395 |
label="typical mass (should probably stay between 0.1 and 0.5)",
|
396 |
minimum=0.01,
|
397 |
maximum=0.99,
|
|
|
404 |
step=1,
|
405 |
value=64
|
406 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
407 |
|
408 |
use_coarse2fine = gr.Checkbox(
|
409 |
label="use coarse2fine",
|
410 |
+
value=True
|
|
|
411 |
)
|
412 |
|
413 |
num_steps = gr.Slider(
|
|
|
427 |
)
|
428 |
|
429 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
430 |
# mask settings
|
431 |
with gr.Column():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
432 |
vamp_button = gr.Button("generate (vamp)!!!")
|
433 |
output_audio = gr.Audio(
|
434 |
label="output audio",
|
|
|
454 |
_inputs = {
|
455 |
input_audio,
|
456 |
num_steps,
|
457 |
+
temp,
|
|
|
|
|
458 |
prefix_s, suffix_s,
|
459 |
rand_mask_intensity,
|
460 |
periodic_p, periodic_w,
|
|
|
467 |
typical_mass,
|
468 |
typical_min_tokens,
|
469 |
beat_mask_width,
|
470 |
+
beat_mask_downbeats
|
|
|
|
|
|
|
|
|
471 |
}
|
472 |
|
473 |
# connect widgets
|
|
|
497 |
outputs=[thank_you, download_file]
|
498 |
)
|
499 |
|
500 |
+
demo.queue().launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/generated-v0/berta-goldman-speech/c2f.yml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
AudioDataset.duration: 3.0
|
4 |
+
AudioDataset.loudness_cutoff: -40.0
|
5 |
+
VampNet.embedding_dim: 1280
|
6 |
+
VampNet.n_codebooks: 14
|
7 |
+
VampNet.n_conditioning_codebooks: 4
|
8 |
+
VampNet.n_heads: 20
|
9 |
+
VampNet.n_layers: 16
|
10 |
+
fine_tune: true
|
11 |
+
save_path: ./runs/berta-goldman-speech/c2f
|
12 |
+
train/AudioLoader.sources:
|
13 |
+
- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3
|
14 |
+
val/AudioLoader.sources:
|
15 |
+
- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3
|
conf/generated-v0/berta-goldman-speech/coarse.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
fine_tune: true
|
4 |
+
save_path: ./runs/berta-goldman-speech/coarse
|
5 |
+
train/AudioLoader.sources:
|
6 |
+
- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3
|
7 |
+
val/AudioLoader.sources:
|
8 |
+
- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3
|
conf/generated-v0/berta-goldman-speech/interface.yml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioLoader.sources:
|
2 |
+
- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3
|
3 |
+
Interface.coarse2fine_ckpt: ./runs/berta-goldman-speech/c2f/best/vampnet/weights.pth
|
4 |
+
Interface.coarse_ckpt: ./runs/berta-goldman-speech/coarse/best/vampnet/weights.pth
|
5 |
+
Interface.codec_ckpt: ./models/spotdl/codec.pth
|
conf/generated-v0/gamelan-xeno-canto/c2f.yml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
AudioDataset.duration: 3.0
|
4 |
+
AudioDataset.loudness_cutoff: -40.0
|
5 |
+
VampNet.embedding_dim: 1280
|
6 |
+
VampNet.n_codebooks: 14
|
7 |
+
VampNet.n_conditioning_codebooks: 4
|
8 |
+
VampNet.n_heads: 20
|
9 |
+
VampNet.n_layers: 16
|
10 |
+
fine_tune: true
|
11 |
+
save_path: ./runs/gamelan-xeno-canto/c2f
|
12 |
+
train/AudioLoader.sources:
|
13 |
+
- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3
|
14 |
+
- /media/CHONK/hugo/loras/xeno-canto-2
|
15 |
+
val/AudioLoader.sources:
|
16 |
+
- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3
|
17 |
+
- /media/CHONK/hugo/loras/xeno-canto-2
|
conf/generated-v0/gamelan-xeno-canto/coarse.yml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
fine_tune: true
|
4 |
+
save_path: ./runs/gamelan-xeno-canto/coarse
|
5 |
+
train/AudioLoader.sources:
|
6 |
+
- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3
|
7 |
+
- /media/CHONK/hugo/loras/xeno-canto-2
|
8 |
+
val/AudioLoader.sources:
|
9 |
+
- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3
|
10 |
+
- /media/CHONK/hugo/loras/xeno-canto-2
|
conf/generated-v0/gamelan-xeno-canto/interface.yml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioLoader.sources:
|
2 |
+
- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3
|
3 |
+
- /media/CHONK/hugo/loras/xeno-canto-2
|
4 |
+
Interface.coarse2fine_ckpt: ./runs/gamelan-xeno-canto/c2f/best/vampnet/weights.pth
|
5 |
+
Interface.coarse_ckpt: ./runs/gamelan-xeno-canto/coarse/best/vampnet/weights.pth
|
6 |
+
Interface.codec_ckpt: ./models/spotdl/codec.pth
|
conf/generated-v0/nasralla/c2f.yml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
AudioDataset.duration: 3.0
|
4 |
+
AudioDataset.loudness_cutoff: -40.0
|
5 |
+
VampNet.embedding_dim: 1280
|
6 |
+
VampNet.n_codebooks: 14
|
7 |
+
VampNet.n_conditioning_codebooks: 4
|
8 |
+
VampNet.n_heads: 20
|
9 |
+
VampNet.n_layers: 16
|
10 |
+
fine_tune: true
|
11 |
+
save_path: ./runs/nasralla/c2f
|
12 |
+
train/AudioLoader.sources:
|
13 |
+
- /media/CHONK/hugo/nasralla
|
14 |
+
val/AudioLoader.sources:
|
15 |
+
- /media/CHONK/hugo/nasralla
|
conf/generated-v0/nasralla/coarse.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
fine_tune: true
|
4 |
+
save_path: ./runs/nasralla/coarse
|
5 |
+
train/AudioLoader.sources:
|
6 |
+
- /media/CHONK/hugo/nasralla
|
7 |
+
val/AudioLoader.sources:
|
8 |
+
- /media/CHONK/hugo/nasralla
|
conf/generated-v0/nasralla/interface.yml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioLoader.sources:
|
2 |
+
- /media/CHONK/hugo/nasralla
|
3 |
+
Interface.coarse2fine_ckpt: ./runs/nasralla/c2f/best/vampnet/weights.pth
|
4 |
+
Interface.coarse_ckpt: ./runs/nasralla/coarse/best/vampnet/weights.pth
|
5 |
+
Interface.codec_ckpt: ./models/spotdl/codec.pth
|
conf/generated/breaks-steps/c2f.yml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
AudioDataset.duration: 3.0
|
4 |
+
AudioDataset.loudness_cutoff: -40.0
|
5 |
+
VampNet.embedding_dim: 1280
|
6 |
+
VampNet.n_codebooks: 14
|
7 |
+
VampNet.n_conditioning_codebooks: 4
|
8 |
+
VampNet.n_heads: 20
|
9 |
+
VampNet.n_layers: 16
|
10 |
+
fine_tune: true
|
11 |
+
fine_tune_checkpoint: ./models/spotdl/c2f.pth
|
12 |
+
save_path: ./runs/breaks-steps/c2f
|
13 |
+
train/AudioLoader.sources: &id001
|
14 |
+
- /media/CHONK/hugo/breaks-steps
|
15 |
+
val/AudioLoader.sources: *id001
|
conf/generated/breaks-steps/coarse.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
fine_tune: true
|
4 |
+
fine_tune_checkpoint: ./models/spotdl/coarse.pth
|
5 |
+
save_path: ./runs/breaks-steps/coarse
|
6 |
+
train/AudioLoader.sources: &id001
|
7 |
+
- /media/CHONK/hugo/breaks-steps
|
8 |
+
val/AudioLoader.sources: *id001
|
conf/generated/breaks-steps/interface.yml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioLoader.sources:
|
2 |
+
- - /media/CHONK/hugo/breaks-steps
|
3 |
+
Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
|
4 |
+
Interface.coarse2fine_lora_ckpt: ./runs/breaks-steps/c2f/latest/lora.pth
|
5 |
+
Interface.coarse_ckpt: ./models/spotdl/coarse.pth
|
6 |
+
Interface.coarse_lora_ckpt: ./runs/breaks-steps/coarse/latest/lora.pth
|
7 |
+
Interface.codec_ckpt: ./models/spotdl/codec.pth
|
conf/generated/bulgarian-tv-choir/c2f.yml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
AudioDataset.duration: 3.0
|
4 |
+
AudioDataset.loudness_cutoff: -40.0
|
5 |
+
VampNet.embedding_dim: 1280
|
6 |
+
VampNet.n_codebooks: 14
|
7 |
+
VampNet.n_conditioning_codebooks: 4
|
8 |
+
VampNet.n_heads: 20
|
9 |
+
VampNet.n_layers: 16
|
10 |
+
fine_tune: true
|
11 |
+
fine_tune_checkpoint: ./models/spotdl/c2f.pth
|
12 |
+
save_path: ./runs/bulgarian-tv-choir/c2f
|
13 |
+
train/AudioLoader.sources: &id001
|
14 |
+
- /media/CHONK/hugo/loras/bulgarian-female-tv-choir/
|
15 |
+
val/AudioLoader.sources: *id001
|
conf/generated/bulgarian-tv-choir/coarse.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
fine_tune: true
|
4 |
+
fine_tune_checkpoint: ./models/spotdl/coarse.pth
|
5 |
+
save_path: ./runs/bulgarian-tv-choir/coarse
|
6 |
+
train/AudioLoader.sources: &id001
|
7 |
+
- /media/CHONK/hugo/loras/bulgarian-female-tv-choir/
|
8 |
+
val/AudioLoader.sources: *id001
|
conf/generated/bulgarian-tv-choir/interface.yml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioLoader.sources:
|
2 |
+
- - /media/CHONK/hugo/loras/bulgarian-female-tv-choir/
|
3 |
+
Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
|
4 |
+
Interface.coarse2fine_lora_ckpt: ./runs/bulgarian-tv-choir/c2f/latest/lora.pth
|
5 |
+
Interface.coarse_ckpt: ./models/spotdl/coarse.pth
|
6 |
+
Interface.coarse_lora_ckpt: ./runs/bulgarian-tv-choir/coarse/latest/lora.pth
|
7 |
+
Interface.codec_ckpt: ./models/spotdl/codec.pth
|
conf/generated/dariacore/c2f.yml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
AudioDataset.duration: 3.0
|
4 |
+
AudioDataset.loudness_cutoff: -40.0
|
5 |
+
VampNet.embedding_dim: 1280
|
6 |
+
VampNet.n_codebooks: 14
|
7 |
+
VampNet.n_conditioning_codebooks: 4
|
8 |
+
VampNet.n_heads: 20
|
9 |
+
VampNet.n_layers: 16
|
10 |
+
fine_tune: true
|
11 |
+
fine_tune_checkpoint: ./models/spotdl/c2f.pth
|
12 |
+
save_path: ./runs/dariacore/c2f
|
13 |
+
train/AudioLoader.sources: &id001
|
14 |
+
- /media/CHONK/hugo/loras/dariacore
|
15 |
+
val/AudioLoader.sources: *id001
|
conf/generated/dariacore/coarse.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
fine_tune: true
|
4 |
+
fine_tune_checkpoint: ./models/spotdl/coarse.pth
|
5 |
+
save_path: ./runs/dariacore/coarse
|
6 |
+
train/AudioLoader.sources: &id001
|
7 |
+
- /media/CHONK/hugo/loras/dariacore
|
8 |
+
val/AudioLoader.sources: *id001
|
conf/generated/dariacore/interface.yml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioLoader.sources:
|
2 |
+
- - /media/CHONK/hugo/loras/dariacore
|
3 |
+
Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
|
4 |
+
Interface.coarse2fine_lora_ckpt: ./runs/dariacore/c2f/latest/lora.pth
|
5 |
+
Interface.coarse_ckpt: ./models/spotdl/coarse.pth
|
6 |
+
Interface.coarse_lora_ckpt: ./runs/dariacore/coarse/latest/lora.pth
|
7 |
+
Interface.codec_ckpt: ./models/spotdl/codec.pth
|
conf/generated/musica-bolero-marimba/c2f.yml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
AudioDataset.duration: 3.0
|
4 |
+
AudioDataset.loudness_cutoff: -40.0
|
5 |
+
VampNet.embedding_dim: 1280
|
6 |
+
VampNet.n_codebooks: 14
|
7 |
+
VampNet.n_conditioning_codebooks: 4
|
8 |
+
VampNet.n_heads: 20
|
9 |
+
VampNet.n_layers: 16
|
10 |
+
fine_tune: true
|
11 |
+
fine_tune_checkpoint: ./models/spotdl/c2f.pth
|
12 |
+
save_path: ./runs/musica-bolero-marimba/c2f
|
13 |
+
train/AudioLoader.sources:
|
14 |
+
- /media/CHONK/hugo/loras/boleros
|
15 |
+
- /media/CHONK/hugo/loras/marimba-honduras
|
16 |
+
val/AudioLoader.sources:
|
17 |
+
- /media/CHONK/hugo/loras/boleros
|
18 |
+
- /media/CHONK/hugo/loras/marimba-honduras
|
conf/generated/musica-bolero-marimba/coarse.yml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
fine_tune: true
|
4 |
+
fine_tune_checkpoint: ./models/spotdl/coarse.pth
|
5 |
+
save_path: ./runs/musica-bolero-marimba/coarse
|
6 |
+
train/AudioLoader.sources:
|
7 |
+
- /media/CHONK/hugo/loras/boleros
|
8 |
+
- /media/CHONK/hugo/loras/marimba-honduras
|
9 |
+
val/AudioLoader.sources:
|
10 |
+
- /media/CHONK/hugo/loras/boleros
|
11 |
+
- /media/CHONK/hugo/loras/marimba-honduras
|
conf/generated/musica-bolero-marimba/interface.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioLoader.sources:
|
2 |
+
- /media/CHONK/hugo/loras/boleros
|
3 |
+
- /media/CHONK/hugo/loras/marimba-honduras
|
4 |
+
Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
|
5 |
+
Interface.coarse2fine_lora_ckpt: ./runs/musica-bolero-marimba/c2f/latest/lora.pth
|
6 |
+
Interface.coarse_ckpt: ./models/spotdl/coarse.pth
|
7 |
+
Interface.coarse_lora_ckpt: ./runs/musica-bolero-marimba/coarse/latest/lora.pth
|
8 |
+
Interface.codec_ckpt: ./models/spotdl/codec.pth
|
conf/generated/panchos/c2f.yml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
AudioDataset.duration: 3.0
|
4 |
+
AudioDataset.loudness_cutoff: -40.0
|
5 |
+
VampNet.embedding_dim: 1280
|
6 |
+
VampNet.n_codebooks: 14
|
7 |
+
VampNet.n_conditioning_codebooks: 4
|
8 |
+
VampNet.n_heads: 20
|
9 |
+
VampNet.n_layers: 16
|
10 |
+
fine_tune: true
|
11 |
+
fine_tune_checkpoint: ./models/spotdl/c2f.pth
|
12 |
+
save_path: ./runs/panchos/c2f
|
13 |
+
train/AudioLoader.sources: &id001
|
14 |
+
- /media/CHONK/hugo/loras/panchos/
|
15 |
+
val/AudioLoader.sources: *id001
|
conf/generated/panchos/coarse.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
fine_tune: true
|
4 |
+
fine_tune_checkpoint: ./models/spotdl/coarse.pth
|
5 |
+
save_path: ./runs/panchos/coarse
|
6 |
+
train/AudioLoader.sources: &id001
|
7 |
+
- /media/CHONK/hugo/loras/panchos/
|
8 |
+
val/AudioLoader.sources: *id001
|
conf/generated/panchos/interface.yml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioLoader.sources:
|
2 |
+
- - /media/CHONK/hugo/loras/panchos/
|
3 |
+
Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
|
4 |
+
Interface.coarse2fine_lora_ckpt: ./runs/panchos/c2f/latest/lora.pth
|
5 |
+
Interface.coarse_ckpt: ./models/spotdl/coarse.pth
|
6 |
+
Interface.coarse_lora_ckpt: ./runs/panchos/coarse/latest/lora.pth
|
7 |
+
Interface.codec_ckpt: ./models/spotdl/codec.pth
|
conf/generated/titi-monkey/c2f.yml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
AudioDataset.duration: 3.0
|
4 |
+
AudioDataset.loudness_cutoff: -40.0
|
5 |
+
VampNet.embedding_dim: 1280
|
6 |
+
VampNet.n_codebooks: 14
|
7 |
+
VampNet.n_conditioning_codebooks: 4
|
8 |
+
VampNet.n_heads: 20
|
9 |
+
VampNet.n_layers: 16
|
10 |
+
fine_tune: true
|
11 |
+
fine_tune_checkpoint: ./models/spotdl/c2f.pth
|
12 |
+
save_path: ./runs/titi-monkey/c2f
|
13 |
+
train/AudioLoader.sources: &id001
|
14 |
+
- /media/CHONK/hugo/loras/titi-monkey.mp3
|
15 |
+
val/AudioLoader.sources: *id001
|
conf/generated/titi-monkey/coarse.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
fine_tune: true
|
4 |
+
fine_tune_checkpoint: ./models/spotdl/coarse.pth
|
5 |
+
save_path: ./runs/titi-monkey/coarse
|
6 |
+
train/AudioLoader.sources: &id001
|
7 |
+
- /media/CHONK/hugo/loras/titi-monkey.mp3
|
8 |
+
val/AudioLoader.sources: *id001
|
conf/generated/titi-monkey/interface.yml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioLoader.sources:
|
2 |
+
- - /media/CHONK/hugo/loras/titi-monkey.mp3
|
3 |
+
Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
|
4 |
+
Interface.coarse2fine_lora_ckpt: ./runs/titi-monkey/c2f/latest/lora.pth
|
5 |
+
Interface.coarse_ckpt: ./models/spotdl/coarse.pth
|
6 |
+
Interface.coarse_lora_ckpt: ./runs/titi-monkey/coarse/latest/lora.pth
|
7 |
+
Interface.codec_ckpt: ./models/spotdl/codec.pth
|
conf/generated/xeno-canto/c2f.yml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
AudioDataset.duration: 3.0
|
4 |
+
AudioDataset.loudness_cutoff: -40.0
|
5 |
+
VampNet.embedding_dim: 1280
|
6 |
+
VampNet.n_codebooks: 14
|
7 |
+
VampNet.n_conditioning_codebooks: 4
|
8 |
+
VampNet.n_heads: 20
|
9 |
+
VampNet.n_layers: 16
|
10 |
+
fine_tune: true
|
11 |
+
fine_tune_checkpoint: ./models/spotdl/c2f.pth
|
12 |
+
save_path: ./runs/xeno-canto/c2f
|
13 |
+
train/AudioLoader.sources: &id001
|
14 |
+
- /media/CHONK/hugo/loras/xeno-canto-2/
|
15 |
+
val/AudioLoader.sources: *id001
|
conf/generated/xeno-canto/coarse.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
fine_tune: true
|
4 |
+
fine_tune_checkpoint: ./models/spotdl/coarse.pth
|
5 |
+
save_path: ./runs/xeno-canto/coarse
|
6 |
+
train/AudioLoader.sources: &id001
|
7 |
+
- /media/CHONK/hugo/loras/xeno-canto-2/
|
8 |
+
val/AudioLoader.sources: *id001
|
conf/generated/xeno-canto/interface.yml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioLoader.sources:
|
2 |
+
- - /media/CHONK/hugo/loras/xeno-canto-2/
|
3 |
+
Interface.coarse2fine_ckpt: ./mod els/spotdl/c2f.pth
|
4 |
+
Interface.coarse2fine_lora_ckpt: ./runs/xeno-canto/c2f/latest/lora.pth
|
5 |
+
Interface.coarse_ckpt: ./models/spotdl/coarse.pth
|
6 |
+
Interface.coarse_lora_ckpt: ./runs/xeno-canto/coarse/latest/lora.pth
|
7 |
+
Interface.codec_ckpt: ./models/spotdl/codec.pth
|
conf/lora/birds.yml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
|
4 |
+
fine_tune: True
|
5 |
+
|
6 |
+
train/AudioLoader.sources:
|
7 |
+
- /media/CHONK/hugo/spotdl/subsets/birds
|
8 |
+
|
9 |
+
val/AudioLoader.sources:
|
10 |
+
- /media/CHONK/hugo/spotdl/subsets/birds
|
conf/lora/birdss.yml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
|
4 |
+
fine_tune: True
|
5 |
+
|
6 |
+
train/AudioLoader.sources:
|
7 |
+
- /media/CHONK/hugo/spotdl/subsets/birds
|
8 |
+
- /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/
|
9 |
+
|
10 |
+
val/AudioLoader.sources:
|
11 |
+
- /media/CHONK/hugo/spotdl/subsets/birds
|
12 |
+
- /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/
|
conf/lora/constructions.yml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
|
4 |
+
fine_tune: True
|
5 |
+
|
6 |
+
train/AudioLoader.sources:
|
7 |
+
- /media/CHONK/hugo/spotdl/subsets/constructions/third.mp3
|
8 |
+
|
9 |
+
val/AudioLoader.sources:
|
10 |
+
- /media/CHONK/hugo/spotdl/subsets/constructions/third.mp3
|
conf/lora/ella-baila-sola.yml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
|
4 |
+
fine_tune: True
|
5 |
+
|
6 |
+
train/AudioLoader.sources:
|
7 |
+
- /media/CHONK/hugo/spotdl/subsets/ella-baila-sola.mp3
|
8 |
+
|
9 |
+
val/AudioLoader.sources:
|
10 |
+
- /media/CHONK/hugo/spotdl/subsets/ella-baila-sola.mp3
|
conf/lora/gas-station.yml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
|
4 |
+
fine_tune: True
|
5 |
+
|
6 |
+
train/AudioLoader.sources:
|
7 |
+
- /media/CHONK/hugo/spotdl/subsets/gas-station-sushi.mp3
|
8 |
+
|
9 |
+
val/AudioLoader.sources:
|
10 |
+
- /media/CHONK/hugo/spotdl/subsets/gas-station-sushi.mp3
|
conf/lora/lora-is-this-charlie-parker.yml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
|
4 |
+
fine_tune: True
|
5 |
+
|
6 |
+
train/AudioLoader.sources:
|
7 |
+
- /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/Charlie Parker - Donna Lee.mp3
|
8 |
+
|
9 |
+
val/AudioLoader.sources:
|
10 |
+
- /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/Charlie Parker - Donna Lee.mp3
|
conf/lora/lora.yml
CHANGED
@@ -3,20 +3,20 @@ $include:
|
|
3 |
|
4 |
fine_tune: True
|
5 |
|
6 |
-
train/AudioDataset.n_examples:
|
7 |
-
|
|
|
8 |
|
9 |
|
10 |
NoamScheduler.warmup: 500
|
11 |
|
12 |
-
batch_size:
|
13 |
num_workers: 7
|
14 |
-
|
15 |
-
|
16 |
-
val_freq: 500
|
17 |
|
18 |
AdamW.lr: 0.0001
|
19 |
|
20 |
# let's us organize sound classes into folders and choose from those sound classes uniformly
|
21 |
AudioDataset.without_replacement: False
|
22 |
-
|
|
|
3 |
|
4 |
fine_tune: True
|
5 |
|
6 |
+
train/AudioDataset.n_examples: 10000000
|
7 |
+
|
8 |
+
val/AudioDataset.n_examples: 10
|
9 |
|
10 |
|
11 |
NoamScheduler.warmup: 500
|
12 |
|
13 |
+
batch_size: 7
|
14 |
num_workers: 7
|
15 |
+
epoch_length: 100
|
16 |
+
save_audio_epochs: 10
|
|
|
17 |
|
18 |
AdamW.lr: 0.0001
|
19 |
|
20 |
# let's us organize sound classes into folders and choose from those sound classes uniformly
|
21 |
AudioDataset.without_replacement: False
|
22 |
+
max_epochs: 500
|
conf/lora/underworld.yml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
|
4 |
+
fine_tune: True
|
5 |
+
|
6 |
+
train/AudioLoader.sources:
|
7 |
+
- /media/CHONK/hugo/spotdl/subsets/underworld.mp3
|
8 |
+
|
9 |
+
val/AudioLoader.sources:
|
10 |
+
- /media/CHONK/hugo/spotdl/subsets/underworld.mp3
|
conf/lora/xeno-canto/c2f.yml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
|
4 |
+
fine_tune: True
|
5 |
+
|
6 |
+
train/AudioLoader.sources:
|
7 |
+
- /media/CHONK/hugo/xeno-canto-2
|
8 |
+
|
9 |
+
val/AudioLoader.sources:
|
10 |
+
- /media/CHONK/hugo/xeno-canto-2
|
11 |
+
|
12 |
+
|
13 |
+
VampNet.n_codebooks: 14
|
14 |
+
VampNet.n_conditioning_codebooks: 4
|
15 |
+
|
16 |
+
VampNet.embedding_dim: 1280
|
17 |
+
VampNet.n_layers: 16
|
18 |
+
VampNet.n_heads: 20
|
19 |
+
|
20 |
+
AudioDataset.duration: 3.0
|
21 |
+
AudioDataset.loudness_cutoff: -40.0
|
conf/lora/xeno-canto/coarse.yml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
|
4 |
+
fine_tune: True
|
5 |
+
|
6 |
+
train/AudioLoader.sources:
|
7 |
+
- /media/CHONK/hugo/xeno-canto-2
|
8 |
+
|
9 |
+
val/AudioLoader.sources:
|
10 |
+
- /media/CHONK/hugo/xeno-canto-2
|
conf/vampnet-musdb-drums.yml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/vampnet.yml
|
3 |
+
|
4 |
+
VampNet.embedding_dim: 512
|
5 |
+
VampNet.n_layers: 12
|
6 |
+
VampNet.n_heads: 8
|
7 |
+
|
8 |
+
AudioDataset.duration: 12.0
|
9 |
+
|
10 |
+
train/AudioDataset.n_examples: 10000000
|
11 |
+
train/AudioLoader.sources:
|
12 |
+
- /data/musdb18hq/train/**/*drums.wav
|
13 |
+
|
14 |
+
|
15 |
+
val/AudioDataset.n_examples: 500
|
16 |
+
val/AudioLoader.sources:
|
17 |
+
- /data/musdb18hq/test/**/*drums.wav
|
18 |
+
|
19 |
+
|
20 |
+
test/AudioDataset.n_examples: 1000
|
21 |
+
test/AudioLoader.sources:
|
22 |
+
- /data/musdb18hq/test/**/*drums.wav
|
conf/vampnet.yml
CHANGED
@@ -1,17 +1,21 @@
|
|
1 |
|
2 |
-
codec_ckpt: ./models/
|
3 |
save_path: ckpt
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
val_idx: [0,1,2,3,4,5,6,7,8,9]
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
10 |
|
11 |
batch_size: 8
|
12 |
num_workers: 10
|
13 |
|
14 |
# Optimization
|
|
|
15 |
amp: false
|
16 |
|
17 |
CrossEntropyLoss.label_smoothing: 0.1
|
@@ -21,6 +25,9 @@ AdamW.lr: 0.001
|
|
21 |
NoamScheduler.factor: 2.0
|
22 |
NoamScheduler.warmup: 10000
|
23 |
|
|
|
|
|
|
|
24 |
VampNet.vocab_size: 1024
|
25 |
VampNet.n_codebooks: 4
|
26 |
VampNet.n_conditioning_codebooks: 0
|
@@ -32,7 +39,7 @@ VampNet.n_heads: 20
|
|
32 |
VampNet.flash_attn: false
|
33 |
VampNet.dropout: 0.1
|
34 |
|
35 |
-
AudioLoader.relative_path:
|
36 |
AudioDataset.loudness_cutoff: -30.0
|
37 |
AudioDataset.without_replacement: true
|
38 |
AudioLoader.shuffle: true
|
@@ -41,9 +48,12 @@ AudioDataset.duration: 10.0
|
|
41 |
|
42 |
train/AudioDataset.n_examples: 10000000
|
43 |
train/AudioLoader.sources:
|
44 |
-
- /
|
45 |
|
46 |
val/AudioDataset.n_examples: 2000
|
47 |
val/AudioLoader.sources:
|
48 |
-
- /
|
49 |
|
|
|
|
|
|
|
|
1 |
|
2 |
+
codec_ckpt: ./models/spotdl/codec.pth
|
3 |
save_path: ckpt
|
4 |
+
max_epochs: 1000
|
5 |
+
epoch_length: 1000
|
6 |
+
save_audio_epochs: 2
|
7 |
val_idx: [0,1,2,3,4,5,6,7,8,9]
|
8 |
+
|
9 |
+
prefix_amt: 0.0
|
10 |
+
suffix_amt: 0.0
|
11 |
+
prefix_dropout: 0.1
|
12 |
+
suffix_dropout: 0.1
|
13 |
|
14 |
batch_size: 8
|
15 |
num_workers: 10
|
16 |
|
17 |
# Optimization
|
18 |
+
detect_anomaly: false
|
19 |
amp: false
|
20 |
|
21 |
CrossEntropyLoss.label_smoothing: 0.1
|
|
|
25 |
NoamScheduler.factor: 2.0
|
26 |
NoamScheduler.warmup: 10000
|
27 |
|
28 |
+
PitchShift.shift_amount: [const, 0]
|
29 |
+
PitchShift.prob: 0.0
|
30 |
+
|
31 |
VampNet.vocab_size: 1024
|
32 |
VampNet.n_codebooks: 4
|
33 |
VampNet.n_conditioning_codebooks: 0
|
|
|
39 |
VampNet.flash_attn: false
|
40 |
VampNet.dropout: 0.1
|
41 |
|
42 |
+
AudioLoader.relative_path: /data/
|
43 |
AudioDataset.loudness_cutoff: -30.0
|
44 |
AudioDataset.without_replacement: true
|
45 |
AudioLoader.shuffle: true
|
|
|
48 |
|
49 |
train/AudioDataset.n_examples: 10000000
|
50 |
train/AudioLoader.sources:
|
51 |
+
- /data/spotdl/audio/train
|
52 |
|
53 |
val/AudioDataset.n_examples: 2000
|
54 |
val/AudioLoader.sources:
|
55 |
+
- /data/spotdl/audio/val
|
56 |
|
57 |
+
test/AudioDataset.n_examples: 1000
|
58 |
+
test/AudioLoader.sources:
|
59 |
+
- /data/spotdl/audio/test
|
requirements.txt
CHANGED
@@ -1,10 +1,8 @@
|
|
1 |
torch
|
2 |
argbind>=0.3.2
|
3 |
-
numpy==1.
|
4 |
gradio
|
5 |
loralib
|
6 |
wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
|
7 |
lac @ git+https://github.com/hugofloresgarcia/lac.git
|
8 |
-
|
9 |
-
-e git+https://github.com/audacitorch/pyharp.git#egg=pyharp
|
10 |
-
torch_pitch_shift
|
|
|
1 |
torch
|
2 |
argbind>=0.3.2
|
3 |
+
numpy==1.22
|
4 |
gradio
|
5 |
loralib
|
6 |
wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
|
7 |
lac @ git+https://github.com/hugofloresgarcia/lac.git
|
8 |
+
audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git
|
|
|
|
scripts/exp/fine_tune.py
CHANGED
@@ -35,7 +35,7 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
|
|
35 |
"AudioDataset.duration": 3.0,
|
36 |
"AudioDataset.loudness_cutoff": -40.0,
|
37 |
"save_path": f"./runs/{name}/c2f",
|
38 |
-
"fine_tune_checkpoint": "./models/
|
39 |
}
|
40 |
|
41 |
finetune_coarse_conf = {
|
@@ -44,16 +44,17 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
|
|
44 |
"train/AudioLoader.sources": audio_files_or_folders,
|
45 |
"val/AudioLoader.sources": audio_files_or_folders,
|
46 |
"save_path": f"./runs/{name}/coarse",
|
47 |
-
"fine_tune_checkpoint": "./models/
|
48 |
}
|
49 |
|
50 |
interface_conf = {
|
51 |
-
"Interface.coarse_ckpt": f"./
|
|
|
52 |
|
53 |
-
"Interface.coarse2fine_ckpt": f"./
|
54 |
-
"Interface.
|
55 |
|
56 |
-
"Interface.codec_ckpt": "./models/
|
57 |
"AudioLoader.sources": [audio_files_or_folders],
|
58 |
}
|
59 |
|
|
|
35 |
"AudioDataset.duration": 3.0,
|
36 |
"AudioDataset.loudness_cutoff": -40.0,
|
37 |
"save_path": f"./runs/{name}/c2f",
|
38 |
+
"fine_tune_checkpoint": "./models/spotdl/c2f.pth"
|
39 |
}
|
40 |
|
41 |
finetune_coarse_conf = {
|
|
|
44 |
"train/AudioLoader.sources": audio_files_or_folders,
|
45 |
"val/AudioLoader.sources": audio_files_or_folders,
|
46 |
"save_path": f"./runs/{name}/coarse",
|
47 |
+
"fine_tune_checkpoint": "./models/spotdl/coarse.pth"
|
48 |
}
|
49 |
|
50 |
interface_conf = {
|
51 |
+
"Interface.coarse_ckpt": f"./models/spotdl/coarse.pth",
|
52 |
+
"Interface.coarse_lora_ckpt": f"./runs/{name}/coarse/latest/lora.pth",
|
53 |
|
54 |
+
"Interface.coarse2fine_ckpt": f"./models/spotdl/c2f.pth",
|
55 |
+
"Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
|
56 |
|
57 |
+
"Interface.codec_ckpt": "./models/spotdl/codec.pth",
|
58 |
"AudioLoader.sources": [audio_files_or_folders],
|
59 |
}
|
60 |
|
scripts/exp/train.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import os
|
2 |
-
import
|
|
|
3 |
import warnings
|
4 |
from pathlib import Path
|
5 |
from typing import Optional
|
6 |
-
from dataclasses import dataclass
|
7 |
|
8 |
import argbind
|
9 |
import audiotools as at
|
@@ -14,7 +14,7 @@ from audiotools.data import transforms
|
|
14 |
from einops import rearrange
|
15 |
from rich import pretty
|
16 |
from rich.traceback import install
|
17 |
-
from
|
18 |
|
19 |
import vampnet
|
20 |
from vampnet.modules.transformer import VampNet
|
@@ -23,15 +23,6 @@ from vampnet import mask as pmask
|
|
23 |
# from dac.model.dac import DAC
|
24 |
from lac.model.lac import LAC as DAC
|
25 |
|
26 |
-
from audiotools.ml.decorators import (
|
27 |
-
timer, Tracker, when
|
28 |
-
)
|
29 |
-
|
30 |
-
import loralib as lora
|
31 |
-
|
32 |
-
import torch._dynamo
|
33 |
-
torch._dynamo.config.verbose=True
|
34 |
-
|
35 |
|
36 |
# Enable cudnn autotuner to speed up training
|
37 |
# (can be altered by the funcs.seed function)
|
@@ -94,7 +85,11 @@ def build_datasets(args, sample_rate: int):
|
|
94 |
)
|
95 |
with argbind.scope(args, "val"):
|
96 |
val_data = AudioDataset(AudioLoader(), sample_rate, transform=build_transform())
|
97 |
-
|
|
|
|
|
|
|
|
|
98 |
|
99 |
|
100 |
def rand_float(shape, low, high, rng):
|
@@ -105,392 +100,16 @@ def flip_coin(shape, p, rng):
|
|
105 |
return rng.draw(shape)[:, 0] < p
|
106 |
|
107 |
|
108 |
-
def num_params_hook(o, p):
|
109 |
-
return o + f" {p/1e6:<.3f}M params."
|
110 |
-
|
111 |
-
|
112 |
-
def add_num_params_repr_hook(model):
|
113 |
-
import numpy as np
|
114 |
-
from functools import partial
|
115 |
-
|
116 |
-
for n, m in model.named_modules():
|
117 |
-
o = m.extra_repr()
|
118 |
-
p = sum([np.prod(p.size()) for p in m.parameters()])
|
119 |
-
|
120 |
-
setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p))
|
121 |
-
|
122 |
-
|
123 |
-
def accuracy(
|
124 |
-
preds: torch.Tensor,
|
125 |
-
target: torch.Tensor,
|
126 |
-
top_k: int = 1,
|
127 |
-
ignore_index: Optional[int] = None,
|
128 |
-
) -> torch.Tensor:
|
129 |
-
# Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class)
|
130 |
-
preds = rearrange(preds, "b p s -> (b s) p")
|
131 |
-
target = rearrange(target, "b s -> (b s)")
|
132 |
-
|
133 |
-
# return torchmetrics.functional.accuracy(preds, target, task='multiclass', top_k=topk, num_classes=preds.shape[-1], ignore_index=ignore_index)
|
134 |
-
if ignore_index is not None:
|
135 |
-
# Create a mask for the ignored index
|
136 |
-
mask = target != ignore_index
|
137 |
-
# Apply the mask to the target and predictions
|
138 |
-
preds = preds[mask]
|
139 |
-
target = target[mask]
|
140 |
-
|
141 |
-
# Get the top-k predicted classes and their indices
|
142 |
-
_, pred_indices = torch.topk(preds, k=top_k, dim=-1)
|
143 |
-
|
144 |
-
# Determine if the true target is in the top-k predicted classes
|
145 |
-
correct = torch.sum(torch.eq(pred_indices, target.unsqueeze(1)), dim=1)
|
146 |
-
|
147 |
-
# Calculate the accuracy
|
148 |
-
accuracy = torch.mean(correct.float())
|
149 |
-
|
150 |
-
return accuracy
|
151 |
-
|
152 |
-
def _metrics(z_hat, r, target, flat_mask, output):
|
153 |
-
for r_range in [(0, 0.5), (0.5, 1.0)]:
|
154 |
-
unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX)
|
155 |
-
masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
156 |
-
|
157 |
-
assert target.shape[0] == r.shape[0]
|
158 |
-
# grab the indices of the r values that are in the range
|
159 |
-
r_idx = (r >= r_range[0]) & (r < r_range[1])
|
160 |
-
|
161 |
-
# grab the target and z_hat values that are in the range
|
162 |
-
r_unmasked_target = unmasked_target[r_idx]
|
163 |
-
r_masked_target = masked_target[r_idx]
|
164 |
-
r_z_hat = z_hat[r_idx]
|
165 |
-
|
166 |
-
for topk in (1, 25):
|
167 |
-
s, e = r_range
|
168 |
-
tag = f"accuracy-{s}-{e}/top{topk}"
|
169 |
-
|
170 |
-
output[f"{tag}/unmasked"] = accuracy(
|
171 |
-
preds=r_z_hat,
|
172 |
-
target=r_unmasked_target,
|
173 |
-
ignore_index=IGNORE_INDEX,
|
174 |
-
top_k=topk,
|
175 |
-
)
|
176 |
-
output[f"{tag}/masked"] = accuracy(
|
177 |
-
preds=r_z_hat,
|
178 |
-
target=r_masked_target,
|
179 |
-
ignore_index=IGNORE_INDEX,
|
180 |
-
top_k=topk,
|
181 |
-
)
|
182 |
-
|
183 |
-
|
184 |
-
@dataclass
|
185 |
-
class State:
|
186 |
-
model: VampNet
|
187 |
-
codec: DAC
|
188 |
-
|
189 |
-
optimizer: AdamW
|
190 |
-
scheduler: NoamScheduler
|
191 |
-
criterion: CrossEntropyLoss
|
192 |
-
grad_clip_val: float
|
193 |
-
|
194 |
-
rng: torch.quasirandom.SobolEngine
|
195 |
-
|
196 |
-
train_data: AudioDataset
|
197 |
-
val_data: AudioDataset
|
198 |
-
|
199 |
-
tracker: Tracker
|
200 |
-
|
201 |
-
|
202 |
-
@timer()
|
203 |
-
def train_loop(state: State, batch: dict, accel: Accelerator):
|
204 |
-
state.model.train()
|
205 |
-
batch = at.util.prepare_batch(batch, accel.device)
|
206 |
-
signal = apply_transform(state.train_data.transform, batch)
|
207 |
-
|
208 |
-
output = {}
|
209 |
-
vn = accel.unwrap(state.model)
|
210 |
-
with accel.autocast():
|
211 |
-
with torch.inference_mode():
|
212 |
-
state.codec.to(accel.device)
|
213 |
-
z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
|
214 |
-
z = z[:, : vn.n_codebooks, :]
|
215 |
-
|
216 |
-
n_batch = z.shape[0]
|
217 |
-
r = state.rng.draw(n_batch)[:, 0].to(accel.device)
|
218 |
-
|
219 |
-
mask = pmask.random(z, r)
|
220 |
-
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
221 |
-
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
222 |
-
|
223 |
-
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
224 |
-
|
225 |
-
dtype = torch.bfloat16 if accel.amp else None
|
226 |
-
with accel.autocast(dtype=dtype):
|
227 |
-
z_hat = state.model(z_mask_latent)
|
228 |
-
|
229 |
-
target = codebook_flatten(
|
230 |
-
z[:, vn.n_conditioning_codebooks :, :],
|
231 |
-
)
|
232 |
-
|
233 |
-
flat_mask = codebook_flatten(
|
234 |
-
mask[:, vn.n_conditioning_codebooks :, :],
|
235 |
-
)
|
236 |
-
|
237 |
-
# replace target with ignore index for masked tokens
|
238 |
-
t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
239 |
-
output["loss"] = state.criterion(z_hat, t_masked)
|
240 |
-
|
241 |
-
_metrics(
|
242 |
-
r=r,
|
243 |
-
z_hat=z_hat,
|
244 |
-
target=target,
|
245 |
-
flat_mask=flat_mask,
|
246 |
-
output=output,
|
247 |
-
)
|
248 |
-
|
249 |
-
|
250 |
-
accel.backward(output["loss"])
|
251 |
-
|
252 |
-
output["other/learning_rate"] = state.optimizer.param_groups[0]["lr"]
|
253 |
-
output["other/batch_size"] = z.shape[0]
|
254 |
-
|
255 |
-
|
256 |
-
accel.scaler.unscale_(state.optimizer)
|
257 |
-
output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_(
|
258 |
-
state.model.parameters(), state.grad_clip_val
|
259 |
-
)
|
260 |
-
|
261 |
-
accel.step(state.optimizer)
|
262 |
-
state.optimizer.zero_grad()
|
263 |
-
|
264 |
-
state.scheduler.step()
|
265 |
-
accel.update()
|
266 |
-
|
267 |
-
|
268 |
-
return {k: v for k, v in sorted(output.items())}
|
269 |
-
|
270 |
-
|
271 |
-
@timer()
|
272 |
-
@torch.no_grad()
|
273 |
-
def val_loop(state: State, batch: dict, accel: Accelerator):
|
274 |
-
state.model.eval()
|
275 |
-
state.codec.eval()
|
276 |
-
batch = at.util.prepare_batch(batch, accel.device)
|
277 |
-
signal = apply_transform(state.val_data.transform, batch)
|
278 |
-
|
279 |
-
vn = accel.unwrap(state.model)
|
280 |
-
z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
|
281 |
-
z = z[:, : vn.n_codebooks, :]
|
282 |
-
|
283 |
-
n_batch = z.shape[0]
|
284 |
-
r = state.rng.draw(n_batch)[:, 0].to(accel.device)
|
285 |
-
|
286 |
-
mask = pmask.random(z, r)
|
287 |
-
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
288 |
-
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
289 |
-
|
290 |
-
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
291 |
-
|
292 |
-
z_hat = state.model(z_mask_latent)
|
293 |
-
|
294 |
-
target = codebook_flatten(
|
295 |
-
z[:, vn.n_conditioning_codebooks :, :],
|
296 |
-
)
|
297 |
-
|
298 |
-
flat_mask = codebook_flatten(
|
299 |
-
mask[:, vn.n_conditioning_codebooks :, :]
|
300 |
-
)
|
301 |
-
|
302 |
-
output = {}
|
303 |
-
# replace target with ignore index for masked tokens
|
304 |
-
t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
305 |
-
output["loss"] = state.criterion(z_hat, t_masked)
|
306 |
-
|
307 |
-
_metrics(
|
308 |
-
r=r,
|
309 |
-
z_hat=z_hat,
|
310 |
-
target=target,
|
311 |
-
flat_mask=flat_mask,
|
312 |
-
output=output,
|
313 |
-
)
|
314 |
-
|
315 |
-
return output
|
316 |
-
|
317 |
-
|
318 |
-
def validate(state, val_dataloader, accel):
|
319 |
-
for batch in val_dataloader:
|
320 |
-
output = val_loop(state, batch, accel)
|
321 |
-
# Consolidate state dicts if using ZeroRedundancyOptimizer
|
322 |
-
if hasattr(state.optimizer, "consolidate_state_dict"):
|
323 |
-
state.optimizer.consolidate_state_dict()
|
324 |
-
return output
|
325 |
-
|
326 |
-
|
327 |
-
def checkpoint(state, save_iters, save_path, fine_tune):
|
328 |
-
if accel.local_rank != 0:
|
329 |
-
state.tracker.print(f"ERROR:Skipping checkpoint on rank {accel.local_rank}")
|
330 |
-
return
|
331 |
-
|
332 |
-
metadata = {"logs": dict(state.tracker.history)}
|
333 |
-
|
334 |
-
tags = ["latest"]
|
335 |
-
state.tracker.print(f"Saving to {str(Path('.').absolute())}")
|
336 |
-
|
337 |
-
if state.tracker.step in save_iters:
|
338 |
-
tags.append(f"{state.tracker.step // 1000}k")
|
339 |
-
|
340 |
-
if state.tracker.is_best("val", "loss"):
|
341 |
-
state.tracker.print(f"Best model so far")
|
342 |
-
tags.append("best")
|
343 |
-
|
344 |
-
if fine_tune:
|
345 |
-
for tag in tags:
|
346 |
-
# save the lora model
|
347 |
-
(Path(save_path) / tag).mkdir(parents=True, exist_ok=True)
|
348 |
-
torch.save(
|
349 |
-
lora.lora_state_dict(accel.unwrap(state.model)),
|
350 |
-
f"{save_path}/{tag}/lora.pth"
|
351 |
-
)
|
352 |
-
|
353 |
-
for tag in tags:
|
354 |
-
model_extra = {
|
355 |
-
"optimizer.pth": state.optimizer.state_dict(),
|
356 |
-
"scheduler.pth": state.scheduler.state_dict(),
|
357 |
-
"tracker.pth": state.tracker.state_dict(),
|
358 |
-
"metadata.pth": metadata,
|
359 |
-
}
|
360 |
-
|
361 |
-
accel.unwrap(state.model).metadata = metadata
|
362 |
-
accel.unwrap(state.model).save_to_folder(
|
363 |
-
f"{save_path}/{tag}", model_extra, package=False
|
364 |
-
)
|
365 |
-
|
366 |
-
|
367 |
-
def save_sampled(state, z, writer):
|
368 |
-
num_samples = z.shape[0]
|
369 |
-
|
370 |
-
for i in range(num_samples):
|
371 |
-
sampled = accel.unwrap(state.model).generate(
|
372 |
-
codec=state.codec,
|
373 |
-
time_steps=z.shape[-1],
|
374 |
-
start_tokens=z[i : i + 1],
|
375 |
-
)
|
376 |
-
sampled.cpu().write_audio_to_tb(
|
377 |
-
f"sampled/{i}",
|
378 |
-
writer,
|
379 |
-
step=state.tracker.step,
|
380 |
-
plot_fn=None,
|
381 |
-
)
|
382 |
-
|
383 |
-
|
384 |
-
def save_imputation(state, z, val_idx, writer):
|
385 |
-
n_prefix = int(z.shape[-1] * 0.25)
|
386 |
-
n_suffix = int(z.shape[-1] * 0.25)
|
387 |
-
|
388 |
-
vn = accel.unwrap(state.model)
|
389 |
-
|
390 |
-
mask = pmask.inpaint(z, n_prefix, n_suffix)
|
391 |
-
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
392 |
-
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
393 |
-
|
394 |
-
imputed_noisy = vn.to_signal(z_mask, state.codec)
|
395 |
-
imputed_true = vn.to_signal(z, state.codec)
|
396 |
-
|
397 |
-
imputed = []
|
398 |
-
for i in range(len(z)):
|
399 |
-
imputed.append(
|
400 |
-
vn.generate(
|
401 |
-
codec=state.codec,
|
402 |
-
time_steps=z.shape[-1],
|
403 |
-
start_tokens=z[i][None, ...],
|
404 |
-
mask=mask[i][None, ...],
|
405 |
-
)
|
406 |
-
)
|
407 |
-
imputed = AudioSignal.batch(imputed)
|
408 |
-
|
409 |
-
for i in range(len(val_idx)):
|
410 |
-
imputed_noisy[i].cpu().write_audio_to_tb(
|
411 |
-
f"inpainted_prompt/{i}",
|
412 |
-
writer,
|
413 |
-
step=state.tracker.step,
|
414 |
-
plot_fn=None,
|
415 |
-
)
|
416 |
-
imputed[i].cpu().write_audio_to_tb(
|
417 |
-
f"inpainted_middle/{i}",
|
418 |
-
writer,
|
419 |
-
step=state.tracker.step,
|
420 |
-
plot_fn=None,
|
421 |
-
)
|
422 |
-
imputed_true[i].cpu().write_audio_to_tb(
|
423 |
-
f"reconstructed/{i}",
|
424 |
-
writer,
|
425 |
-
step=state.tracker.step,
|
426 |
-
plot_fn=None,
|
427 |
-
)
|
428 |
-
|
429 |
-
|
430 |
-
@torch.no_grad()
|
431 |
-
def save_samples(state: State, val_idx: int, writer: SummaryWriter):
|
432 |
-
state.model.eval()
|
433 |
-
state.codec.eval()
|
434 |
-
vn = accel.unwrap(state.model)
|
435 |
-
|
436 |
-
batch = [state.val_data[i] for i in val_idx]
|
437 |
-
batch = at.util.prepare_batch(state.val_data.collate(batch), accel.device)
|
438 |
-
|
439 |
-
signal = apply_transform(state.val_data.transform, batch)
|
440 |
-
|
441 |
-
z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
|
442 |
-
z = z[:, : vn.n_codebooks, :]
|
443 |
-
|
444 |
-
r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device)
|
445 |
-
|
446 |
-
|
447 |
-
mask = pmask.random(z, r)
|
448 |
-
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
449 |
-
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
450 |
-
|
451 |
-
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
452 |
-
|
453 |
-
z_hat = state.model(z_mask_latent)
|
454 |
-
|
455 |
-
z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
|
456 |
-
z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks)
|
457 |
-
z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
|
458 |
-
|
459 |
-
generated = vn.to_signal(z_pred, state.codec)
|
460 |
-
reconstructed = vn.to_signal(z, state.codec)
|
461 |
-
masked = vn.to_signal(z_mask.squeeze(1), state.codec)
|
462 |
-
|
463 |
-
for i in range(generated.batch_size):
|
464 |
-
audio_dict = {
|
465 |
-
"original": signal[i],
|
466 |
-
"masked": masked[i],
|
467 |
-
"generated": generated[i],
|
468 |
-
"reconstructed": reconstructed[i],
|
469 |
-
}
|
470 |
-
for k, v in audio_dict.items():
|
471 |
-
v.cpu().write_audio_to_tb(
|
472 |
-
f"onestep/_{i}.r={r[i]:0.2f}/{k}",
|
473 |
-
writer,
|
474 |
-
step=state.tracker.step,
|
475 |
-
plot_fn=None,
|
476 |
-
)
|
477 |
-
|
478 |
-
save_sampled(state=state, z=z, writer=writer)
|
479 |
-
save_imputation(state=state, z=z, val_idx=val_idx, writer=writer)
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
@argbind.bind(without_prefix=True)
|
484 |
def load(
|
485 |
args,
|
486 |
accel: at.ml.Accelerator,
|
487 |
-
tracker: Tracker,
|
488 |
save_path: str,
|
489 |
resume: bool = False,
|
490 |
tag: str = "latest",
|
|
|
491 |
fine_tune_checkpoint: Optional[str] = None,
|
492 |
-
|
493 |
-
) -> State:
|
494 |
codec = DAC.load(args["codec_ckpt"], map_location="cpu")
|
495 |
codec.eval()
|
496 |
|
@@ -500,9 +119,8 @@ def load(
|
|
500 |
kwargs = {
|
501 |
"folder": f"{save_path}/{tag}",
|
502 |
"map_location": "cpu",
|
503 |
-
"package":
|
504 |
}
|
505 |
-
tracker.print(f"Loading checkpoint from {kwargs['folder']}")
|
506 |
if (Path(kwargs["folder"]) / "vampnet").exists():
|
507 |
model, v_extra = VampNet.load_from_folder(**kwargs)
|
508 |
else:
|
@@ -513,14 +131,11 @@ def load(
|
|
513 |
|
514 |
if args["fine_tune"]:
|
515 |
assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
|
516 |
-
model =
|
517 |
-
|
518 |
-
map_location="cpu",
|
519 |
-
)
|
520 |
-
)
|
521 |
|
|
|
522 |
|
523 |
-
model = torch.compile(VampNet()) if model is None else model
|
524 |
model = accel.prepare_model(model)
|
525 |
|
526 |
# assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
|
@@ -532,57 +147,89 @@ def load(
|
|
532 |
scheduler = NoamScheduler(optimizer, d_model=accel.unwrap(model).embedding_dim)
|
533 |
scheduler.step()
|
534 |
|
|
|
|
|
535 |
if "optimizer.pth" in v_extra:
|
536 |
optimizer.load_state_dict(v_extra["optimizer.pth"])
|
|
|
537 |
scheduler.load_state_dict(v_extra["scheduler.pth"])
|
538 |
-
if "
|
539 |
-
|
540 |
-
|
541 |
-
criterion = CrossEntropyLoss()
|
542 |
|
543 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
544 |
|
545 |
-
# a better rng for sampling from our schedule
|
546 |
-
rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=args["seed"])
|
547 |
|
548 |
-
# log a model summary w/ num params
|
549 |
-
if accel.local_rank == 0:
|
550 |
-
add_num_params_repr_hook(accel.unwrap(model))
|
551 |
-
with open(f"{save_path}/model.txt", "w") as f:
|
552 |
-
f.write(repr(accel.unwrap(model)))
|
553 |
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
569 |
|
570 |
|
571 |
@argbind.bind(without_prefix=True)
|
572 |
def train(
|
573 |
args,
|
574 |
accel: at.ml.Accelerator,
|
575 |
-
seed: int = 0,
|
576 |
codec_ckpt: str = None,
|
|
|
577 |
save_path: str = "ckpt",
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
batch_size: int =
|
|
|
583 |
val_idx: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
|
584 |
num_workers: int = 10,
|
|
|
|
|
585 |
fine_tune: bool = False,
|
|
|
586 |
):
|
587 |
assert codec_ckpt is not None, "codec_ckpt is required"
|
588 |
|
@@ -594,79 +241,376 @@ def train(
|
|
594 |
writer = SummaryWriter(log_dir=f"{save_path}/logs/")
|
595 |
argbind.dump_args(args, f"{save_path}/args.yml")
|
596 |
|
597 |
-
tracker = Tracker(
|
598 |
-
writer=writer, log_file=f"{save_path}/log.txt", rank=accel.local_rank
|
599 |
-
)
|
600 |
-
|
601 |
# load the codec model
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
609 |
train_dataloader = accel.prepare_dataloader(
|
610 |
-
|
611 |
-
start_idx=
|
612 |
num_workers=num_workers,
|
613 |
batch_size=batch_size,
|
614 |
-
collate_fn=
|
615 |
)
|
616 |
val_dataloader = accel.prepare_dataloader(
|
617 |
-
|
618 |
start_idx=0,
|
619 |
num_workers=num_workers,
|
620 |
batch_size=batch_size,
|
621 |
-
collate_fn=
|
622 |
-
persistent_workers=num_workers > 0,
|
623 |
)
|
624 |
-
print("initialized dataloader.")
|
625 |
|
626 |
-
|
627 |
|
628 |
if fine_tune:
|
629 |
-
lora
|
630 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
631 |
|
632 |
-
|
633 |
-
|
634 |
-
global train_loop, val_loop, validate, save_samples, checkpoint
|
635 |
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
val_loop = tracker.track("val", len(val_dataloader))(val_loop)
|
640 |
-
validate = tracker.log("val", "mean")(validate)
|
641 |
|
642 |
-
|
643 |
-
checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
|
644 |
|
645 |
-
|
646 |
-
with tracker.live:
|
647 |
-
for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
|
648 |
-
train_loop(state, batch, accel)
|
649 |
|
650 |
-
|
651 |
-
|
652 |
)
|
653 |
|
654 |
-
|
655 |
-
|
|
|
656 |
|
657 |
-
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
664 |
|
665 |
-
|
666 |
-
tracker.done("val", f"Iteration {tracker.step}")
|
667 |
|
668 |
-
|
669 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
670 |
|
671 |
|
672 |
if __name__ == "__main__":
|
@@ -674,6 +618,4 @@ if __name__ == "__main__":
|
|
674 |
args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0
|
675 |
with argbind.scope(args):
|
676 |
with Accelerator() as accel:
|
677 |
-
if accel.local_rank != 0:
|
678 |
-
sys.tracebacklimit = 0
|
679 |
train(args, accel)
|
|
|
1 |
import os
|
2 |
+
import subprocess
|
3 |
+
import time
|
4 |
import warnings
|
5 |
from pathlib import Path
|
6 |
from typing import Optional
|
|
|
7 |
|
8 |
import argbind
|
9 |
import audiotools as at
|
|
|
14 |
from einops import rearrange
|
15 |
from rich import pretty
|
16 |
from rich.traceback import install
|
17 |
+
from tensorboardX import SummaryWriter
|
18 |
|
19 |
import vampnet
|
20 |
from vampnet.modules.transformer import VampNet
|
|
|
23 |
# from dac.model.dac import DAC
|
24 |
from lac.model.lac import LAC as DAC
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
# Enable cudnn autotuner to speed up training
|
28 |
# (can be altered by the funcs.seed function)
|
|
|
85 |
)
|
86 |
with argbind.scope(args, "val"):
|
87 |
val_data = AudioDataset(AudioLoader(), sample_rate, transform=build_transform())
|
88 |
+
with argbind.scope(args, "test"):
|
89 |
+
test_data = AudioDataset(
|
90 |
+
AudioLoader(), sample_rate, transform=build_transform()
|
91 |
+
)
|
92 |
+
return train_data, val_data, test_data
|
93 |
|
94 |
|
95 |
def rand_float(shape, low, high, rng):
|
|
|
100 |
return rng.draw(shape)[:, 0] < p
|
101 |
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
@argbind.bind(without_prefix=True)
|
104 |
def load(
|
105 |
args,
|
106 |
accel: at.ml.Accelerator,
|
|
|
107 |
save_path: str,
|
108 |
resume: bool = False,
|
109 |
tag: str = "latest",
|
110 |
+
load_weights: bool = False,
|
111 |
fine_tune_checkpoint: Optional[str] = None,
|
112 |
+
):
|
|
|
113 |
codec = DAC.load(args["codec_ckpt"], map_location="cpu")
|
114 |
codec.eval()
|
115 |
|
|
|
119 |
kwargs = {
|
120 |
"folder": f"{save_path}/{tag}",
|
121 |
"map_location": "cpu",
|
122 |
+
"package": not load_weights,
|
123 |
}
|
|
|
124 |
if (Path(kwargs["folder"]) / "vampnet").exists():
|
125 |
model, v_extra = VampNet.load_from_folder(**kwargs)
|
126 |
else:
|
|
|
131 |
|
132 |
if args["fine_tune"]:
|
133 |
assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
|
134 |
+
model = VampNet.load(location=Path(fine_tune_checkpoint), map_location="cpu")
|
135 |
+
|
|
|
|
|
|
|
136 |
|
137 |
+
model = VampNet() if model is None else model
|
138 |
|
|
|
139 |
model = accel.prepare_model(model)
|
140 |
|
141 |
# assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
|
|
|
147 |
scheduler = NoamScheduler(optimizer, d_model=accel.unwrap(model).embedding_dim)
|
148 |
scheduler.step()
|
149 |
|
150 |
+
trainer_state = {"state_dict": None, "start_idx": 0}
|
151 |
+
|
152 |
if "optimizer.pth" in v_extra:
|
153 |
optimizer.load_state_dict(v_extra["optimizer.pth"])
|
154 |
+
if "scheduler.pth" in v_extra:
|
155 |
scheduler.load_state_dict(v_extra["scheduler.pth"])
|
156 |
+
if "trainer.pth" in v_extra:
|
157 |
+
trainer_state = v_extra["trainer.pth"]
|
|
|
|
|
158 |
|
159 |
+
return {
|
160 |
+
"model": model,
|
161 |
+
"codec": codec,
|
162 |
+
"optimizer": optimizer,
|
163 |
+
"scheduler": scheduler,
|
164 |
+
"trainer_state": trainer_state,
|
165 |
+
}
|
166 |
|
|
|
|
|
167 |
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
+
def num_params_hook(o, p):
|
170 |
+
return o + f" {p/1e6:<.3f}M params."
|
171 |
+
|
172 |
+
|
173 |
+
def add_num_params_repr_hook(model):
|
174 |
+
import numpy as np
|
175 |
+
from functools import partial
|
176 |
+
|
177 |
+
for n, m in model.named_modules():
|
178 |
+
o = m.extra_repr()
|
179 |
+
p = sum([np.prod(p.size()) for p in m.parameters()])
|
180 |
+
|
181 |
+
setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p))
|
182 |
+
|
183 |
+
|
184 |
+
def accuracy(
|
185 |
+
preds: torch.Tensor,
|
186 |
+
target: torch.Tensor,
|
187 |
+
top_k: int = 1,
|
188 |
+
ignore_index: Optional[int] = None,
|
189 |
+
) -> torch.Tensor:
|
190 |
+
# Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class)
|
191 |
+
preds = rearrange(preds, "b p s -> (b s) p")
|
192 |
+
target = rearrange(target, "b s -> (b s)")
|
193 |
+
|
194 |
+
# return torchmetrics.functional.accuracy(preds, target, task='multiclass', top_k=topk, num_classes=preds.shape[-1], ignore_index=ignore_index)
|
195 |
+
if ignore_index is not None:
|
196 |
+
# Create a mask for the ignored index
|
197 |
+
mask = target != ignore_index
|
198 |
+
# Apply the mask to the target and predictions
|
199 |
+
preds = preds[mask]
|
200 |
+
target = target[mask]
|
201 |
+
|
202 |
+
# Get the top-k predicted classes and their indices
|
203 |
+
_, pred_indices = torch.topk(preds, k=top_k, dim=-1)
|
204 |
+
|
205 |
+
# Determine if the true target is in the top-k predicted classes
|
206 |
+
correct = torch.sum(torch.eq(pred_indices, target.unsqueeze(1)), dim=1)
|
207 |
+
|
208 |
+
# Calculate the accuracy
|
209 |
+
accuracy = torch.mean(correct.float())
|
210 |
+
|
211 |
+
return accuracy
|
212 |
|
213 |
|
214 |
@argbind.bind(without_prefix=True)
|
215 |
def train(
|
216 |
args,
|
217 |
accel: at.ml.Accelerator,
|
|
|
218 |
codec_ckpt: str = None,
|
219 |
+
seed: int = 0,
|
220 |
save_path: str = "ckpt",
|
221 |
+
max_epochs: int = int(100e3),
|
222 |
+
epoch_length: int = 1000,
|
223 |
+
save_audio_epochs: int = 2,
|
224 |
+
save_epochs: list = [10, 50, 100, 200, 300, 400,],
|
225 |
+
batch_size: int = 48,
|
226 |
+
grad_acc_steps: int = 1,
|
227 |
val_idx: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
|
228 |
num_workers: int = 10,
|
229 |
+
detect_anomaly: bool = False,
|
230 |
+
grad_clip_val: float = 5.0,
|
231 |
fine_tune: bool = False,
|
232 |
+
quiet: bool = False,
|
233 |
):
|
234 |
assert codec_ckpt is not None, "codec_ckpt is required"
|
235 |
|
|
|
241 |
writer = SummaryWriter(log_dir=f"{save_path}/logs/")
|
242 |
argbind.dump_args(args, f"{save_path}/args.yml")
|
243 |
|
|
|
|
|
|
|
|
|
244 |
# load the codec model
|
245 |
+
loaded = load(args, accel, save_path)
|
246 |
+
model = loaded["model"]
|
247 |
+
codec = loaded["codec"]
|
248 |
+
optimizer = loaded["optimizer"]
|
249 |
+
scheduler = loaded["scheduler"]
|
250 |
+
trainer_state = loaded["trainer_state"]
|
251 |
|
252 |
+
sample_rate = codec.sample_rate
|
253 |
+
|
254 |
+
# a better rng for sampling from our schedule
|
255 |
+
rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=seed)
|
256 |
+
|
257 |
+
# log a model summary w/ num params
|
258 |
+
if accel.local_rank == 0:
|
259 |
+
add_num_params_repr_hook(accel.unwrap(model))
|
260 |
+
with open(f"{save_path}/model.txt", "w") as f:
|
261 |
+
f.write(repr(accel.unwrap(model)))
|
262 |
+
|
263 |
+
# load the datasets
|
264 |
+
train_data, val_data, _ = build_datasets(args, sample_rate)
|
265 |
train_dataloader = accel.prepare_dataloader(
|
266 |
+
train_data,
|
267 |
+
start_idx=trainer_state["start_idx"],
|
268 |
num_workers=num_workers,
|
269 |
batch_size=batch_size,
|
270 |
+
collate_fn=train_data.collate,
|
271 |
)
|
272 |
val_dataloader = accel.prepare_dataloader(
|
273 |
+
val_data,
|
274 |
start_idx=0,
|
275 |
num_workers=num_workers,
|
276 |
batch_size=batch_size,
|
277 |
+
collate_fn=val_data.collate,
|
|
|
278 |
)
|
|
|
279 |
|
280 |
+
criterion = CrossEntropyLoss()
|
281 |
|
282 |
if fine_tune:
|
283 |
+
import loralib as lora
|
284 |
+
lora.mark_only_lora_as_trainable(model)
|
285 |
+
|
286 |
+
|
287 |
+
class Trainer(at.ml.BaseTrainer):
|
288 |
+
_last_grad_norm = 0.0
|
289 |
+
|
290 |
+
def _metrics(self, vn, z_hat, r, target, flat_mask, output):
|
291 |
+
for r_range in [(0, 0.5), (0.5, 1.0)]:
|
292 |
+
unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX)
|
293 |
+
masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
294 |
+
|
295 |
+
assert target.shape[0] == r.shape[0]
|
296 |
+
# grab the indices of the r values that are in the range
|
297 |
+
r_idx = (r >= r_range[0]) & (r < r_range[1])
|
298 |
+
|
299 |
+
# grab the target and z_hat values that are in the range
|
300 |
+
r_unmasked_target = unmasked_target[r_idx]
|
301 |
+
r_masked_target = masked_target[r_idx]
|
302 |
+
r_z_hat = z_hat[r_idx]
|
303 |
+
|
304 |
+
for topk in (1, 25):
|
305 |
+
s, e = r_range
|
306 |
+
tag = f"accuracy-{s}-{e}/top{topk}"
|
307 |
+
|
308 |
+
output[f"{tag}/unmasked"] = accuracy(
|
309 |
+
preds=r_z_hat,
|
310 |
+
target=r_unmasked_target,
|
311 |
+
ignore_index=IGNORE_INDEX,
|
312 |
+
top_k=topk,
|
313 |
+
)
|
314 |
+
output[f"{tag}/masked"] = accuracy(
|
315 |
+
preds=r_z_hat,
|
316 |
+
target=r_masked_target,
|
317 |
+
ignore_index=IGNORE_INDEX,
|
318 |
+
top_k=topk,
|
319 |
+
)
|
320 |
+
|
321 |
+
def train_loop(self, engine, batch):
|
322 |
+
model.train()
|
323 |
+
batch = at.util.prepare_batch(batch, accel.device)
|
324 |
+
signal = apply_transform(train_data.transform, batch)
|
325 |
+
|
326 |
+
output = {}
|
327 |
+
vn = accel.unwrap(model)
|
328 |
+
with accel.autocast():
|
329 |
+
with torch.inference_mode():
|
330 |
+
codec.to(accel.device)
|
331 |
+
z = codec.encode(signal.samples, signal.sample_rate)["codes"]
|
332 |
+
z = z[:, : vn.n_codebooks, :]
|
333 |
+
|
334 |
+
n_batch = z.shape[0]
|
335 |
+
r = rng.draw(n_batch)[:, 0].to(accel.device)
|
336 |
+
|
337 |
+
mask = pmask.random(z, r)
|
338 |
+
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
339 |
+
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
340 |
+
|
341 |
+
z_mask_latent = vn.embedding.from_codes(z_mask, codec)
|
342 |
+
|
343 |
+
dtype = torch.bfloat16 if accel.amp else None
|
344 |
+
with accel.autocast(dtype=dtype):
|
345 |
+
z_hat = model(z_mask_latent, r)
|
346 |
+
|
347 |
+
target = codebook_flatten(
|
348 |
+
z[:, vn.n_conditioning_codebooks :, :],
|
349 |
+
)
|
350 |
+
|
351 |
+
flat_mask = codebook_flatten(
|
352 |
+
mask[:, vn.n_conditioning_codebooks :, :],
|
353 |
+
)
|
354 |
+
|
355 |
+
# replace target with ignore index for masked tokens
|
356 |
+
t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
357 |
+
output["loss"] = criterion(z_hat, t_masked)
|
358 |
+
|
359 |
+
self._metrics(
|
360 |
+
vn=vn,
|
361 |
+
r=r,
|
362 |
+
z_hat=z_hat,
|
363 |
+
target=target,
|
364 |
+
flat_mask=flat_mask,
|
365 |
+
output=output,
|
366 |
+
)
|
367 |
+
|
368 |
+
|
369 |
+
accel.backward(output["loss"] / grad_acc_steps)
|
370 |
+
|
371 |
+
output["other/learning_rate"] = optimizer.param_groups[0]["lr"]
|
372 |
+
output["other/batch_size"] = z.shape[0]
|
373 |
+
|
374 |
+
if (
|
375 |
+
(engine.state.iteration % grad_acc_steps == 0)
|
376 |
+
or (engine.state.iteration % epoch_length == 0)
|
377 |
+
or (engine.state.iteration % epoch_length == 1)
|
378 |
+
): # (or we reached the end of the epoch)
|
379 |
+
accel.scaler.unscale_(optimizer)
|
380 |
+
output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_(
|
381 |
+
model.parameters(), grad_clip_val
|
382 |
+
)
|
383 |
+
self._last_grad_norm = output["other/grad_norm"]
|
384 |
+
|
385 |
+
accel.step(optimizer)
|
386 |
+
optimizer.zero_grad()
|
387 |
+
|
388 |
+
scheduler.step()
|
389 |
+
accel.update()
|
390 |
+
else:
|
391 |
+
output["other/grad_norm"] = self._last_grad_norm
|
392 |
+
|
393 |
+
return {k: v for k, v in sorted(output.items())}
|
394 |
+
|
395 |
+
@torch.no_grad()
|
396 |
+
def val_loop(self, engine, batch):
|
397 |
+
model.eval()
|
398 |
+
codec.eval()
|
399 |
+
batch = at.util.prepare_batch(batch, accel.device)
|
400 |
+
signal = apply_transform(val_data.transform, batch)
|
401 |
+
|
402 |
+
vn = accel.unwrap(model)
|
403 |
+
z = codec.encode(signal.samples, signal.sample_rate)["codes"]
|
404 |
+
z = z[:, : vn.n_codebooks, :]
|
405 |
|
406 |
+
n_batch = z.shape[0]
|
407 |
+
r = rng.draw(n_batch)[:, 0].to(accel.device)
|
|
|
408 |
|
409 |
+
mask = pmask.random(z, r)
|
410 |
+
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
411 |
+
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
|
|
|
|
412 |
|
413 |
+
z_mask_latent = vn.embedding.from_codes(z_mask, codec)
|
|
|
414 |
|
415 |
+
z_hat = model(z_mask_latent, r)
|
|
|
|
|
|
|
416 |
|
417 |
+
target = codebook_flatten(
|
418 |
+
z[:, vn.n_conditioning_codebooks :, :],
|
419 |
)
|
420 |
|
421 |
+
flat_mask = codebook_flatten(
|
422 |
+
mask[:, vn.n_conditioning_codebooks :, :]
|
423 |
+
)
|
424 |
|
425 |
+
output = {}
|
426 |
+
# replace target with ignore index for masked tokens
|
427 |
+
t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
428 |
+
output["loss"] = criterion(z_hat, t_masked)
|
429 |
+
|
430 |
+
self._metrics(
|
431 |
+
vn=vn,
|
432 |
+
r=r,
|
433 |
+
z_hat=z_hat,
|
434 |
+
target=target,
|
435 |
+
flat_mask=flat_mask,
|
436 |
+
output=output,
|
437 |
+
)
|
438 |
|
439 |
+
return output
|
|
|
440 |
|
441 |
+
def checkpoint(self, engine):
|
442 |
+
if accel.local_rank != 0:
|
443 |
+
print(f"ERROR:Skipping checkpoint on rank {accel.local_rank}")
|
444 |
+
return
|
445 |
+
|
446 |
+
metadata = {"logs": dict(engine.state.logs["epoch"])}
|
447 |
+
|
448 |
+
if self.state.epoch % save_audio_epochs == 0:
|
449 |
+
self.save_samples()
|
450 |
+
|
451 |
+
tags = ["latest"]
|
452 |
+
loss_key = "loss/val" if "loss/val" in metadata["logs"] else "loss/train"
|
453 |
+
self.print(f"Saving to {str(Path('.').absolute())}")
|
454 |
+
|
455 |
+
if self.state.epoch in save_epochs:
|
456 |
+
tags.append(f"epoch={self.state.epoch}")
|
457 |
+
|
458 |
+
if self.is_best(engine, loss_key):
|
459 |
+
self.print(f"Best model so far")
|
460 |
+
tags.append("best")
|
461 |
+
|
462 |
+
if fine_tune:
|
463 |
+
for tag in tags:
|
464 |
+
# save the lora model
|
465 |
+
(Path(save_path) / tag).mkdir(parents=True, exist_ok=True)
|
466 |
+
torch.save(
|
467 |
+
lora.lora_state_dict(accel.unwrap(model)),
|
468 |
+
f"{save_path}/{tag}/lora.pth"
|
469 |
+
)
|
470 |
+
|
471 |
+
for tag in tags:
|
472 |
+
model_extra = {
|
473 |
+
"optimizer.pth": optimizer.state_dict(),
|
474 |
+
"scheduler.pth": scheduler.state_dict(),
|
475 |
+
"trainer.pth": {
|
476 |
+
"start_idx": self.state.iteration * batch_size,
|
477 |
+
"state_dict": self.state_dict(),
|
478 |
+
},
|
479 |
+
"metadata.pth": metadata,
|
480 |
+
}
|
481 |
+
|
482 |
+
accel.unwrap(model).metadata = metadata
|
483 |
+
accel.unwrap(model).save_to_folder(
|
484 |
+
f"{save_path}/{tag}", model_extra,
|
485 |
+
)
|
486 |
+
|
487 |
+
def save_sampled(self, z):
|
488 |
+
num_samples = z.shape[0]
|
489 |
+
|
490 |
+
for i in range(num_samples):
|
491 |
+
sampled = accel.unwrap(model).generate(
|
492 |
+
codec=codec,
|
493 |
+
time_steps=z.shape[-1],
|
494 |
+
start_tokens=z[i : i + 1],
|
495 |
+
)
|
496 |
+
sampled.cpu().write_audio_to_tb(
|
497 |
+
f"sampled/{i}",
|
498 |
+
self.writer,
|
499 |
+
step=self.state.epoch,
|
500 |
+
plot_fn=None,
|
501 |
+
)
|
502 |
+
|
503 |
+
|
504 |
+
def save_imputation(self, z: torch.Tensor):
|
505 |
+
n_prefix = int(z.shape[-1] * 0.25)
|
506 |
+
n_suffix = int(z.shape[-1] * 0.25)
|
507 |
+
|
508 |
+
vn = accel.unwrap(model)
|
509 |
+
|
510 |
+
mask = pmask.inpaint(z, n_prefix, n_suffix)
|
511 |
+
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
512 |
+
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
513 |
+
|
514 |
+
imputed_noisy = vn.to_signal(z_mask, codec)
|
515 |
+
imputed_true = vn.to_signal(z, codec)
|
516 |
+
|
517 |
+
imputed = []
|
518 |
+
for i in range(len(z)):
|
519 |
+
imputed.append(
|
520 |
+
vn.generate(
|
521 |
+
codec=codec,
|
522 |
+
time_steps=z.shape[-1],
|
523 |
+
start_tokens=z[i][None, ...],
|
524 |
+
mask=mask[i][None, ...],
|
525 |
+
)
|
526 |
+
)
|
527 |
+
imputed = AudioSignal.batch(imputed)
|
528 |
+
|
529 |
+
for i in range(len(val_idx)):
|
530 |
+
imputed_noisy[i].cpu().write_audio_to_tb(
|
531 |
+
f"imputed_noisy/{i}",
|
532 |
+
self.writer,
|
533 |
+
step=self.state.epoch,
|
534 |
+
plot_fn=None,
|
535 |
+
)
|
536 |
+
imputed[i].cpu().write_audio_to_tb(
|
537 |
+
f"imputed/{i}",
|
538 |
+
self.writer,
|
539 |
+
step=self.state.epoch,
|
540 |
+
plot_fn=None,
|
541 |
+
)
|
542 |
+
imputed_true[i].cpu().write_audio_to_tb(
|
543 |
+
f"imputed_true/{i}",
|
544 |
+
self.writer,
|
545 |
+
step=self.state.epoch,
|
546 |
+
plot_fn=None,
|
547 |
+
)
|
548 |
+
|
549 |
+
@torch.no_grad()
|
550 |
+
def save_samples(self):
|
551 |
+
model.eval()
|
552 |
+
codec.eval()
|
553 |
+
vn = accel.unwrap(model)
|
554 |
+
|
555 |
+
batch = [val_data[i] for i in val_idx]
|
556 |
+
batch = at.util.prepare_batch(val_data.collate(batch), accel.device)
|
557 |
+
|
558 |
+
signal = apply_transform(val_data.transform, batch)
|
559 |
+
|
560 |
+
z = codec.encode(signal.samples, signal.sample_rate)["codes"]
|
561 |
+
z = z[:, : vn.n_codebooks, :]
|
562 |
+
|
563 |
+
r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device)
|
564 |
+
|
565 |
+
|
566 |
+
mask = pmask.random(z, r)
|
567 |
+
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
568 |
+
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
569 |
+
|
570 |
+
z_mask_latent = vn.embedding.from_codes(z_mask, codec)
|
571 |
+
|
572 |
+
z_hat = model(z_mask_latent, r)
|
573 |
+
|
574 |
+
z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
|
575 |
+
z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks)
|
576 |
+
z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
|
577 |
+
|
578 |
+
generated = vn.to_signal(z_pred, codec)
|
579 |
+
reconstructed = vn.to_signal(z, codec)
|
580 |
+
masked = vn.to_signal(z_mask.squeeze(1), codec)
|
581 |
+
|
582 |
+
for i in range(generated.batch_size):
|
583 |
+
audio_dict = {
|
584 |
+
"original": signal[i],
|
585 |
+
"masked": masked[i],
|
586 |
+
"generated": generated[i],
|
587 |
+
"reconstructed": reconstructed[i],
|
588 |
+
}
|
589 |
+
for k, v in audio_dict.items():
|
590 |
+
v.cpu().write_audio_to_tb(
|
591 |
+
f"samples/_{i}.r={r[i]:0.2f}/{k}",
|
592 |
+
self.writer,
|
593 |
+
step=self.state.epoch,
|
594 |
+
plot_fn=None,
|
595 |
+
)
|
596 |
+
|
597 |
+
self.save_sampled(z)
|
598 |
+
self.save_imputation(z)
|
599 |
+
|
600 |
+
trainer = Trainer(writer=writer, quiet=quiet)
|
601 |
+
|
602 |
+
if trainer_state["state_dict"] is not None:
|
603 |
+
trainer.load_state_dict(trainer_state["state_dict"])
|
604 |
+
if hasattr(train_dataloader.sampler, "set_epoch"):
|
605 |
+
train_dataloader.sampler.set_epoch(trainer.trainer.state.epoch)
|
606 |
+
|
607 |
+
trainer.run(
|
608 |
+
train_dataloader,
|
609 |
+
val_dataloader,
|
610 |
+
num_epochs=max_epochs,
|
611 |
+
epoch_length=epoch_length,
|
612 |
+
detect_anomaly=detect_anomaly,
|
613 |
+
)
|
614 |
|
615 |
|
616 |
if __name__ == "__main__":
|
|
|
618 |
args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0
|
619 |
with argbind.scope(args):
|
620 |
with Accelerator() as accel:
|
|
|
|
|
621 |
train(args, accel)
|
scripts/utils/{data/augment.py → augment.py}
RENAMED
@@ -5,19 +5,34 @@ from audiotools import AudioSignal
|
|
5 |
|
6 |
import argbind
|
7 |
import tqdm
|
8 |
-
import torch
|
9 |
|
10 |
|
11 |
-
from
|
12 |
-
|
|
|
|
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
@argbind.bind(without_prefix=True)
|
18 |
def augment(
|
19 |
-
audio_folder: Path
|
20 |
-
dest_folder: Path
|
21 |
n_augmentations: int = 10,
|
22 |
):
|
23 |
"""
|
@@ -26,8 +41,7 @@ def augment(
|
|
26 |
The dest foler will contain a folder for each of the clean dataset's files.
|
27 |
Under each of these folders, there will be a clean file and many augmented files.
|
28 |
"""
|
29 |
-
|
30 |
-
assert dest_folder is not None
|
31 |
audio_files = at.util.find_audio(audio_folder)
|
32 |
|
33 |
for audio_file in tqdm.tqdm(audio_files):
|
@@ -35,33 +49,5 @@ def augment(
|
|
35 |
subdir = subtree / audio_file.stem
|
36 |
subdir.mkdir(parents=True, exist_ok=True)
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
for i, chunk in tqdm.tqdm(enumerate(src.windows(10, 10))):
|
42 |
-
# apply pedalboard transforms
|
43 |
-
for j in range(n_augmentations):
|
44 |
-
# pitch shift between -7 and 7 semitones
|
45 |
-
import random
|
46 |
-
dst = chunk.clone()
|
47 |
-
dst.samples = pitch_shift(
|
48 |
-
dst.samples,
|
49 |
-
shift=random.choice(get_fast_shifts(src.sample_rate,
|
50 |
-
condition=lambda x: x >= 0.25 and x <= 1.0)),
|
51 |
-
sample_rate=src.sample_rate
|
52 |
-
)
|
53 |
-
dst.samples = time_stretch(
|
54 |
-
dst.samples,
|
55 |
-
stretch=random.choice(get_fast_stretches(src.sample_rate,
|
56 |
-
condition=lambda x: x >= 0.667 and x <= 1.5, )),
|
57 |
-
sample_rate=src.sample_rate,
|
58 |
-
)
|
59 |
-
|
60 |
-
dst.cpu().write(subdir / f"{i}-{j}.wav")
|
61 |
-
|
62 |
-
|
63 |
-
if __name__ == "__main__":
|
64 |
-
args = argbind.parse_args()
|
65 |
-
|
66 |
-
with argbind.scope(args):
|
67 |
-
augment()
|
|
|
5 |
|
6 |
import argbind
|
7 |
import tqdm
|
|
|
8 |
|
9 |
|
10 |
+
from pedalboard import (
|
11 |
+
Compressor, Gain, Chorus, LadderFilter, Phaser, Convolution, Reverb, Pedalboard
|
12 |
+
)
|
13 |
+
from pedalboard.io import AudioFile
|
14 |
|
15 |
+
# Read in a whole file, resampling to our desired sample rate:
|
16 |
+
samplerate = 44100.0
|
17 |
+
with AudioFile('guitar-input.wav').resampled_to(samplerate) as f:
|
18 |
+
audio = f.read(f.frames)
|
19 |
+
|
20 |
+
# Make a pretty interesting sounding guitar pedalboard:
|
21 |
+
board = Pedalboard([
|
22 |
+
Compressor(threshold_db=-50, ratio=25),
|
23 |
+
Gain(gain_db=30),
|
24 |
+
Chorus(),
|
25 |
+
LadderFilter(mode=LadderFilter.Mode.HPF12, cutoff_hz=900),
|
26 |
+
Phaser(),
|
27 |
+
Convolution("./guitar_amp.wav", 1.0),
|
28 |
+
Reverb(room_size=0.25),
|
29 |
+
])
|
30 |
|
31 |
|
32 |
@argbind.bind(without_prefix=True)
|
33 |
def augment(
|
34 |
+
audio_folder: Path,
|
35 |
+
dest_folder: Path,
|
36 |
n_augmentations: int = 10,
|
37 |
):
|
38 |
"""
|
|
|
41 |
The dest foler will contain a folder for each of the clean dataset's files.
|
42 |
Under each of these folders, there will be a clean file and many augmented files.
|
43 |
"""
|
44 |
+
|
|
|
45 |
audio_files = at.util.find_audio(audio_folder)
|
46 |
|
47 |
for audio_file in tqdm.tqdm(audio_files):
|
|
|
49 |
subdir = subtree / audio_file.stem
|
50 |
subdir.mkdir(parents=True, exist_ok=True)
|
51 |
|
52 |
+
# apply pedalboard transforms
|
53 |
+
for i in range(n_augmentations):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/utils/gtzan_embeddings.py
DELETED
@@ -1,263 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
TODO: train a linear probe
|
3 |
-
usage:
|
4 |
-
python gtzan_embeddings.py --args.load conf/interface.yml --Interface.device cuda --path_to_gtzan /path/to/gtzan/genres_original --output_dir /path/to/output
|
5 |
-
"""
|
6 |
-
from pathlib import Path
|
7 |
-
from typing import List
|
8 |
-
|
9 |
-
import audiotools as at
|
10 |
-
from audiotools import AudioSignal
|
11 |
-
import argbind
|
12 |
-
import torch
|
13 |
-
import numpy as np
|
14 |
-
import zipfile
|
15 |
-
import json
|
16 |
-
|
17 |
-
from vampnet.interface import Interface
|
18 |
-
import tqdm
|
19 |
-
|
20 |
-
# bind the Interface to argbind
|
21 |
-
Interface = argbind.bind(Interface)
|
22 |
-
|
23 |
-
DEBUG = False
|
24 |
-
|
25 |
-
def smart_plotly_export(fig, save_path):
|
26 |
-
img_format = save_path.split('.')[-1]
|
27 |
-
if img_format == 'html':
|
28 |
-
fig.write_html(save_path)
|
29 |
-
elif img_format == 'bytes':
|
30 |
-
return fig.to_image(format='png')
|
31 |
-
#TODO: come back and make this prettier
|
32 |
-
elif img_format == 'numpy':
|
33 |
-
import io
|
34 |
-
from PIL import Image
|
35 |
-
|
36 |
-
def plotly_fig2array(fig):
|
37 |
-
#convert Plotly fig to an array
|
38 |
-
fig_bytes = fig.to_image(format="png", width=1200, height=700)
|
39 |
-
buf = io.BytesIO(fig_bytes)
|
40 |
-
img = Image.open(buf)
|
41 |
-
return np.asarray(img)
|
42 |
-
|
43 |
-
return plotly_fig2array(fig)
|
44 |
-
elif img_format == 'jpeg' or 'png' or 'webp':
|
45 |
-
fig.write_image(save_path)
|
46 |
-
else:
|
47 |
-
raise ValueError("invalid image format")
|
48 |
-
|
49 |
-
def dim_reduce(emb, labels, save_path, n_components=3, method='tsne', title=''):
|
50 |
-
"""
|
51 |
-
dimensionality reduction for visualization!
|
52 |
-
saves an html plotly figure to save_path
|
53 |
-
parameters:
|
54 |
-
emb (np.ndarray): the samples to be reduces with shape (samples, features)
|
55 |
-
labels (list): list of labels for embedding
|
56 |
-
save_path (str): path where u wanna save ur figure
|
57 |
-
method (str): umap, tsne, or pca
|
58 |
-
title (str): title for ur figure
|
59 |
-
returns:
|
60 |
-
proj (np.ndarray): projection vector with shape (samples, dimensions)
|
61 |
-
"""
|
62 |
-
import pandas as pd
|
63 |
-
import plotly.express as px
|
64 |
-
if method == 'umap':
|
65 |
-
reducer = umap.UMAP(n_components=n_components)
|
66 |
-
elif method == 'tsne':
|
67 |
-
from sklearn.manifold import TSNE
|
68 |
-
reducer = TSNE(n_components=n_components)
|
69 |
-
elif method == 'pca':
|
70 |
-
from sklearn.decomposition import PCA
|
71 |
-
reducer = PCA(n_components=n_components)
|
72 |
-
else:
|
73 |
-
raise ValueError
|
74 |
-
|
75 |
-
proj = reducer.fit_transform(emb)
|
76 |
-
|
77 |
-
if n_components == 2:
|
78 |
-
df = pd.DataFrame(dict(
|
79 |
-
x=proj[:, 0],
|
80 |
-
y=proj[:, 1],
|
81 |
-
instrument=labels
|
82 |
-
))
|
83 |
-
fig = px.scatter(df, x='x', y='y', color='instrument',
|
84 |
-
title=title+f"_{method}")
|
85 |
-
|
86 |
-
elif n_components == 3:
|
87 |
-
df = pd.DataFrame(dict(
|
88 |
-
x=proj[:, 0],
|
89 |
-
y=proj[:, 1],
|
90 |
-
z=proj[:, 2],
|
91 |
-
instrument=labels
|
92 |
-
))
|
93 |
-
fig = px.scatter_3d(df, x='x', y='y', z='z',
|
94 |
-
color='instrument',
|
95 |
-
title=title)
|
96 |
-
else:
|
97 |
-
raise ValueError("cant plot more than 3 components")
|
98 |
-
|
99 |
-
fig.update_traces(marker=dict(size=6,
|
100 |
-
line=dict(width=1,
|
101 |
-
color='DarkSlateGrey')),
|
102 |
-
selector=dict(mode='markers'))
|
103 |
-
|
104 |
-
return smart_plotly_export(fig, save_path)
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
# per JukeMIR, we want the emebddings from the middle layer?
|
109 |
-
def vampnet_embed(sig: AudioSignal, interface: Interface, layer=10):
|
110 |
-
with torch.inference_mode():
|
111 |
-
# preprocess the signal
|
112 |
-
sig = interface.preprocess(sig)
|
113 |
-
|
114 |
-
# get the coarse vampnet model
|
115 |
-
vampnet = interface.coarse
|
116 |
-
|
117 |
-
# get the tokens
|
118 |
-
z = interface.encode(sig)[:, :vampnet.n_codebooks, :]
|
119 |
-
z_latents = vampnet.embedding.from_codes(z, interface.codec)
|
120 |
-
|
121 |
-
# do a forward pass through the model, get the embeddings
|
122 |
-
_z, embeddings = vampnet(z_latents, return_activations=True)
|
123 |
-
# print(f"got embeddings with shape {embeddings.shape}")
|
124 |
-
# [layer, batch, time, n_dims]
|
125 |
-
# [20, 1, 600ish, 768]
|
126 |
-
|
127 |
-
|
128 |
-
# squeeze batch dim (1 bc layer should be dim 0)
|
129 |
-
assert embeddings.shape[1] == 1, f"expected batch dim to be 1, got {embeddings.shape[0]}"
|
130 |
-
embeddings = embeddings.squeeze(1)
|
131 |
-
|
132 |
-
num_layers = embeddings.shape[0]
|
133 |
-
assert layer < num_layers, f"layer {layer} is out of bounds for model with {num_layers} layers"
|
134 |
-
|
135 |
-
# do meanpooling over the time dimension
|
136 |
-
embeddings = embeddings.mean(dim=-2)
|
137 |
-
# [20, 768]
|
138 |
-
|
139 |
-
# return the embeddings
|
140 |
-
return embeddings
|
141 |
-
|
142 |
-
from dataclasses import dataclass, fields
|
143 |
-
@dataclass
|
144 |
-
class Embedding:
|
145 |
-
genre: str
|
146 |
-
filename: str
|
147 |
-
embedding: np.ndarray
|
148 |
-
|
149 |
-
def save(self, path):
|
150 |
-
"""Save the Embedding object to a given path as a zip file."""
|
151 |
-
with zipfile.ZipFile(path, 'w') as archive:
|
152 |
-
|
153 |
-
# Save numpy array
|
154 |
-
with archive.open('embedding.npy', 'w') as f:
|
155 |
-
np.save(f, self.embedding)
|
156 |
-
|
157 |
-
# Save non-numpy data as json
|
158 |
-
non_numpy_data = {f.name: getattr(self, f.name) for f in fields(self) if f.name != 'embedding'}
|
159 |
-
with archive.open('data.json', 'w') as f:
|
160 |
-
f.write(json.dumps(non_numpy_data).encode('utf-8'))
|
161 |
-
|
162 |
-
@classmethod
|
163 |
-
def load(cls, path):
|
164 |
-
"""Load the Embedding object from a given zip path."""
|
165 |
-
with zipfile.ZipFile(path, 'r') as archive:
|
166 |
-
|
167 |
-
# Load numpy array
|
168 |
-
with archive.open('embedding.npy') as f:
|
169 |
-
embedding = np.load(f)
|
170 |
-
|
171 |
-
# Load non-numpy data from json
|
172 |
-
with archive.open('data.json') as f:
|
173 |
-
data = json.loads(f.read().decode('utf-8'))
|
174 |
-
|
175 |
-
return cls(embedding=embedding, **data)
|
176 |
-
|
177 |
-
|
178 |
-
@argbind.bind(without_prefix=True)
|
179 |
-
def main(
|
180 |
-
path_to_gtzan: str = None,
|
181 |
-
cache_dir: str = "./.gtzan_emb_cache",
|
182 |
-
output_dir: str = "./gtzan_vampnet_embeddings",
|
183 |
-
layers: List[int] = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19]
|
184 |
-
):
|
185 |
-
path_to_gtzan = Path(path_to_gtzan)
|
186 |
-
assert path_to_gtzan.exists(), f"{path_to_gtzan} does not exist"
|
187 |
-
|
188 |
-
cache_dir = Path(cache_dir)
|
189 |
-
output_dir = Path(output_dir)
|
190 |
-
output_dir.mkdir(exist_ok=True, parents=True)
|
191 |
-
|
192 |
-
# load our interface
|
193 |
-
# argbind will automatically load the default config,
|
194 |
-
interface = Interface()
|
195 |
-
|
196 |
-
# gtzan should have a folder for each genre, so let's get the list of genres
|
197 |
-
genres = [Path(x).name for x in path_to_gtzan.iterdir() if x.is_dir()]
|
198 |
-
print(f"Found {len(genres)} genres")
|
199 |
-
print(f"genres: {genres}")
|
200 |
-
|
201 |
-
# collect audio files, genres, and embeddings
|
202 |
-
data = []
|
203 |
-
for genre in genres:
|
204 |
-
audio_files = list(at.util.find_audio(path_to_gtzan / genre))
|
205 |
-
print(f"Found {len(audio_files)} audio files for genre {genre}")
|
206 |
-
|
207 |
-
for audio_file in tqdm.tqdm(audio_files, desc=f"embedding genre {genre}"):
|
208 |
-
# check if we have a cached embedding for this file
|
209 |
-
cached_path = (cache_dir / f"{genre}_{audio_file.stem}.emb")
|
210 |
-
if cached_path.exists():
|
211 |
-
# if so, load it
|
212 |
-
if DEBUG:
|
213 |
-
print(f"loading cached embedding for {cached_path.stem}")
|
214 |
-
embedding = Embedding.load(cached_path)
|
215 |
-
data.append(embedding)
|
216 |
-
else:
|
217 |
-
try:
|
218 |
-
sig = AudioSignal(audio_file)
|
219 |
-
except Exception as e:
|
220 |
-
print(f"failed to load {audio_file.name} with error {e}")
|
221 |
-
print(f"skipping {audio_file.name}")
|
222 |
-
continue
|
223 |
-
|
224 |
-
# gets the embedding
|
225 |
-
emb = vampnet_embed(sig, interface).cpu().numpy()
|
226 |
-
|
227 |
-
# create an embedding we can save/load
|
228 |
-
embedding = Embedding(
|
229 |
-
genre=genre,
|
230 |
-
filename=audio_file.name,
|
231 |
-
embedding=emb
|
232 |
-
)
|
233 |
-
|
234 |
-
# cache the embeddings
|
235 |
-
cached_path.parent.mkdir(exist_ok=True, parents=True)
|
236 |
-
embedding.save(cached_path)
|
237 |
-
|
238 |
-
# now, let's do a dim reduction on the embeddings
|
239 |
-
# and visualize them.
|
240 |
-
|
241 |
-
# collect a list of embeddings and labels
|
242 |
-
embeddings = [d.embedding for d in data]
|
243 |
-
labels = [d.genre for d in data]
|
244 |
-
|
245 |
-
# convert the embeddings to a numpy array
|
246 |
-
embeddings = np.stack(embeddings)
|
247 |
-
|
248 |
-
# do dimensionality reduction for each layer we're given
|
249 |
-
for layer in tqdm.tqdm(layers, desc="dim reduction"):
|
250 |
-
dim_reduce(
|
251 |
-
embeddings[:, layer, :], labels,
|
252 |
-
save_path=str(output_dir / f'vampnet-gtzan-layer={layer}.html'),
|
253 |
-
n_components=2, method='tsne',
|
254 |
-
title=f'vampnet-gtzan-layer={layer}'
|
255 |
-
)
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
if __name__ == "__main__":
|
261 |
-
args = argbind.parse_args()
|
262 |
-
with argbind.scope(args):
|
263 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|