Spaces:
Sleeping
Sleeping
c2f prompts
#4
by
hugggof
- opened
- .gitignore +3 -6
- README.md +2 -11
- app.py +42 -227
- conf/lora/lora.yml +3 -5
- conf/vampnet.yml +1 -1
- requirements.txt +2 -4
- scripts/exp/fine_tune.py +4 -3
- scripts/exp/train.py +15 -23
- scripts/utils/{data/augment.py → augment.py} +24 -38
- scripts/utils/gtzan_embeddings.py +0 -263
- scripts/utils/{data/maestro-reorg.py → maestro-reorg.py} +0 -0
- scripts/utils/remove_quiet_files.py +0 -29
- scripts/utils/split_long_audio_file.py +0 -34
- scripts/utils/xeno-canto-dl.py +0 -234
- setup.py +2 -3
- vampnet/interface.py +6 -5
- vampnet/mask.py +20 -38
- vampnet/modules/transformer.py +51 -138
.gitignore
CHANGED
@@ -175,14 +175,11 @@ lyrebird-audio-codec
|
|
175 |
samples-*/**
|
176 |
|
177 |
gradio-outputs/
|
178 |
-
models/
|
179 |
samples*/
|
180 |
models-all/
|
181 |
models.zip
|
|
|
|
|
|
|
182 |
.git-old
|
183 |
conf/generated/*
|
184 |
-
runs*/
|
185 |
-
|
186 |
-
|
187 |
-
gtzan.zip
|
188 |
-
.gtzan_emb_cache
|
|
|
175 |
samples-*/**
|
176 |
|
177 |
gradio-outputs/
|
|
|
178 |
samples*/
|
179 |
models-all/
|
180 |
models.zip
|
181 |
+
audiotools/
|
182 |
+
descript-audio-codec/
|
183 |
+
# *.pth
|
184 |
.git-old
|
185 |
conf/generated/*
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
@@ -7,7 +7,6 @@ sdk: gradio
|
|
7 |
sdk_version: 3.36.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
python_version: 3.9
|
11 |
---
|
12 |
|
13 |
# VampNet
|
@@ -19,15 +18,7 @@ you can try vampnet in a co-creative looper called unloop. see this link: https:
|
|
19 |
|
20 |
# Setting up
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
you'll need a Python 3.9 environment to run VampNet. This is due to a [known issue with madmom](https://github.com/hugofloresgarcia/vampnet/issues/15).
|
25 |
-
|
26 |
-
(for example, using conda)
|
27 |
-
```bash
|
28 |
-
conda create -n vampnet python=3.9
|
29 |
-
conda activate vampnet
|
30 |
-
```
|
31 |
|
32 |
|
33 |
install VampNet
|
@@ -100,7 +91,7 @@ python scripts/exp/train.py --args.load conf/<fine_tune_name>/c2f.yml
|
|
100 |
|
101 |
launch the interface:
|
102 |
```bash
|
103 |
-
python
|
104 |
```
|
105 |
|
106 |
|
|
|
7 |
sdk_version: 3.36.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
# VampNet
|
|
|
18 |
|
19 |
# Setting up
|
20 |
|
21 |
+
Requires Python 3.9 or later.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
|
24 |
install VampNet
|
|
|
91 |
|
92 |
launch the interface:
|
93 |
```bash
|
94 |
+
python demo.py --args.load conf/generated/<fine_tune_name>/interface.yml
|
95 |
```
|
96 |
|
97 |
|
app.py
CHANGED
@@ -1,12 +1,3 @@
|
|
1 |
-
# huggingface space exclusive
|
2 |
-
import os
|
3 |
-
|
4 |
-
# print("installing pyharp")
|
5 |
-
# os.system('pip install "pyharp@git+https://github.com/audacitorch/pyharp.git"')
|
6 |
-
# print("installing madmom")
|
7 |
-
os.system('pip install cython')
|
8 |
-
os.system('pip install madmom')
|
9 |
-
|
10 |
from pathlib import Path
|
11 |
from typing import Tuple
|
12 |
import yaml
|
@@ -24,38 +15,27 @@ import gradio as gr
|
|
24 |
from vampnet.interface import Interface
|
25 |
from vampnet import mask as pmask
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
# loader = AudioLoader()
|
32 |
# AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
shift=interval,
|
42 |
-
sample_rate=signal.sample_rate
|
43 |
-
)
|
44 |
-
return signal
|
45 |
-
|
46 |
-
def load_interface():
|
47 |
-
interface = Interface(
|
48 |
-
coarse_ckpt="./models/vampnet/coarse.pth",
|
49 |
-
coarse2fine_ckpt="./models/vampnet/c2f.pth",
|
50 |
-
codec_ckpt="./models/vampnet/codec.pth",
|
51 |
-
wavebeat_ckpt="./models/wavebeat.pth",
|
52 |
-
device="cuda" if torch.cuda.is_available() else "cpu",
|
53 |
-
)
|
54 |
-
return interface
|
55 |
-
|
56 |
|
57 |
-
|
|
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
OUT_DIR = Path("gradio-outputs")
|
61 |
OUT_DIR.mkdir(exist_ok=True, parents=True)
|
@@ -70,7 +50,7 @@ def load_audio(file):
|
|
70 |
)
|
71 |
sig = interface.preprocess(sig)
|
72 |
|
73 |
-
out_dir = OUT_DIR /
|
74 |
out_dir.mkdir(parents=True, exist_ok=True)
|
75 |
sig.write(out_dir / "input.wav")
|
76 |
return sig.path_to_file
|
@@ -88,10 +68,6 @@ def _vamp(data, return_mask=False):
|
|
88 |
out_dir = OUT_DIR / str(uuid.uuid4())
|
89 |
out_dir.mkdir()
|
90 |
sig = at.AudioSignal(data[input_audio])
|
91 |
-
sig = interface.preprocess(sig)
|
92 |
-
|
93 |
-
if data[pitch_shift_amt] != 0:
|
94 |
-
sig = shift_pitch(sig, data[pitch_shift_amt])
|
95 |
|
96 |
z = interface.encode(sig)
|
97 |
|
@@ -131,58 +107,24 @@ def _vamp(data, return_mask=False):
|
|
131 |
mask = pmask.codebook_unmask(mask, ncc)
|
132 |
|
133 |
|
134 |
-
print(f"dropout {data[dropout]}")
|
135 |
-
print(f"masktemp {data[masktemp]}")
|
136 |
-
print(f"sampletemp {data[sampletemp]}")
|
137 |
-
print(f"top_p {data[top_p]}")
|
138 |
-
print(f"prefix_s {data[prefix_s]}")
|
139 |
-
print(f"suffix_s {data[suffix_s]}")
|
140 |
-
print(f"rand_mask_intensity {data[rand_mask_intensity]}")
|
141 |
-
print(f"num_steps {data[num_steps]}")
|
142 |
-
print(f"periodic_p {data[periodic_p]}")
|
143 |
-
print(f"periodic_w {data[periodic_w]}")
|
144 |
-
print(f"n_conditioning_codebooks {data[n_conditioning_codebooks]}")
|
145 |
-
print(f"use_coarse2fine {data[use_coarse2fine]}")
|
146 |
-
print(f"onset_mask_width {data[onset_mask_width]}")
|
147 |
-
print(f"beat_mask_width {data[beat_mask_width]}")
|
148 |
-
print(f"beat_mask_downbeats {data[beat_mask_downbeats]}")
|
149 |
-
print(f"stretch_factor {data[stretch_factor]}")
|
150 |
-
print(f"seed {data[seed]}")
|
151 |
-
print(f"pitch_shift_amt {data[pitch_shift_amt]}")
|
152 |
-
print(f"sample_cutoff {data[sample_cutoff]}")
|
153 |
-
|
154 |
-
|
155 |
-
_top_p = data[top_p] if data[top_p] > 0 else None
|
156 |
# save the mask as a txt file
|
157 |
np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
|
158 |
|
159 |
-
_seed = data[seed] if data[seed] > 0 else None
|
160 |
zv, mask_z = interface.coarse_vamp(
|
161 |
z,
|
162 |
mask=mask,
|
163 |
sampling_steps=data[num_steps],
|
164 |
-
|
165 |
-
sampling_temperature=data[sampletemp],
|
166 |
return_mask=True,
|
167 |
typical_filtering=data[typical_filtering],
|
168 |
typical_mass=data[typical_mass],
|
169 |
typical_min_tokens=data[typical_min_tokens],
|
170 |
-
top_p=_top_p,
|
171 |
gen_fn=interface.coarse.generate,
|
172 |
-
seed=_seed,
|
173 |
-
sample_cutoff=data[sample_cutoff],
|
174 |
)
|
175 |
|
176 |
if use_coarse2fine:
|
177 |
-
zv = interface.coarse_to_fine(
|
178 |
-
zv,
|
179 |
-
mask_temperature=data[masktemp]*10,
|
180 |
-
sampling_temperature=data[sampletemp],
|
181 |
-
mask=mask,
|
182 |
-
sampling_steps=data[num_steps] // 2,
|
183 |
-
sample_cutoff=data[sample_cutoff],
|
184 |
-
seed=_seed,
|
185 |
-
)
|
186 |
|
187 |
sig = interface.to_signal(zv).cpu()
|
188 |
print("done")
|
@@ -215,9 +157,7 @@ def save_vamp(data):
|
|
215 |
sig_out.write(out_dir / "output.wav")
|
216 |
|
217 |
_data = {
|
218 |
-
"
|
219 |
-
"sampletemp": data[sampletemp],
|
220 |
-
"top_p": data[top_p],
|
221 |
"prefix_s": data[prefix_s],
|
222 |
"suffix_s": data[suffix_s],
|
223 |
"rand_mask_intensity": data[rand_mask_intensity],
|
@@ -228,8 +168,6 @@ def save_vamp(data):
|
|
228 |
"n_conditioning_codebooks": data[n_conditioning_codebooks],
|
229 |
"use_coarse2fine": data[use_coarse2fine],
|
230 |
"stretch_factor": data[stretch_factor],
|
231 |
-
"seed": data[seed],
|
232 |
-
"samplecutoff": data[sample_cutoff],
|
233 |
}
|
234 |
|
235 |
# save with yaml
|
@@ -245,54 +183,13 @@ def save_vamp(data):
|
|
245 |
return f"saved! your save code is {out_dir.stem}", zip_path
|
246 |
|
247 |
|
248 |
-
def harp_vamp(_input_audio, _beat_mask_width, _sampletemp):
|
249 |
-
|
250 |
-
out_dir = OUT_DIR / str(uuid.uuid4())
|
251 |
-
out_dir.mkdir()
|
252 |
-
sig = at.AudioSignal(_input_audio)
|
253 |
-
sig = interface.preprocess(sig)
|
254 |
-
|
255 |
-
z = interface.encode(sig)
|
256 |
-
|
257 |
-
# build the mask
|
258 |
-
mask = pmask.linear_random(z, 1.0)
|
259 |
-
if _beat_mask_width > 0:
|
260 |
-
beat_mask = interface.make_beat_mask(
|
261 |
-
sig,
|
262 |
-
after_beat_s=(_beat_mask_width/1000),
|
263 |
-
)
|
264 |
-
mask = pmask.mask_and(mask, beat_mask)
|
265 |
-
|
266 |
-
# save the mask as a txt file
|
267 |
-
zv, mask_z = interface.coarse_vamp(
|
268 |
-
z,
|
269 |
-
mask=mask,
|
270 |
-
sampling_temperature=_sampletemp,
|
271 |
-
return_mask=True,
|
272 |
-
gen_fn=interface.coarse.generate,
|
273 |
-
)
|
274 |
-
|
275 |
-
|
276 |
-
zv = interface.coarse_to_fine(
|
277 |
-
zv,
|
278 |
-
sampling_temperature=_sampletemp,
|
279 |
-
mask=mask,
|
280 |
-
)
|
281 |
-
|
282 |
-
sig = interface.to_signal(zv).cpu()
|
283 |
-
print("done")
|
284 |
-
|
285 |
-
sig.write(out_dir / "output.wav")
|
286 |
-
|
287 |
-
return sig.path_to_file
|
288 |
-
|
289 |
with gr.Blocks() as demo:
|
290 |
|
291 |
with gr.Row():
|
292 |
with gr.Column():
|
293 |
-
gr.Markdown("# VampNet
|
294 |
gr.Markdown("""## Description:
|
295 |
-
This is a demo of
|
296 |
You can control the extent and nature of variation with a set of manual controls and presets.
|
297 |
Use this interface to experiment with different mask settings and explore the audio outputs.
|
298 |
""")
|
@@ -300,8 +197,8 @@ with gr.Blocks() as demo:
|
|
300 |
gr.Markdown("""
|
301 |
## Instructions:
|
302 |
1. You can start by uploading some audio, or by loading the example audio.
|
303 |
-
2. Choose a preset for the vamp operation, or manually adjust the controls to customize the mask settings.
|
304 |
-
3. Click the "generate (vamp)!!!" button to
|
305 |
4. Optionally, you can add some notes and save the result.
|
306 |
5. You can also use the output as the new input and continue experimenting!
|
307 |
""")
|
@@ -352,25 +249,19 @@ with gr.Blocks() as demo:
|
|
352 |
"beat_mask_downbeats": False,
|
353 |
},
|
354 |
"slight periodic variation": {
|
355 |
-
"periodic_p":
|
356 |
-
"onset_mask_width":
|
357 |
-
"beat_mask_width": 0,
|
358 |
-
"beat_mask_downbeats": False,
|
359 |
-
},
|
360 |
-
"moderate periodic variation": {
|
361 |
-
"periodic_p": 13,
|
362 |
-
"onset_mask_width": 5,
|
363 |
"beat_mask_width": 0,
|
364 |
"beat_mask_downbeats": False,
|
365 |
},
|
366 |
"strong periodic variation": {
|
367 |
-
"periodic_p":
|
368 |
"onset_mask_width": 5,
|
369 |
"beat_mask_width": 0,
|
370 |
"beat_mask_downbeats": False,
|
371 |
},
|
372 |
"very strong periodic variation": {
|
373 |
-
"periodic_p":
|
374 |
"onset_mask_width": 5,
|
375 |
"beat_mask_width": 0,
|
376 |
"beat_mask_downbeats": False,
|
@@ -378,15 +269,9 @@ with gr.Blocks() as demo:
|
|
378 |
"beat-driven variation": {
|
379 |
"periodic_p": 0,
|
380 |
"onset_mask_width": 0,
|
381 |
-
"beat_mask_width":
|
382 |
"beat_mask_downbeats": False,
|
383 |
},
|
384 |
-
"beat-driven variation (downbeats only)": {
|
385 |
-
"periodic_p": 0,
|
386 |
-
"onset_mask_width": 0,
|
387 |
-
"beat_mask_width": 50,
|
388 |
-
"beat_mask_downbeats": True,
|
389 |
-
},
|
390 |
"beat-driven variation (downbeats only, strong)": {
|
391 |
"periodic_p": 0,
|
392 |
"onset_mask_width": 0,
|
@@ -408,20 +293,20 @@ with gr.Blocks() as demo:
|
|
408 |
minimum=0,
|
409 |
maximum=128,
|
410 |
step=1,
|
411 |
-
value=
|
412 |
)
|
413 |
|
414 |
|
415 |
onset_mask_width = gr.Slider(
|
416 |
label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
|
417 |
minimum=0,
|
418 |
-
maximum=
|
419 |
step=1,
|
420 |
value=5,
|
421 |
)
|
422 |
|
423 |
beat_mask_width = gr.Slider(
|
424 |
-
label="beat
|
425 |
minimum=0,
|
426 |
maximum=200,
|
427 |
value=0,
|
@@ -433,14 +318,6 @@ with gr.Blocks() as demo:
|
|
433 |
|
434 |
|
435 |
with gr.Accordion("extras ", open=False):
|
436 |
-
pitch_shift_amt = gr.Slider(
|
437 |
-
label="pitch shift amount (semitones)",
|
438 |
-
minimum=-12,
|
439 |
-
maximum=12,
|
440 |
-
step=1,
|
441 |
-
value=0,
|
442 |
-
)
|
443 |
-
|
444 |
rand_mask_intensity = gr.Slider(
|
445 |
label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
|
446 |
minimum=0.0,
|
@@ -500,34 +377,21 @@ with gr.Blocks() as demo:
|
|
500 |
value=0.0
|
501 |
)
|
502 |
|
503 |
-
|
504 |
-
label="
|
505 |
minimum=0.0,
|
506 |
-
maximum=100.0,
|
507 |
-
value=1.5
|
508 |
-
)
|
509 |
-
sampletemp = gr.Slider(
|
510 |
-
label="sample temperature",
|
511 |
-
minimum=0.1,
|
512 |
maximum=10.0,
|
513 |
-
value=1.
|
514 |
-
step=0.001
|
515 |
)
|
516 |
-
|
517 |
|
518 |
|
519 |
with gr.Accordion("sampling settings", open=False):
|
520 |
-
top_p = gr.Slider(
|
521 |
-
label="top p (0.0 = off)",
|
522 |
-
minimum=0.0,
|
523 |
-
maximum=1.0,
|
524 |
-
value=0.9
|
525 |
-
)
|
526 |
typical_filtering = gr.Checkbox(
|
527 |
label="typical filtering ",
|
528 |
value=False
|
529 |
)
|
530 |
-
typical_mass = gr.Slider(
|
531 |
label="typical mass (should probably stay between 0.1 and 0.5)",
|
532 |
minimum=0.01,
|
533 |
maximum=0.99,
|
@@ -540,13 +404,6 @@ with gr.Blocks() as demo:
|
|
540 |
step=1,
|
541 |
value=64
|
542 |
)
|
543 |
-
sample_cutoff = gr.Slider(
|
544 |
-
label="sample cutoff",
|
545 |
-
minimum=0.0,
|
546 |
-
maximum=1.0,
|
547 |
-
value=0.5,
|
548 |
-
step=0.01
|
549 |
-
)
|
550 |
|
551 |
use_coarse2fine = gr.Checkbox(
|
552 |
label="use coarse2fine",
|
@@ -571,24 +428,8 @@ with gr.Blocks() as demo:
|
|
571 |
)
|
572 |
|
573 |
|
574 |
-
seed = gr.Number(
|
575 |
-
label="seed (0 for random)",
|
576 |
-
value=0,
|
577 |
-
precision=0,
|
578 |
-
)
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
# mask settings
|
583 |
with gr.Column():
|
584 |
-
|
585 |
-
# lora_choice = gr.Dropdown(
|
586 |
-
# label="lora choice",
|
587 |
-
# choices=list(loras.keys()),
|
588 |
-
# value=LORA_NONE,
|
589 |
-
# visible=False
|
590 |
-
# )
|
591 |
-
|
592 |
vamp_button = gr.Button("generate (vamp)!!!")
|
593 |
output_audio = gr.Audio(
|
594 |
label="output audio",
|
@@ -614,9 +455,7 @@ with gr.Blocks() as demo:
|
|
614 |
_inputs = {
|
615 |
input_audio,
|
616 |
num_steps,
|
617 |
-
|
618 |
-
sampletemp,
|
619 |
-
top_p,
|
620 |
prefix_s, suffix_s,
|
621 |
rand_mask_intensity,
|
622 |
periodic_p, periodic_w,
|
@@ -629,11 +468,7 @@ with gr.Blocks() as demo:
|
|
629 |
typical_mass,
|
630 |
typical_min_tokens,
|
631 |
beat_mask_width,
|
632 |
-
beat_mask_downbeats
|
633 |
-
seed,
|
634 |
-
# lora_choice,
|
635 |
-
pitch_shift_amt,
|
636 |
-
sample_cutoff
|
637 |
}
|
638 |
|
639 |
# connect widgets
|
@@ -663,24 +498,4 @@ with gr.Blocks() as demo:
|
|
663 |
outputs=[thank_you, download_file]
|
664 |
)
|
665 |
|
666 |
-
|
667 |
-
harp_inputs = [
|
668 |
-
input_audio,
|
669 |
-
beat_mask_width,
|
670 |
-
sampletemp,
|
671 |
-
]
|
672 |
-
|
673 |
-
build_endpoint(
|
674 |
-
inputs=harp_inputs,
|
675 |
-
output=output_audio,
|
676 |
-
process_fn=harp_vamp,
|
677 |
-
card=ModelCard(
|
678 |
-
name="vampnet",
|
679 |
-
description="Generate variations on music input, based on small prompts around the beat. NOTE: vampnet's has a maximum context length of 10 seconds. Please split all audio clips into 10 second chunks, or processing will result in an error. ",
|
680 |
-
author="Hugo Flores García",
|
681 |
-
tags=["music", "generative"]
|
682 |
-
),
|
683 |
-
visible=False
|
684 |
-
)
|
685 |
-
|
686 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from pathlib import Path
|
2 |
from typing import Tuple
|
3 |
import yaml
|
|
|
15 |
from vampnet.interface import Interface
|
16 |
from vampnet import mask as pmask
|
17 |
|
18 |
+
# Interface = argbind.bind(Interface)
|
|
|
|
|
|
|
|
|
19 |
# AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
|
20 |
|
21 |
+
interface = Interface(
|
22 |
+
coarse_ckpt="./models/vampnet/coarse.pth",
|
23 |
+
coarse2fine_ckpt="./models/vampnet/c2f.pth",
|
24 |
+
codec_ckpt="./models/vampnet/codec.pth",
|
25 |
+
wavebeat_ckpt="./models/wavebeat.pth",
|
26 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
27 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
+
# loader = AudioLoader()
|
30 |
+
print(f"interface device is {interface.device}")
|
31 |
|
32 |
+
# dataset = at.data.datasets.AudioDataset(
|
33 |
+
# loader,
|
34 |
+
# sample_rate=interface.codec.sample_rate,
|
35 |
+
# duration=interface.coarse.chunk_size_s,
|
36 |
+
# n_examples=5000,
|
37 |
+
# without_replacement=True,
|
38 |
+
# )
|
39 |
|
40 |
OUT_DIR = Path("gradio-outputs")
|
41 |
OUT_DIR.mkdir(exist_ok=True, parents=True)
|
|
|
50 |
)
|
51 |
sig = interface.preprocess(sig)
|
52 |
|
53 |
+
out_dir = OUT_DIR / str(uuid.uuid4())
|
54 |
out_dir.mkdir(parents=True, exist_ok=True)
|
55 |
sig.write(out_dir / "input.wav")
|
56 |
return sig.path_to_file
|
|
|
68 |
out_dir = OUT_DIR / str(uuid.uuid4())
|
69 |
out_dir.mkdir()
|
70 |
sig = at.AudioSignal(data[input_audio])
|
|
|
|
|
|
|
|
|
71 |
|
72 |
z = interface.encode(sig)
|
73 |
|
|
|
107 |
mask = pmask.codebook_unmask(mask, ncc)
|
108 |
|
109 |
|
110 |
+
print(f"created mask with: linear random {data[rand_mask_intensity]}, inpaint {data[prefix_s]}:{data[suffix_s]}, periodic {data[periodic_p]}:{data[periodic_w]}, dropout {data[dropout]}, codebook unmask {ncc}, onset mask {data[onset_mask_width]}, num steps {data[num_steps]}, init temp {data[temp]}, use coarse2fine {data[use_coarse2fine]}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
# save the mask as a txt file
|
112 |
np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
|
113 |
|
|
|
114 |
zv, mask_z = interface.coarse_vamp(
|
115 |
z,
|
116 |
mask=mask,
|
117 |
sampling_steps=data[num_steps],
|
118 |
+
temperature=float(data[temp]*10),
|
|
|
119 |
return_mask=True,
|
120 |
typical_filtering=data[typical_filtering],
|
121 |
typical_mass=data[typical_mass],
|
122 |
typical_min_tokens=data[typical_min_tokens],
|
|
|
123 |
gen_fn=interface.coarse.generate,
|
|
|
|
|
124 |
)
|
125 |
|
126 |
if use_coarse2fine:
|
127 |
+
zv = interface.coarse_to_fine(zv, temperature=data[temp], mask=mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
sig = interface.to_signal(zv).cpu()
|
130 |
print("done")
|
|
|
157 |
sig_out.write(out_dir / "output.wav")
|
158 |
|
159 |
_data = {
|
160 |
+
"temp": data[temp],
|
|
|
|
|
161 |
"prefix_s": data[prefix_s],
|
162 |
"suffix_s": data[suffix_s],
|
163 |
"rand_mask_intensity": data[rand_mask_intensity],
|
|
|
168 |
"n_conditioning_codebooks": data[n_conditioning_codebooks],
|
169 |
"use_coarse2fine": data[use_coarse2fine],
|
170 |
"stretch_factor": data[stretch_factor],
|
|
|
|
|
171 |
}
|
172 |
|
173 |
# save with yaml
|
|
|
183 |
return f"saved! your save code is {out_dir.stem}", zip_path
|
184 |
|
185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
with gr.Blocks() as demo:
|
187 |
|
188 |
with gr.Row():
|
189 |
with gr.Column():
|
190 |
+
gr.Markdown("# VampNet")
|
191 |
gr.Markdown("""## Description:
|
192 |
+
This is a demo of VampNet, a masked generative music model capable of doing music variations.
|
193 |
You can control the extent and nature of variation with a set of manual controls and presets.
|
194 |
Use this interface to experiment with different mask settings and explore the audio outputs.
|
195 |
""")
|
|
|
197 |
gr.Markdown("""
|
198 |
## Instructions:
|
199 |
1. You can start by uploading some audio, or by loading the example audio.
|
200 |
+
2. Choose a preset for the vamp operation, or manually adjust the controls to customize the mask settings. Click the load preset button.
|
201 |
+
3. Click the "generate (vamp)!!!" button to generate audio. Listen to the output audio, and the masked audio to hear the mask hints.
|
202 |
4. Optionally, you can add some notes and save the result.
|
203 |
5. You can also use the output as the new input and continue experimenting!
|
204 |
""")
|
|
|
249 |
"beat_mask_downbeats": False,
|
250 |
},
|
251 |
"slight periodic variation": {
|
252 |
+
"periodic_p": 7,
|
253 |
+
"onset_mask_width": 0,
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
"beat_mask_width": 0,
|
255 |
"beat_mask_downbeats": False,
|
256 |
},
|
257 |
"strong periodic variation": {
|
258 |
+
"periodic_p": 13,
|
259 |
"onset_mask_width": 5,
|
260 |
"beat_mask_width": 0,
|
261 |
"beat_mask_downbeats": False,
|
262 |
},
|
263 |
"very strong periodic variation": {
|
264 |
+
"periodic_p": 17,
|
265 |
"onset_mask_width": 5,
|
266 |
"beat_mask_width": 0,
|
267 |
"beat_mask_downbeats": False,
|
|
|
269 |
"beat-driven variation": {
|
270 |
"periodic_p": 0,
|
271 |
"onset_mask_width": 0,
|
272 |
+
"beat_mask_width": 20,
|
273 |
"beat_mask_downbeats": False,
|
274 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
"beat-driven variation (downbeats only, strong)": {
|
276 |
"periodic_p": 0,
|
277 |
"onset_mask_width": 0,
|
|
|
293 |
minimum=0,
|
294 |
maximum=128,
|
295 |
step=1,
|
296 |
+
value=13,
|
297 |
)
|
298 |
|
299 |
|
300 |
onset_mask_width = gr.Slider(
|
301 |
label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
|
302 |
minimum=0,
|
303 |
+
maximum=20,
|
304 |
step=1,
|
305 |
value=5,
|
306 |
)
|
307 |
|
308 |
beat_mask_width = gr.Slider(
|
309 |
+
label="beat mask width (in milliseconds)",
|
310 |
minimum=0,
|
311 |
maximum=200,
|
312 |
value=0,
|
|
|
318 |
|
319 |
|
320 |
with gr.Accordion("extras ", open=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
rand_mask_intensity = gr.Slider(
|
322 |
label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
|
323 |
minimum=0.0,
|
|
|
377 |
value=0.0
|
378 |
)
|
379 |
|
380 |
+
temp = gr.Slider(
|
381 |
+
label="temperature",
|
382 |
minimum=0.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
383 |
maximum=10.0,
|
384 |
+
value=1.8
|
|
|
385 |
)
|
386 |
+
|
387 |
|
388 |
|
389 |
with gr.Accordion("sampling settings", open=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
390 |
typical_filtering = gr.Checkbox(
|
391 |
label="typical filtering ",
|
392 |
value=False
|
393 |
)
|
394 |
+
typical_mass = gr.Slider(
|
395 |
label="typical mass (should probably stay between 0.1 and 0.5)",
|
396 |
minimum=0.01,
|
397 |
maximum=0.99,
|
|
|
404 |
step=1,
|
405 |
value=64
|
406 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
407 |
|
408 |
use_coarse2fine = gr.Checkbox(
|
409 |
label="use coarse2fine",
|
|
|
428 |
)
|
429 |
|
430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
431 |
# mask settings
|
432 |
with gr.Column():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
433 |
vamp_button = gr.Button("generate (vamp)!!!")
|
434 |
output_audio = gr.Audio(
|
435 |
label="output audio",
|
|
|
455 |
_inputs = {
|
456 |
input_audio,
|
457 |
num_steps,
|
458 |
+
temp,
|
|
|
|
|
459 |
prefix_s, suffix_s,
|
460 |
rand_mask_intensity,
|
461 |
periodic_p, periodic_w,
|
|
|
468 |
typical_mass,
|
469 |
typical_min_tokens,
|
470 |
beat_mask_width,
|
471 |
+
beat_mask_downbeats
|
|
|
|
|
|
|
|
|
472 |
}
|
473 |
|
474 |
# connect widgets
|
|
|
498 |
outputs=[thank_you, download_file]
|
499 |
)
|
500 |
|
501 |
+
demo.queue().launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/lora/lora.yml
CHANGED
@@ -4,16 +4,14 @@ $include:
|
|
4 |
fine_tune: True
|
5 |
|
6 |
train/AudioDataset.n_examples: 100000000
|
7 |
-
val/AudioDataset.n_examples:
|
8 |
|
9 |
|
10 |
NoamScheduler.warmup: 500
|
11 |
|
12 |
-
batch_size:
|
13 |
num_workers: 7
|
14 |
-
save_iters: [
|
15 |
-
sample_freq: 1000
|
16 |
-
val_freq: 500
|
17 |
|
18 |
AdamW.lr: 0.0001
|
19 |
|
|
|
4 |
fine_tune: True
|
5 |
|
6 |
train/AudioDataset.n_examples: 100000000
|
7 |
+
val/AudioDataset.n_examples: 100
|
8 |
|
9 |
|
10 |
NoamScheduler.warmup: 500
|
11 |
|
12 |
+
batch_size: 7
|
13 |
num_workers: 7
|
14 |
+
save_iters: [100000, 200000, 300000, 4000000, 500000]
|
|
|
|
|
15 |
|
16 |
AdamW.lr: 0.0001
|
17 |
|
conf/vampnet.yml
CHANGED
@@ -32,7 +32,7 @@ VampNet.n_heads: 20
|
|
32 |
VampNet.flash_attn: false
|
33 |
VampNet.dropout: 0.1
|
34 |
|
35 |
-
AudioLoader.relative_path:
|
36 |
AudioDataset.loudness_cutoff: -30.0
|
37 |
AudioDataset.without_replacement: true
|
38 |
AudioLoader.shuffle: true
|
|
|
32 |
VampNet.flash_attn: false
|
33 |
VampNet.dropout: 0.1
|
34 |
|
35 |
+
AudioLoader.relative_path: /data/
|
36 |
AudioDataset.loudness_cutoff: -30.0
|
37 |
AudioDataset.without_replacement: true
|
38 |
AudioLoader.shuffle: true
|
requirements.txt
CHANGED
@@ -1,10 +1,8 @@
|
|
1 |
torch
|
2 |
argbind>=0.3.2
|
3 |
-
numpy==1.
|
4 |
gradio
|
5 |
loralib
|
6 |
wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
|
7 |
lac @ git+https://github.com/hugofloresgarcia/lac.git
|
8 |
-
|
9 |
-
-e git+https://github.com/audacitorch/pyharp.git#egg=pyharp
|
10 |
-
torch_pitch_shift
|
|
|
1 |
torch
|
2 |
argbind>=0.3.2
|
3 |
+
numpy==1.22
|
4 |
gradio
|
5 |
loralib
|
6 |
wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
|
7 |
lac @ git+https://github.com/hugofloresgarcia/lac.git
|
8 |
+
audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git
|
|
|
|
scripts/exp/fine_tune.py
CHANGED
@@ -48,10 +48,11 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
|
|
48 |
}
|
49 |
|
50 |
interface_conf = {
|
51 |
-
"Interface.coarse_ckpt": f"./
|
|
|
52 |
|
53 |
-
"Interface.coarse2fine_ckpt": f"./
|
54 |
-
"Interface.
|
55 |
|
56 |
"Interface.codec_ckpt": "./models/vampnet/codec.pth",
|
57 |
"AudioLoader.sources": [audio_files_or_folders],
|
|
|
48 |
}
|
49 |
|
50 |
interface_conf = {
|
51 |
+
"Interface.coarse_ckpt": f"./models/vampnet/coarse.pth",
|
52 |
+
"Interface.coarse_lora_ckpt": f"./runs/{name}/coarse/latest/lora.pth",
|
53 |
|
54 |
+
"Interface.coarse2fine_ckpt": f"./models/vampnet/c2f.pth",
|
55 |
+
"Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
|
56 |
|
57 |
"Interface.codec_ckpt": "./models/vampnet/codec.pth",
|
58 |
"AudioLoader.sources": [audio_files_or_folders],
|
scripts/exp/train.py
CHANGED
@@ -14,7 +14,7 @@ from audiotools.data import transforms
|
|
14 |
from einops import rearrange
|
15 |
from rich import pretty
|
16 |
from rich.traceback import install
|
17 |
-
from
|
18 |
|
19 |
import vampnet
|
20 |
from vampnet.modules.transformer import VampNet
|
@@ -29,9 +29,6 @@ from audiotools.ml.decorators import (
|
|
29 |
|
30 |
import loralib as lora
|
31 |
|
32 |
-
import torch._dynamo
|
33 |
-
torch._dynamo.config.verbose=True
|
34 |
-
|
35 |
|
36 |
# Enable cudnn autotuner to speed up training
|
37 |
# (can be altered by the funcs.seed function)
|
@@ -224,7 +221,7 @@ def train_loop(state: State, batch: dict, accel: Accelerator):
|
|
224 |
|
225 |
dtype = torch.bfloat16 if accel.amp else None
|
226 |
with accel.autocast(dtype=dtype):
|
227 |
-
z_hat = state.model(z_mask_latent)
|
228 |
|
229 |
target = codebook_flatten(
|
230 |
z[:, vn.n_conditioning_codebooks :, :],
|
@@ -289,7 +286,7 @@ def val_loop(state: State, batch: dict, accel: Accelerator):
|
|
289 |
|
290 |
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
291 |
|
292 |
-
z_hat = state.model(z_mask_latent)
|
293 |
|
294 |
target = codebook_flatten(
|
295 |
z[:, vn.n_conditioning_codebooks :, :],
|
@@ -408,19 +405,19 @@ def save_imputation(state, z, val_idx, writer):
|
|
408 |
|
409 |
for i in range(len(val_idx)):
|
410 |
imputed_noisy[i].cpu().write_audio_to_tb(
|
411 |
-
f"
|
412 |
writer,
|
413 |
step=state.tracker.step,
|
414 |
plot_fn=None,
|
415 |
)
|
416 |
imputed[i].cpu().write_audio_to_tb(
|
417 |
-
f"
|
418 |
writer,
|
419 |
step=state.tracker.step,
|
420 |
plot_fn=None,
|
421 |
)
|
422 |
imputed_true[i].cpu().write_audio_to_tb(
|
423 |
-
f"
|
424 |
writer,
|
425 |
step=state.tracker.step,
|
426 |
plot_fn=None,
|
@@ -450,7 +447,7 @@ def save_samples(state: State, val_idx: int, writer: SummaryWriter):
|
|
450 |
|
451 |
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
452 |
|
453 |
-
z_hat = state.model(z_mask_latent)
|
454 |
|
455 |
z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
|
456 |
z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks)
|
@@ -469,7 +466,7 @@ def save_samples(state: State, val_idx: int, writer: SummaryWriter):
|
|
469 |
}
|
470 |
for k, v in audio_dict.items():
|
471 |
v.cpu().write_audio_to_tb(
|
472 |
-
f"
|
473 |
writer,
|
474 |
step=state.tracker.step,
|
475 |
plot_fn=None,
|
@@ -488,6 +485,7 @@ def load(
|
|
488 |
save_path: str,
|
489 |
resume: bool = False,
|
490 |
tag: str = "latest",
|
|
|
491 |
fine_tune_checkpoint: Optional[str] = None,
|
492 |
grad_clip_val: float = 5.0,
|
493 |
) -> State:
|
@@ -500,7 +498,7 @@ def load(
|
|
500 |
kwargs = {
|
501 |
"folder": f"{save_path}/{tag}",
|
502 |
"map_location": "cpu",
|
503 |
-
"package":
|
504 |
}
|
505 |
tracker.print(f"Loading checkpoint from {kwargs['folder']}")
|
506 |
if (Path(kwargs["folder"]) / "vampnet").exists():
|
@@ -513,14 +511,11 @@ def load(
|
|
513 |
|
514 |
if args["fine_tune"]:
|
515 |
assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
|
516 |
-
model =
|
517 |
-
VampNet.load(location=Path(fine_tune_checkpoint),
|
518 |
-
map_location="cpu",
|
519 |
-
)
|
520 |
-
)
|
521 |
|
522 |
|
523 |
-
model =
|
|
|
524 |
model = accel.prepare_model(model)
|
525 |
|
526 |
# assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
|
@@ -604,7 +599,7 @@ def train(
|
|
604 |
accel=accel,
|
605 |
tracker=tracker,
|
606 |
save_path=save_path)
|
607 |
-
|
608 |
|
609 |
train_dataloader = accel.prepare_dataloader(
|
610 |
state.train_data,
|
@@ -619,15 +614,13 @@ def train(
|
|
619 |
num_workers=num_workers,
|
620 |
batch_size=batch_size,
|
621 |
collate_fn=state.val_data.collate,
|
622 |
-
persistent_workers=
|
623 |
)
|
624 |
-
print("initialized dataloader.")
|
625 |
|
626 |
|
627 |
|
628 |
if fine_tune:
|
629 |
lora.mark_only_lora_as_trainable(state.model)
|
630 |
-
print("marked only lora as trainable.")
|
631 |
|
632 |
# Wrap the functions so that they neatly track in TensorBoard + progress bars
|
633 |
# and only run when specific conditions are met.
|
@@ -642,7 +635,6 @@ def train(
|
|
642 |
save_samples = when(lambda: accel.local_rank == 0)(save_samples)
|
643 |
checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
|
644 |
|
645 |
-
print("starting training loop.")
|
646 |
with tracker.live:
|
647 |
for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
|
648 |
train_loop(state, batch, accel)
|
|
|
14 |
from einops import rearrange
|
15 |
from rich import pretty
|
16 |
from rich.traceback import install
|
17 |
+
from tensorboardX import SummaryWriter
|
18 |
|
19 |
import vampnet
|
20 |
from vampnet.modules.transformer import VampNet
|
|
|
29 |
|
30 |
import loralib as lora
|
31 |
|
|
|
|
|
|
|
32 |
|
33 |
# Enable cudnn autotuner to speed up training
|
34 |
# (can be altered by the funcs.seed function)
|
|
|
221 |
|
222 |
dtype = torch.bfloat16 if accel.amp else None
|
223 |
with accel.autocast(dtype=dtype):
|
224 |
+
z_hat = state.model(z_mask_latent, r)
|
225 |
|
226 |
target = codebook_flatten(
|
227 |
z[:, vn.n_conditioning_codebooks :, :],
|
|
|
286 |
|
287 |
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
288 |
|
289 |
+
z_hat = state.model(z_mask_latent, r)
|
290 |
|
291 |
target = codebook_flatten(
|
292 |
z[:, vn.n_conditioning_codebooks :, :],
|
|
|
405 |
|
406 |
for i in range(len(val_idx)):
|
407 |
imputed_noisy[i].cpu().write_audio_to_tb(
|
408 |
+
f"imputed_noisy/{i}",
|
409 |
writer,
|
410 |
step=state.tracker.step,
|
411 |
plot_fn=None,
|
412 |
)
|
413 |
imputed[i].cpu().write_audio_to_tb(
|
414 |
+
f"imputed/{i}",
|
415 |
writer,
|
416 |
step=state.tracker.step,
|
417 |
plot_fn=None,
|
418 |
)
|
419 |
imputed_true[i].cpu().write_audio_to_tb(
|
420 |
+
f"imputed_true/{i}",
|
421 |
writer,
|
422 |
step=state.tracker.step,
|
423 |
plot_fn=None,
|
|
|
447 |
|
448 |
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
449 |
|
450 |
+
z_hat = state.model(z_mask_latent, r)
|
451 |
|
452 |
z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
|
453 |
z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks)
|
|
|
466 |
}
|
467 |
for k, v in audio_dict.items():
|
468 |
v.cpu().write_audio_to_tb(
|
469 |
+
f"samples/_{i}.r={r[i]:0.2f}/{k}",
|
470 |
writer,
|
471 |
step=state.tracker.step,
|
472 |
plot_fn=None,
|
|
|
485 |
save_path: str,
|
486 |
resume: bool = False,
|
487 |
tag: str = "latest",
|
488 |
+
load_weights: bool = False,
|
489 |
fine_tune_checkpoint: Optional[str] = None,
|
490 |
grad_clip_val: float = 5.0,
|
491 |
) -> State:
|
|
|
498 |
kwargs = {
|
499 |
"folder": f"{save_path}/{tag}",
|
500 |
"map_location": "cpu",
|
501 |
+
"package": not load_weights,
|
502 |
}
|
503 |
tracker.print(f"Loading checkpoint from {kwargs['folder']}")
|
504 |
if (Path(kwargs["folder"]) / "vampnet").exists():
|
|
|
511 |
|
512 |
if args["fine_tune"]:
|
513 |
assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
|
514 |
+
model = VampNet.load(location=Path(fine_tune_checkpoint), map_location="cpu")
|
|
|
|
|
|
|
|
|
515 |
|
516 |
|
517 |
+
model = VampNet() if model is None else model
|
518 |
+
|
519 |
model = accel.prepare_model(model)
|
520 |
|
521 |
# assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
|
|
|
599 |
accel=accel,
|
600 |
tracker=tracker,
|
601 |
save_path=save_path)
|
602 |
+
|
603 |
|
604 |
train_dataloader = accel.prepare_dataloader(
|
605 |
state.train_data,
|
|
|
614 |
num_workers=num_workers,
|
615 |
batch_size=batch_size,
|
616 |
collate_fn=state.val_data.collate,
|
617 |
+
persistent_workers=True,
|
618 |
)
|
|
|
619 |
|
620 |
|
621 |
|
622 |
if fine_tune:
|
623 |
lora.mark_only_lora_as_trainable(state.model)
|
|
|
624 |
|
625 |
# Wrap the functions so that they neatly track in TensorBoard + progress bars
|
626 |
# and only run when specific conditions are met.
|
|
|
635 |
save_samples = when(lambda: accel.local_rank == 0)(save_samples)
|
636 |
checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
|
637 |
|
|
|
638 |
with tracker.live:
|
639 |
for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
|
640 |
train_loop(state, batch, accel)
|
scripts/utils/{data/augment.py → augment.py}
RENAMED
@@ -5,19 +5,34 @@ from audiotools import AudioSignal
|
|
5 |
|
6 |
import argbind
|
7 |
import tqdm
|
8 |
-
import torch
|
9 |
|
10 |
|
11 |
-
from
|
12 |
-
|
|
|
|
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
@argbind.bind(without_prefix=True)
|
18 |
def augment(
|
19 |
-
audio_folder: Path
|
20 |
-
dest_folder: Path
|
21 |
n_augmentations: int = 10,
|
22 |
):
|
23 |
"""
|
@@ -26,8 +41,7 @@ def augment(
|
|
26 |
The dest foler will contain a folder for each of the clean dataset's files.
|
27 |
Under each of these folders, there will be a clean file and many augmented files.
|
28 |
"""
|
29 |
-
|
30 |
-
assert dest_folder is not None
|
31 |
audio_files = at.util.find_audio(audio_folder)
|
32 |
|
33 |
for audio_file in tqdm.tqdm(audio_files):
|
@@ -35,33 +49,5 @@ def augment(
|
|
35 |
subdir = subtree / audio_file.stem
|
36 |
subdir.mkdir(parents=True, exist_ok=True)
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
for i, chunk in tqdm.tqdm(enumerate(src.windows(10, 10))):
|
42 |
-
# apply pedalboard transforms
|
43 |
-
for j in range(n_augmentations):
|
44 |
-
# pitch shift between -7 and 7 semitones
|
45 |
-
import random
|
46 |
-
dst = chunk.clone()
|
47 |
-
dst.samples = pitch_shift(
|
48 |
-
dst.samples,
|
49 |
-
shift=random.choice(get_fast_shifts(src.sample_rate,
|
50 |
-
condition=lambda x: x >= 0.25 and x <= 1.0)),
|
51 |
-
sample_rate=src.sample_rate
|
52 |
-
)
|
53 |
-
dst.samples = time_stretch(
|
54 |
-
dst.samples,
|
55 |
-
stretch=random.choice(get_fast_stretches(src.sample_rate,
|
56 |
-
condition=lambda x: x >= 0.667 and x <= 1.5, )),
|
57 |
-
sample_rate=src.sample_rate,
|
58 |
-
)
|
59 |
-
|
60 |
-
dst.cpu().write(subdir / f"{i}-{j}.wav")
|
61 |
-
|
62 |
-
|
63 |
-
if __name__ == "__main__":
|
64 |
-
args = argbind.parse_args()
|
65 |
-
|
66 |
-
with argbind.scope(args):
|
67 |
-
augment()
|
|
|
5 |
|
6 |
import argbind
|
7 |
import tqdm
|
|
|
8 |
|
9 |
|
10 |
+
from pedalboard import (
|
11 |
+
Compressor, Gain, Chorus, LadderFilter, Phaser, Convolution, Reverb, Pedalboard
|
12 |
+
)
|
13 |
+
from pedalboard.io import AudioFile
|
14 |
|
15 |
+
# Read in a whole file, resampling to our desired sample rate:
|
16 |
+
samplerate = 44100.0
|
17 |
+
with AudioFile('guitar-input.wav').resampled_to(samplerate) as f:
|
18 |
+
audio = f.read(f.frames)
|
19 |
+
|
20 |
+
# Make a pretty interesting sounding guitar pedalboard:
|
21 |
+
board = Pedalboard([
|
22 |
+
Compressor(threshold_db=-50, ratio=25),
|
23 |
+
Gain(gain_db=30),
|
24 |
+
Chorus(),
|
25 |
+
LadderFilter(mode=LadderFilter.Mode.HPF12, cutoff_hz=900),
|
26 |
+
Phaser(),
|
27 |
+
Convolution("./guitar_amp.wav", 1.0),
|
28 |
+
Reverb(room_size=0.25),
|
29 |
+
])
|
30 |
|
31 |
|
32 |
@argbind.bind(without_prefix=True)
|
33 |
def augment(
|
34 |
+
audio_folder: Path,
|
35 |
+
dest_folder: Path,
|
36 |
n_augmentations: int = 10,
|
37 |
):
|
38 |
"""
|
|
|
41 |
The dest foler will contain a folder for each of the clean dataset's files.
|
42 |
Under each of these folders, there will be a clean file and many augmented files.
|
43 |
"""
|
44 |
+
|
|
|
45 |
audio_files = at.util.find_audio(audio_folder)
|
46 |
|
47 |
for audio_file in tqdm.tqdm(audio_files):
|
|
|
49 |
subdir = subtree / audio_file.stem
|
50 |
subdir.mkdir(parents=True, exist_ok=True)
|
51 |
|
52 |
+
# apply pedalboard transforms
|
53 |
+
for i in range(n_augmentations):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/utils/gtzan_embeddings.py
DELETED
@@ -1,263 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
TODO: train a linear probe
|
3 |
-
usage:
|
4 |
-
python gtzan_embeddings.py --args.load conf/interface.yml --Interface.device cuda --path_to_gtzan /path/to/gtzan/genres_original --output_dir /path/to/output
|
5 |
-
"""
|
6 |
-
from pathlib import Path
|
7 |
-
from typing import List
|
8 |
-
|
9 |
-
import audiotools as at
|
10 |
-
from audiotools import AudioSignal
|
11 |
-
import argbind
|
12 |
-
import torch
|
13 |
-
import numpy as np
|
14 |
-
import zipfile
|
15 |
-
import json
|
16 |
-
|
17 |
-
from vampnet.interface import Interface
|
18 |
-
import tqdm
|
19 |
-
|
20 |
-
# bind the Interface to argbind
|
21 |
-
Interface = argbind.bind(Interface)
|
22 |
-
|
23 |
-
DEBUG = False
|
24 |
-
|
25 |
-
def smart_plotly_export(fig, save_path):
|
26 |
-
img_format = save_path.split('.')[-1]
|
27 |
-
if img_format == 'html':
|
28 |
-
fig.write_html(save_path)
|
29 |
-
elif img_format == 'bytes':
|
30 |
-
return fig.to_image(format='png')
|
31 |
-
#TODO: come back and make this prettier
|
32 |
-
elif img_format == 'numpy':
|
33 |
-
import io
|
34 |
-
from PIL import Image
|
35 |
-
|
36 |
-
def plotly_fig2array(fig):
|
37 |
-
#convert Plotly fig to an array
|
38 |
-
fig_bytes = fig.to_image(format="png", width=1200, height=700)
|
39 |
-
buf = io.BytesIO(fig_bytes)
|
40 |
-
img = Image.open(buf)
|
41 |
-
return np.asarray(img)
|
42 |
-
|
43 |
-
return plotly_fig2array(fig)
|
44 |
-
elif img_format == 'jpeg' or 'png' or 'webp':
|
45 |
-
fig.write_image(save_path)
|
46 |
-
else:
|
47 |
-
raise ValueError("invalid image format")
|
48 |
-
|
49 |
-
def dim_reduce(emb, labels, save_path, n_components=3, method='tsne', title=''):
|
50 |
-
"""
|
51 |
-
dimensionality reduction for visualization!
|
52 |
-
saves an html plotly figure to save_path
|
53 |
-
parameters:
|
54 |
-
emb (np.ndarray): the samples to be reduces with shape (samples, features)
|
55 |
-
labels (list): list of labels for embedding
|
56 |
-
save_path (str): path where u wanna save ur figure
|
57 |
-
method (str): umap, tsne, or pca
|
58 |
-
title (str): title for ur figure
|
59 |
-
returns:
|
60 |
-
proj (np.ndarray): projection vector with shape (samples, dimensions)
|
61 |
-
"""
|
62 |
-
import pandas as pd
|
63 |
-
import plotly.express as px
|
64 |
-
if method == 'umap':
|
65 |
-
reducer = umap.UMAP(n_components=n_components)
|
66 |
-
elif method == 'tsne':
|
67 |
-
from sklearn.manifold import TSNE
|
68 |
-
reducer = TSNE(n_components=n_components)
|
69 |
-
elif method == 'pca':
|
70 |
-
from sklearn.decomposition import PCA
|
71 |
-
reducer = PCA(n_components=n_components)
|
72 |
-
else:
|
73 |
-
raise ValueError
|
74 |
-
|
75 |
-
proj = reducer.fit_transform(emb)
|
76 |
-
|
77 |
-
if n_components == 2:
|
78 |
-
df = pd.DataFrame(dict(
|
79 |
-
x=proj[:, 0],
|
80 |
-
y=proj[:, 1],
|
81 |
-
instrument=labels
|
82 |
-
))
|
83 |
-
fig = px.scatter(df, x='x', y='y', color='instrument',
|
84 |
-
title=title+f"_{method}")
|
85 |
-
|
86 |
-
elif n_components == 3:
|
87 |
-
df = pd.DataFrame(dict(
|
88 |
-
x=proj[:, 0],
|
89 |
-
y=proj[:, 1],
|
90 |
-
z=proj[:, 2],
|
91 |
-
instrument=labels
|
92 |
-
))
|
93 |
-
fig = px.scatter_3d(df, x='x', y='y', z='z',
|
94 |
-
color='instrument',
|
95 |
-
title=title)
|
96 |
-
else:
|
97 |
-
raise ValueError("cant plot more than 3 components")
|
98 |
-
|
99 |
-
fig.update_traces(marker=dict(size=6,
|
100 |
-
line=dict(width=1,
|
101 |
-
color='DarkSlateGrey')),
|
102 |
-
selector=dict(mode='markers'))
|
103 |
-
|
104 |
-
return smart_plotly_export(fig, save_path)
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
# per JukeMIR, we want the emebddings from the middle layer?
|
109 |
-
def vampnet_embed(sig: AudioSignal, interface: Interface, layer=10):
|
110 |
-
with torch.inference_mode():
|
111 |
-
# preprocess the signal
|
112 |
-
sig = interface.preprocess(sig)
|
113 |
-
|
114 |
-
# get the coarse vampnet model
|
115 |
-
vampnet = interface.coarse
|
116 |
-
|
117 |
-
# get the tokens
|
118 |
-
z = interface.encode(sig)[:, :vampnet.n_codebooks, :]
|
119 |
-
z_latents = vampnet.embedding.from_codes(z, interface.codec)
|
120 |
-
|
121 |
-
# do a forward pass through the model, get the embeddings
|
122 |
-
_z, embeddings = vampnet(z_latents, return_activations=True)
|
123 |
-
# print(f"got embeddings with shape {embeddings.shape}")
|
124 |
-
# [layer, batch, time, n_dims]
|
125 |
-
# [20, 1, 600ish, 768]
|
126 |
-
|
127 |
-
|
128 |
-
# squeeze batch dim (1 bc layer should be dim 0)
|
129 |
-
assert embeddings.shape[1] == 1, f"expected batch dim to be 1, got {embeddings.shape[0]}"
|
130 |
-
embeddings = embeddings.squeeze(1)
|
131 |
-
|
132 |
-
num_layers = embeddings.shape[0]
|
133 |
-
assert layer < num_layers, f"layer {layer} is out of bounds for model with {num_layers} layers"
|
134 |
-
|
135 |
-
# do meanpooling over the time dimension
|
136 |
-
embeddings = embeddings.mean(dim=-2)
|
137 |
-
# [20, 768]
|
138 |
-
|
139 |
-
# return the embeddings
|
140 |
-
return embeddings
|
141 |
-
|
142 |
-
from dataclasses import dataclass, fields
|
143 |
-
@dataclass
|
144 |
-
class Embedding:
|
145 |
-
genre: str
|
146 |
-
filename: str
|
147 |
-
embedding: np.ndarray
|
148 |
-
|
149 |
-
def save(self, path):
|
150 |
-
"""Save the Embedding object to a given path as a zip file."""
|
151 |
-
with zipfile.ZipFile(path, 'w') as archive:
|
152 |
-
|
153 |
-
# Save numpy array
|
154 |
-
with archive.open('embedding.npy', 'w') as f:
|
155 |
-
np.save(f, self.embedding)
|
156 |
-
|
157 |
-
# Save non-numpy data as json
|
158 |
-
non_numpy_data = {f.name: getattr(self, f.name) for f in fields(self) if f.name != 'embedding'}
|
159 |
-
with archive.open('data.json', 'w') as f:
|
160 |
-
f.write(json.dumps(non_numpy_data).encode('utf-8'))
|
161 |
-
|
162 |
-
@classmethod
|
163 |
-
def load(cls, path):
|
164 |
-
"""Load the Embedding object from a given zip path."""
|
165 |
-
with zipfile.ZipFile(path, 'r') as archive:
|
166 |
-
|
167 |
-
# Load numpy array
|
168 |
-
with archive.open('embedding.npy') as f:
|
169 |
-
embedding = np.load(f)
|
170 |
-
|
171 |
-
# Load non-numpy data from json
|
172 |
-
with archive.open('data.json') as f:
|
173 |
-
data = json.loads(f.read().decode('utf-8'))
|
174 |
-
|
175 |
-
return cls(embedding=embedding, **data)
|
176 |
-
|
177 |
-
|
178 |
-
@argbind.bind(without_prefix=True)
|
179 |
-
def main(
|
180 |
-
path_to_gtzan: str = None,
|
181 |
-
cache_dir: str = "./.gtzan_emb_cache",
|
182 |
-
output_dir: str = "./gtzan_vampnet_embeddings",
|
183 |
-
layers: List[int] = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19]
|
184 |
-
):
|
185 |
-
path_to_gtzan = Path(path_to_gtzan)
|
186 |
-
assert path_to_gtzan.exists(), f"{path_to_gtzan} does not exist"
|
187 |
-
|
188 |
-
cache_dir = Path(cache_dir)
|
189 |
-
output_dir = Path(output_dir)
|
190 |
-
output_dir.mkdir(exist_ok=True, parents=True)
|
191 |
-
|
192 |
-
# load our interface
|
193 |
-
# argbind will automatically load the default config,
|
194 |
-
interface = Interface()
|
195 |
-
|
196 |
-
# gtzan should have a folder for each genre, so let's get the list of genres
|
197 |
-
genres = [Path(x).name for x in path_to_gtzan.iterdir() if x.is_dir()]
|
198 |
-
print(f"Found {len(genres)} genres")
|
199 |
-
print(f"genres: {genres}")
|
200 |
-
|
201 |
-
# collect audio files, genres, and embeddings
|
202 |
-
data = []
|
203 |
-
for genre in genres:
|
204 |
-
audio_files = list(at.util.find_audio(path_to_gtzan / genre))
|
205 |
-
print(f"Found {len(audio_files)} audio files for genre {genre}")
|
206 |
-
|
207 |
-
for audio_file in tqdm.tqdm(audio_files, desc=f"embedding genre {genre}"):
|
208 |
-
# check if we have a cached embedding for this file
|
209 |
-
cached_path = (cache_dir / f"{genre}_{audio_file.stem}.emb")
|
210 |
-
if cached_path.exists():
|
211 |
-
# if so, load it
|
212 |
-
if DEBUG:
|
213 |
-
print(f"loading cached embedding for {cached_path.stem}")
|
214 |
-
embedding = Embedding.load(cached_path)
|
215 |
-
data.append(embedding)
|
216 |
-
else:
|
217 |
-
try:
|
218 |
-
sig = AudioSignal(audio_file)
|
219 |
-
except Exception as e:
|
220 |
-
print(f"failed to load {audio_file.name} with error {e}")
|
221 |
-
print(f"skipping {audio_file.name}")
|
222 |
-
continue
|
223 |
-
|
224 |
-
# gets the embedding
|
225 |
-
emb = vampnet_embed(sig, interface).cpu().numpy()
|
226 |
-
|
227 |
-
# create an embedding we can save/load
|
228 |
-
embedding = Embedding(
|
229 |
-
genre=genre,
|
230 |
-
filename=audio_file.name,
|
231 |
-
embedding=emb
|
232 |
-
)
|
233 |
-
|
234 |
-
# cache the embeddings
|
235 |
-
cached_path.parent.mkdir(exist_ok=True, parents=True)
|
236 |
-
embedding.save(cached_path)
|
237 |
-
|
238 |
-
# now, let's do a dim reduction on the embeddings
|
239 |
-
# and visualize them.
|
240 |
-
|
241 |
-
# collect a list of embeddings and labels
|
242 |
-
embeddings = [d.embedding for d in data]
|
243 |
-
labels = [d.genre for d in data]
|
244 |
-
|
245 |
-
# convert the embeddings to a numpy array
|
246 |
-
embeddings = np.stack(embeddings)
|
247 |
-
|
248 |
-
# do dimensionality reduction for each layer we're given
|
249 |
-
for layer in tqdm.tqdm(layers, desc="dim reduction"):
|
250 |
-
dim_reduce(
|
251 |
-
embeddings[:, layer, :], labels,
|
252 |
-
save_path=str(output_dir / f'vampnet-gtzan-layer={layer}.html'),
|
253 |
-
n_components=2, method='tsne',
|
254 |
-
title=f'vampnet-gtzan-layer={layer}'
|
255 |
-
)
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
if __name__ == "__main__":
|
261 |
-
args = argbind.parse_args()
|
262 |
-
with argbind.scope(args):
|
263 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/utils/{data/maestro-reorg.py → maestro-reorg.py}
RENAMED
File without changes
|
scripts/utils/remove_quiet_files.py
DELETED
@@ -1,29 +0,0 @@
|
|
1 |
-
# removes files with loudness below 24db
|
2 |
-
|
3 |
-
from pathlib import Path
|
4 |
-
import shutil
|
5 |
-
import audiotools as at
|
6 |
-
import argbind
|
7 |
-
|
8 |
-
@argbind.bind(without_prefix=True)
|
9 |
-
def remove_quiet_files(
|
10 |
-
src_dir: Path = None,
|
11 |
-
dest_dir: Path = None,
|
12 |
-
min_loudness: float = -30,
|
13 |
-
):
|
14 |
-
# copy src to dest
|
15 |
-
dest_dir.mkdir(parents=True, exist_ok=True)
|
16 |
-
shutil.copytree(src_dir, dest_dir, dirs_exist_ok=True)
|
17 |
-
|
18 |
-
audio_files = at.util.find_audio(dest_dir)
|
19 |
-
for audio_file in audio_files:
|
20 |
-
sig = at.AudioSignal(audio_file)
|
21 |
-
if sig.loudness() < min_loudness:
|
22 |
-
audio_file.unlink()
|
23 |
-
print(f"removed {audio_file}")
|
24 |
-
|
25 |
-
if __name__ == "__main__":
|
26 |
-
args = argbind.parse_args()
|
27 |
-
|
28 |
-
with argbind.scope(args):
|
29 |
-
remove_quiet_files()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/utils/split_long_audio_file.py
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
from pathlib import Path
|
2 |
-
import argbind
|
3 |
-
|
4 |
-
import audiotools as at
|
5 |
-
import tqdm
|
6 |
-
|
7 |
-
|
8 |
-
@argbind.bind(without_prefix=True)
|
9 |
-
def split_long_audio_file(
|
10 |
-
file: str = None,
|
11 |
-
max_chunk_size_s: int = 60*10
|
12 |
-
):
|
13 |
-
file = Path(file)
|
14 |
-
output_dir = file.parent / file.stem
|
15 |
-
output_dir.mkdir()
|
16 |
-
|
17 |
-
sig = at.AudioSignal(file)
|
18 |
-
|
19 |
-
# split into chunks
|
20 |
-
for i, sig in tqdm.tqdm(enumerate(sig.windows(
|
21 |
-
window_duration=max_chunk_size_s, hop_duration=max_chunk_size_s/2,
|
22 |
-
preprocess=True))
|
23 |
-
):
|
24 |
-
sig.write(output_dir / f"{i}.wav")
|
25 |
-
|
26 |
-
print(f"wrote {len(list(output_dir.glob('*.wav')))} files to {output_dir}")
|
27 |
-
|
28 |
-
return output_dir
|
29 |
-
|
30 |
-
if __name__ == "__main__":
|
31 |
-
args = argbind.parse_args()
|
32 |
-
|
33 |
-
with argbind.scope(args):
|
34 |
-
split_long_audio_file()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/utils/xeno-canto-dl.py
DELETED
@@ -1,234 +0,0 @@
|
|
1 |
-
from xenopy import Query
|
2 |
-
|
3 |
-
|
4 |
-
SPECIES = [
|
5 |
-
"American Robin",
|
6 |
-
"Northern Cardinal",
|
7 |
-
"Mourning Dove",
|
8 |
-
"American Crow",
|
9 |
-
"Baltimore Oriole",
|
10 |
-
"Blue Jay",
|
11 |
-
"Eastern Bluebird",
|
12 |
-
"House Finch",
|
13 |
-
"American Goldfinch",
|
14 |
-
"House Sparrow",
|
15 |
-
"Song Sparrow",
|
16 |
-
"Tufted Titmouse",
|
17 |
-
"White-breasted Nuthatch",
|
18 |
-
"European Starling",
|
19 |
-
"American Redstart",
|
20 |
-
"Red-winged Blackbird",
|
21 |
-
"Brown-headed Cowbird",
|
22 |
-
"Common Grackle",
|
23 |
-
"Boat-tailed Grackle",
|
24 |
-
"Common Yellowthroat",
|
25 |
-
"Northern Mockingbird",
|
26 |
-
"Carolina Wren",
|
27 |
-
"Eastern Meadowlark",
|
28 |
-
"Chipping Sparrow",
|
29 |
-
"Tree Swallow",
|
30 |
-
"Barn Swallow",
|
31 |
-
"Cliff Swallow",
|
32 |
-
"Pine Siskin",
|
33 |
-
"Indigo Bunting",
|
34 |
-
"Eastern Towhee",
|
35 |
-
"Carolina Chickadee",
|
36 |
-
"Great Crested Flycatcher",
|
37 |
-
"Eastern Wood-Pewee",
|
38 |
-
"Ovenbird",
|
39 |
-
"Northern Flicker",
|
40 |
-
"Red-eyed Vireo",
|
41 |
-
"American Woodcock",
|
42 |
-
"Eastern Phoebe",
|
43 |
-
"Downy Woodpecker",
|
44 |
-
"Scarlet Tanager",
|
45 |
-
"Yellow Warbler",
|
46 |
-
"White-eyed Vireo",
|
47 |
-
"Common Loon",
|
48 |
-
"White-throated Sparrow",
|
49 |
-
"Yellow-throated Vireo",
|
50 |
-
"Great Blue Heron",
|
51 |
-
"Belted Kingfisher",
|
52 |
-
"Pied-billed Grebe",
|
53 |
-
"Wild Turkey",
|
54 |
-
"Wood Thrush",
|
55 |
-
"Rose-breasted Grosbeak",
|
56 |
-
"Field Sparrow",
|
57 |
-
"Hooded Warbler",
|
58 |
-
"Northern Parula",
|
59 |
-
"Chestnut-sided Warbler",
|
60 |
-
"Blue-winged Warbler",
|
61 |
-
"Red-bellied Woodpecker",
|
62 |
-
"Yellow-billed Cuckoo",
|
63 |
-
"Gray Catbird",
|
64 |
-
"Northern Saw-whet Owl",
|
65 |
-
"Osprey",
|
66 |
-
"Common Nighthawk",
|
67 |
-
"Broad-winged Hawk",
|
68 |
-
"Black-throated Green Warbler",
|
69 |
-
"Great Horned Owl",
|
70 |
-
"Common Raven",
|
71 |
-
"Barred Owl",
|
72 |
-
"Canada Warbler",
|
73 |
-
"Magnolia Warbler",
|
74 |
-
"Black-and-white Warbler",
|
75 |
-
"Eastern Kingbird",
|
76 |
-
"Swainson's Thrush",
|
77 |
-
"Worm-eating Warbler",
|
78 |
-
"Prairie Warbler",
|
79 |
-
"Baltimore Oriole",
|
80 |
-
"Black-throated Blue Warbler",
|
81 |
-
"Louisiana Waterthrush",
|
82 |
-
"Blackburnian Warbler",
|
83 |
-
"Black-capped Chickadee",
|
84 |
-
"Cerulean Warbler",
|
85 |
-
"Red-shouldered Hawk",
|
86 |
-
"Cooper's Hawk",
|
87 |
-
"Yellow-throated Warbler",
|
88 |
-
"Blue-headed Vireo",
|
89 |
-
"Blackpoll Warbler",
|
90 |
-
"Ruffed Grouse",
|
91 |
-
"Kentucky Warbler",
|
92 |
-
"Hermit Thrush",
|
93 |
-
"Cedar Waxwing",
|
94 |
-
"Eastern Screech-Owl",
|
95 |
-
"Northern Goshawk",
|
96 |
-
"Green Heron",
|
97 |
-
"Red-tailed Hawk",
|
98 |
-
"Black Vulture",
|
99 |
-
"Hairy Woodpecker",
|
100 |
-
"Golden-crowned Kinglet",
|
101 |
-
"Ruby-crowned Kinglet",
|
102 |
-
"Bicknell's Thrush",
|
103 |
-
"Blue-gray Gnatcatcher",
|
104 |
-
"Veery",
|
105 |
-
"Pileated Woodpecker",
|
106 |
-
"Purple Finch",
|
107 |
-
"White-crowned Sparrow",
|
108 |
-
"Snow Bunting",
|
109 |
-
"Pine Grosbeak",
|
110 |
-
"American Tree Sparrow",
|
111 |
-
"Dark-eyed Junco",
|
112 |
-
"Snowy Owl",
|
113 |
-
"White-winged Crossbill",
|
114 |
-
"Red Crossbill",
|
115 |
-
"Common Redpoll",
|
116 |
-
"Northern Shrike",
|
117 |
-
"Northern Harrier",
|
118 |
-
"Rough-legged Hawk",
|
119 |
-
"Long-eared Owl",
|
120 |
-
"Evening Grosbeak",
|
121 |
-
"Northern Pintail",
|
122 |
-
"American Black Duck",
|
123 |
-
"Mallard",
|
124 |
-
"Canvasback",
|
125 |
-
"Redhead",
|
126 |
-
"Ring-necked Duck",
|
127 |
-
"Greater Scaup",
|
128 |
-
"Lesser Scaup",
|
129 |
-
"Bufflehead",
|
130 |
-
"Common Goldeneye",
|
131 |
-
"Hooded Merganser",
|
132 |
-
"Common Merganser",
|
133 |
-
"Red-breasted Merganser",
|
134 |
-
"Ruddy Duck",
|
135 |
-
"Wood Duck",
|
136 |
-
"Gadwall",
|
137 |
-
"American Wigeon",
|
138 |
-
"Northern Shoveler",
|
139 |
-
"Green-winged Teal",
|
140 |
-
"Blue-winged Teal",
|
141 |
-
"Cinnamon Teal",
|
142 |
-
"Ringed Teal",
|
143 |
-
"Cape Teal",
|
144 |
-
"Northern Fulmar",
|
145 |
-
"Yellow-billed Loon",
|
146 |
-
"Red-throated Loon",
|
147 |
-
"Arctic Loon",
|
148 |
-
"Pacific Loon",
|
149 |
-
"Horned Grebe",
|
150 |
-
"Red-necked Grebe",
|
151 |
-
"Eared Grebe",
|
152 |
-
"Western Grebe",
|
153 |
-
"Clark's Grebe",
|
154 |
-
"Double-crested Cormorant",
|
155 |
-
"Pelagic Cormorant",
|
156 |
-
"Great Cormorant",
|
157 |
-
"American White Pelican",
|
158 |
-
"Brown Pelican",
|
159 |
-
"Brandt's Cormorant",
|
160 |
-
"Least Bittern",
|
161 |
-
"Great Egret",
|
162 |
-
"Snowy Egret",
|
163 |
-
"Little Blue Heron",
|
164 |
-
"Tricolored Heron",
|
165 |
-
"Reddish Egret",
|
166 |
-
"Black-crowned Night-Heron",
|
167 |
-
"Yellow-crowned Night-Heron",
|
168 |
-
"White Ibis",
|
169 |
-
"Glossy Ibis",
|
170 |
-
"Roseate Spoonbill",
|
171 |
-
"Wood Stork",
|
172 |
-
"Black-bellied Whistling-Duck",
|
173 |
-
"Fulvous Whistling-Duck",
|
174 |
-
"Greater White-fronted Goose",
|
175 |
-
"Snow Goose",
|
176 |
-
"Ross's Goose",
|
177 |
-
"Canada Goose",
|
178 |
-
"Brant",
|
179 |
-
"Mute Swan",
|
180 |
-
"Tundra Swan",
|
181 |
-
"Whooper Swan",
|
182 |
-
"Sandhill Crane",
|
183 |
-
"Black-necked Stilt",
|
184 |
-
"American Avocet",
|
185 |
-
"Northern Jacana",
|
186 |
-
"Greater Yellowlegs",
|
187 |
-
"Lesser Yellowlegs",
|
188 |
-
"Willet",
|
189 |
-
"Spotted Sandpiper",
|
190 |
-
"Upland Sandpiper",
|
191 |
-
"Whimbrel",
|
192 |
-
"Long-billed Curlew",
|
193 |
-
"Marbled Godwit",
|
194 |
-
"Ruddy Turnstone",
|
195 |
-
"Red Knot",
|
196 |
-
"Sanderling",
|
197 |
-
"Semipalmated Sandpiper",
|
198 |
-
"Western Sandpiper",
|
199 |
-
"Least Sandpiper",
|
200 |
-
"White-rumped Sandpiper",
|
201 |
-
"Baird's Sandpiper",
|
202 |
-
"Pectoral Sandpiper",
|
203 |
-
"Dunlin",
|
204 |
-
"Buff-breasted Sandpiper",
|
205 |
-
"Short-billed Dowitcher",
|
206 |
-
"Long-billed Dowitcher",
|
207 |
-
"Common Snipe",
|
208 |
-
"American Woodcock",
|
209 |
-
"Wilson's Phalarope",
|
210 |
-
"Red-necked Phalarope",
|
211 |
-
"Red Phalarope"
|
212 |
-
]
|
213 |
-
|
214 |
-
from pathlib import Path
|
215 |
-
|
216 |
-
def remove_spaces(s):
|
217 |
-
return s.replace(" ", "")
|
218 |
-
|
219 |
-
for species in SPECIES:
|
220 |
-
if Path("/media/CHONK/hugo/xeno-canto-full/" + remove_spaces(species)).exists():
|
221 |
-
continue
|
222 |
-
try:
|
223 |
-
q = Query(
|
224 |
-
name=species, q="A", length="10-30",
|
225 |
-
)
|
226 |
-
|
227 |
-
# retrieve metadata
|
228 |
-
metafiles = q.retrieve_meta(verbose=True)
|
229 |
-
# retrieve recordings
|
230 |
-
q.retrieve_recordings(multiprocess=True, nproc=10, attempts=10, outdir="/media/CHONK/hugo/xeno-canto-full/")
|
231 |
-
|
232 |
-
except:
|
233 |
-
print("Failed to download " + species)
|
234 |
-
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
setup.py
CHANGED
@@ -28,13 +28,12 @@ setup(
|
|
28 |
install_requires=[
|
29 |
"torch",
|
30 |
"argbind>=0.3.2",
|
31 |
-
"numpy==1.
|
32 |
"wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat",
|
33 |
"lac @ git+https://github.com/hugofloresgarcia/lac.git",
|
34 |
"descript-audiotools @ git+https://github.com/descriptinc/[email protected]",
|
35 |
"gradio",
|
|
|
36 |
"loralib",
|
37 |
-
"torch_pitch_shift",
|
38 |
-
"madmom",
|
39 |
],
|
40 |
)
|
|
|
28 |
install_requires=[
|
29 |
"torch",
|
30 |
"argbind>=0.3.2",
|
31 |
+
"numpy==1.22",
|
32 |
"wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat",
|
33 |
"lac @ git+https://github.com/hugofloresgarcia/lac.git",
|
34 |
"descript-audiotools @ git+https://github.com/descriptinc/[email protected]",
|
35 |
"gradio",
|
36 |
+
"tensorboardX",
|
37 |
"loralib",
|
|
|
|
|
38 |
],
|
39 |
)
|
vampnet/interface.py
CHANGED
@@ -65,7 +65,7 @@ class Interface(torch.nn.Module):
|
|
65 |
):
|
66 |
super().__init__()
|
67 |
assert codec_ckpt is not None, "must provide a codec checkpoint"
|
68 |
-
self.codec = DAC.load(
|
69 |
self.codec.eval()
|
70 |
self.codec.to(device)
|
71 |
|
@@ -120,16 +120,17 @@ class Interface(torch.nn.Module):
|
|
120 |
if coarse_ckpt is not None:
|
121 |
self.coarse.to("cpu")
|
122 |
state_dict = torch.load(coarse_ckpt, map_location="cpu")
|
123 |
-
|
124 |
self.coarse.load_state_dict(state_dict, strict=False)
|
125 |
self.coarse.to(self.device)
|
126 |
if c2f_ckpt is not None:
|
127 |
self.c2f.to("cpu")
|
128 |
state_dict = torch.load(c2f_ckpt, map_location="cpu")
|
129 |
-
|
130 |
self.c2f.load_state_dict(state_dict, strict=False)
|
131 |
self.c2f.to(self.device)
|
132 |
|
|
|
133 |
def s2t(self, seconds: float):
|
134 |
"""seconds to tokens"""
|
135 |
if isinstance(seconds, np.ndarray):
|
@@ -193,8 +194,8 @@ class Interface(torch.nn.Module):
|
|
193 |
|
194 |
def make_beat_mask(self,
|
195 |
signal: AudioSignal,
|
196 |
-
before_beat_s: float = 0.
|
197 |
-
after_beat_s: float = 0.
|
198 |
mask_downbeats: bool = True,
|
199 |
mask_upbeats: bool = True,
|
200 |
downbeat_downsample_factor: int = None,
|
|
|
65 |
):
|
66 |
super().__init__()
|
67 |
assert codec_ckpt is not None, "must provide a codec checkpoint"
|
68 |
+
self.codec = DAC.load(codec_ckpt)
|
69 |
self.codec.eval()
|
70 |
self.codec.to(device)
|
71 |
|
|
|
120 |
if coarse_ckpt is not None:
|
121 |
self.coarse.to("cpu")
|
122 |
state_dict = torch.load(coarse_ckpt, map_location="cpu")
|
123 |
+
|
124 |
self.coarse.load_state_dict(state_dict, strict=False)
|
125 |
self.coarse.to(self.device)
|
126 |
if c2f_ckpt is not None:
|
127 |
self.c2f.to("cpu")
|
128 |
state_dict = torch.load(c2f_ckpt, map_location="cpu")
|
129 |
+
|
130 |
self.c2f.load_state_dict(state_dict, strict=False)
|
131 |
self.c2f.to(self.device)
|
132 |
|
133 |
+
|
134 |
def s2t(self, seconds: float):
|
135 |
"""seconds to tokens"""
|
136 |
if isinstance(seconds, np.ndarray):
|
|
|
194 |
|
195 |
def make_beat_mask(self,
|
196 |
signal: AudioSignal,
|
197 |
+
before_beat_s: float = 0.1,
|
198 |
+
after_beat_s: float = 0.1,
|
199 |
mask_downbeats: bool = True,
|
200 |
mask_upbeats: bool = True,
|
201 |
downbeat_downsample_factor: int = None,
|
vampnet/mask.py
CHANGED
@@ -191,47 +191,29 @@ def onset_mask(
|
|
191 |
width: int = 1
|
192 |
):
|
193 |
import librosa
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
if onset_indices.shape[0] == 0:
|
214 |
-
mask = empty_mask(z)
|
215 |
-
print(f"no onsets found, returning empty mask")
|
216 |
-
else:
|
217 |
-
torch.set_printoptions(threshold=1000)
|
218 |
-
print("onset indices: ", onset_indices)
|
219 |
-
print("onset times: ", onset_times)
|
220 |
-
|
221 |
-
# create a mask, set onset
|
222 |
-
mask = torch.ones_like(z)
|
223 |
-
n_timesteps = z.shape[-1]
|
224 |
-
|
225 |
-
for onset_index in onset_indices:
|
226 |
-
onset_index = min(onset_index, n_timesteps - 1)
|
227 |
-
onset_index = max(onset_index, 0)
|
228 |
-
mask[:, :, onset_index - width:onset_index + width] = 0.0
|
229 |
-
|
230 |
-
print(mask)
|
231 |
|
232 |
return mask
|
233 |
|
234 |
|
235 |
|
236 |
if __name__ == "__main__":
|
237 |
-
|
|
|
|
191 |
width: int = 1
|
192 |
):
|
193 |
import librosa
|
194 |
+
|
195 |
+
onset_indices = librosa.onset.onset_detect(
|
196 |
+
y=sig.clone().to_mono().samples.cpu().numpy()[0, 0],
|
197 |
+
sr=sig.sample_rate,
|
198 |
+
hop_length=interface.codec.hop_length,
|
199 |
+
backtrack=True,
|
200 |
+
)
|
201 |
+
|
202 |
+
# create a mask, set onset
|
203 |
+
mask = torch.ones_like(z)
|
204 |
+
n_timesteps = z.shape[-1]
|
205 |
+
|
206 |
+
for onset_index in onset_indices:
|
207 |
+
onset_index = min(onset_index, n_timesteps - 1)
|
208 |
+
onset_index = max(onset_index, 0)
|
209 |
+
mask[:, :, onset_index - width:onset_index + width] = 0.0
|
210 |
+
|
211 |
+
print(mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
return mask
|
214 |
|
215 |
|
216 |
|
217 |
if __name__ == "__main__":
|
218 |
+
torch.set_printoptions(threshold=10000)
|
219 |
+
|
vampnet/modules/transformer.py
CHANGED
@@ -410,9 +410,7 @@ class TransformerStack(nn.Module):
|
|
410 |
def subsequent_mask(self, size):
|
411 |
return torch.ones(1, size, size).tril().bool()
|
412 |
|
413 |
-
def forward(self, x, x_mask, cond=None, src=None, src_mask=None
|
414 |
-
return_activations: bool = False
|
415 |
-
):
|
416 |
"""Computes a full transformer stack
|
417 |
Parameters
|
418 |
----------
|
@@ -439,8 +437,6 @@ class TransformerStack(nn.Module):
|
|
439 |
encoder_decoder_position_bias = None
|
440 |
|
441 |
# Compute transformer layers
|
442 |
-
if return_activations:
|
443 |
-
activations = []
|
444 |
for layer in self.layers:
|
445 |
x, position_bias, encoder_decoder_position_bias = layer(
|
446 |
x=x,
|
@@ -451,15 +447,8 @@ class TransformerStack(nn.Module):
|
|
451 |
position_bias=position_bias,
|
452 |
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
453 |
)
|
454 |
-
if return_activations:
|
455 |
-
activations.append(x.detach())
|
456 |
|
457 |
-
|
458 |
-
out = self.norm(x) if self.norm is not None else x
|
459 |
-
if return_activations:
|
460 |
-
return out, torch.stack(activations)
|
461 |
-
else:
|
462 |
-
return out
|
463 |
|
464 |
|
465 |
class VampNet(at.ml.BaseModel):
|
@@ -467,7 +456,7 @@ class VampNet(at.ml.BaseModel):
|
|
467 |
self,
|
468 |
n_heads: int = 20,
|
469 |
n_layers: int = 16,
|
470 |
-
r_cond_dim: int =
|
471 |
n_codebooks: int = 9,
|
472 |
n_conditioning_codebooks: int = 0,
|
473 |
latent_dim: int = 8,
|
@@ -478,7 +467,6 @@ class VampNet(at.ml.BaseModel):
|
|
478 |
dropout: float = 0.1
|
479 |
):
|
480 |
super().__init__()
|
481 |
-
assert r_cond_dim == 0, f"r_cond_dim must be 0 (not supported), but got {r_cond_dim}"
|
482 |
self.n_heads = n_heads
|
483 |
self.n_layers = n_layers
|
484 |
self.r_cond_dim = r_cond_dim
|
@@ -525,25 +513,21 @@ class VampNet(at.ml.BaseModel):
|
|
525 |
),
|
526 |
)
|
527 |
|
528 |
-
def forward(self, x,
|
529 |
x = self.embedding(x)
|
530 |
x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1)
|
531 |
|
532 |
-
|
533 |
-
out = self.transformer(x=x, x_mask=x_mask, return_activations=return_activations)
|
534 |
-
if return_activations:
|
535 |
-
out, activations = out
|
536 |
|
|
|
|
|
537 |
out = rearrange(out, "b n d -> b d n")
|
538 |
|
539 |
-
out = self.classifier(out,
|
540 |
|
541 |
out = rearrange(out, "b (p c) t -> b p (t c)", c=self.n_predict_codebooks)
|
542 |
|
543 |
-
|
544 |
-
return out, activations
|
545 |
-
else:
|
546 |
-
return out
|
547 |
|
548 |
def r_embed(self, r, max_positions=10000):
|
549 |
if self.r_cond_dim > 0:
|
@@ -594,23 +578,22 @@ class VampNet(at.ml.BaseModel):
|
|
594 |
self,
|
595 |
codec,
|
596 |
time_steps: int = 300,
|
597 |
-
sampling_steps: int =
|
598 |
start_tokens: Optional[torch.Tensor] = None,
|
599 |
-
sampling_temperature: float = 1.0,
|
600 |
mask: Optional[torch.Tensor] = None,
|
601 |
-
|
602 |
typical_filtering=False,
|
603 |
typical_mass=0.2,
|
604 |
typical_min_tokens=1,
|
605 |
-
top_p=None,
|
606 |
return_signal=True,
|
607 |
-
seed: int = None,
|
608 |
-
sample_cutoff: float = 1.0,
|
609 |
):
|
610 |
-
if seed is not None:
|
611 |
-
at.util.seed(seed)
|
612 |
logging.debug(f"beginning generation with {sampling_steps} steps")
|
613 |
|
|
|
|
|
|
|
|
|
|
|
614 |
|
615 |
|
616 |
#####################
|
@@ -662,6 +645,9 @@ class VampNet(at.ml.BaseModel):
|
|
662 |
for i in range(sampling_steps):
|
663 |
logging.debug(f"step {i} of {sampling_steps}")
|
664 |
|
|
|
|
|
|
|
665 |
# our current schedule step
|
666 |
r = scalar_to_batch_tensor(
|
667 |
(i + 1) / sampling_steps,
|
@@ -676,24 +662,41 @@ class VampNet(at.ml.BaseModel):
|
|
676 |
|
677 |
# infer from latents
|
678 |
# NOTE: this collapses the codebook dimension into the sequence dimension
|
679 |
-
logits = self.forward(latents) # b, prob, seq
|
680 |
logits = logits.permute(0, 2, 1) # b, seq, prob
|
681 |
-
|
|
|
|
|
|
|
|
|
|
|
682 |
|
683 |
logging.debug(f"permuted logits with shape: {logits.shape}")
|
684 |
|
685 |
-
sampled_z, selected_probs = sample_from_logits(
|
686 |
-
logits, sample=(
|
687 |
-
(i / sampling_steps) <= sample_cutoff
|
688 |
-
),
|
689 |
-
temperature=sampling_temperature,
|
690 |
-
typical_filtering=typical_filtering, typical_mass=typical_mass,
|
691 |
-
typical_min_tokens=typical_min_tokens,
|
692 |
-
top_k=None, top_p=top_p, return_probs=True,
|
693 |
-
)
|
694 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
695 |
logging.debug(f"sampled z with shape: {sampled_z.shape}")
|
696 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
697 |
# flatten z_masked and mask, so we can deal with the sampling logic
|
698 |
# we'll unflatten them at the end of the loop for the next forward pass
|
699 |
# remove conditioning codebooks, we'll add them back at the end
|
@@ -730,7 +733,7 @@ class VampNet(at.ml.BaseModel):
|
|
730 |
|
731 |
# get our new mask
|
732 |
mask = mask_by_random_topk(
|
733 |
-
num_to_mask, selected_probs,
|
734 |
)
|
735 |
|
736 |
# update the mask
|
@@ -763,97 +766,8 @@ class VampNet(at.ml.BaseModel):
|
|
763 |
else:
|
764 |
return sampled_z
|
765 |
|
766 |
-
def sample_from_logits(
|
767 |
-
logits,
|
768 |
-
sample: bool = True,
|
769 |
-
temperature: float = 1.0,
|
770 |
-
top_k: int = None,
|
771 |
-
top_p: float = None,
|
772 |
-
typical_filtering: bool = False,
|
773 |
-
typical_mass: float = 0.2,
|
774 |
-
typical_min_tokens: int = 1,
|
775 |
-
return_probs: bool = False
|
776 |
-
):
|
777 |
-
"""Convenience function to sample from a categorial distribution with input as
|
778 |
-
unnormalized logits.
|
779 |
-
|
780 |
-
Parameters
|
781 |
-
----------
|
782 |
-
logits : Tensor[..., vocab_size]
|
783 |
-
config: SamplingConfig
|
784 |
-
The set of hyperparameters to be used for sampling
|
785 |
-
sample : bool, optional
|
786 |
-
Whether to perform multinomial sampling, by default True
|
787 |
-
temperature : float, optional
|
788 |
-
Scaling parameter when multinomial samping, by default 1.0
|
789 |
-
top_k : int, optional
|
790 |
-
Restricts sampling to only `top_k` values acc. to probability,
|
791 |
-
by default None
|
792 |
-
top_p : float, optional
|
793 |
-
Restricts sampling to only those values with cumulative
|
794 |
-
probability = `top_p`, by default None
|
795 |
-
|
796 |
-
Returns
|
797 |
-
-------
|
798 |
-
Tensor[...]
|
799 |
-
Sampled tokens
|
800 |
-
"""
|
801 |
-
shp = logits.shape[:-1]
|
802 |
-
|
803 |
-
if typical_filtering:
|
804 |
-
typical_filter(logits,
|
805 |
-
typical_mass=typical_mass,
|
806 |
-
typical_min_tokens=typical_min_tokens
|
807 |
-
)
|
808 |
-
|
809 |
-
# Apply top_k sampling
|
810 |
-
if top_k is not None:
|
811 |
-
v, _ = logits.topk(top_k)
|
812 |
-
logits[logits < v[..., [-1]]] = -float("inf")
|
813 |
-
|
814 |
-
# Apply top_p (nucleus) sampling
|
815 |
-
if top_p is not None and top_p < 1.0:
|
816 |
-
v, sorted_indices = logits.sort(descending=True)
|
817 |
-
cumulative_probs = v.softmax(dim=-1).cumsum(dim=-1)
|
818 |
-
|
819 |
-
sorted_indices_to_remove = cumulative_probs > top_p
|
820 |
-
# Right shift indices_to_remove to keep 1st token over threshold
|
821 |
-
sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, 0), value=False)[
|
822 |
-
..., :-1
|
823 |
-
]
|
824 |
-
|
825 |
-
# Compute indices_to_remove in unsorted array
|
826 |
-
indices_to_remove = sorted_indices_to_remove.scatter(
|
827 |
-
-1, sorted_indices, sorted_indices_to_remove
|
828 |
-
)
|
829 |
-
|
830 |
-
logits[indices_to_remove] = -float("inf")
|
831 |
-
|
832 |
-
# Perform multinomial sampling after normalizing logits
|
833 |
-
probs = (
|
834 |
-
F.softmax(logits / temperature, dim=-1)
|
835 |
-
if temperature > 0
|
836 |
-
else logits.softmax(dim=-1)
|
837 |
-
)
|
838 |
-
token = (
|
839 |
-
probs.view(-1, probs.size(-1)).multinomial(1).squeeze(1).view(*shp)
|
840 |
-
if sample
|
841 |
-
else logits.argmax(-1)
|
842 |
-
)
|
843 |
-
|
844 |
-
if return_probs:
|
845 |
-
token_probs = probs.take_along_dim(token.unsqueeze(-1), dim=-1).squeeze(-1)
|
846 |
-
return token, token_probs
|
847 |
-
else:
|
848 |
-
return token
|
849 |
-
|
850 |
-
|
851 |
|
852 |
-
def mask_by_random_topk(
|
853 |
-
num_to_mask: int,
|
854 |
-
probs: torch.Tensor,
|
855 |
-
temperature: float = 1.0,
|
856 |
-
):
|
857 |
"""
|
858 |
Args:
|
859 |
num_to_mask (int): number of tokens to mask
|
@@ -866,8 +780,7 @@ def mask_by_random_topk(
|
|
866 |
logging.debug(f"temperature: {temperature}")
|
867 |
logging.debug("")
|
868 |
|
869 |
-
|
870 |
-
confidence = torch.log(probs) + temperature * noise
|
871 |
logging.debug(f"confidence shape: {confidence.shape}")
|
872 |
|
873 |
sorted_confidence, sorted_idx = confidence.sort(dim=-1)
|
@@ -937,7 +850,7 @@ if __name__ == "__main__":
|
|
937 |
z_mask_latent = torch.rand(
|
938 |
batch_size, model.latent_dim * model.n_codebooks, seq_len
|
939 |
).to(device)
|
940 |
-
z_hat = model(z_mask_latent)
|
941 |
|
942 |
pred = z_hat.argmax(dim=1)
|
943 |
pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks)
|
|
|
410 |
def subsequent_mask(self, size):
|
411 |
return torch.ones(1, size, size).tril().bool()
|
412 |
|
413 |
+
def forward(self, x, x_mask, cond=None, src=None, src_mask=None):
|
|
|
|
|
414 |
"""Computes a full transformer stack
|
415 |
Parameters
|
416 |
----------
|
|
|
437 |
encoder_decoder_position_bias = None
|
438 |
|
439 |
# Compute transformer layers
|
|
|
|
|
440 |
for layer in self.layers:
|
441 |
x, position_bias, encoder_decoder_position_bias = layer(
|
442 |
x=x,
|
|
|
447 |
position_bias=position_bias,
|
448 |
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
449 |
)
|
|
|
|
|
450 |
|
451 |
+
return self.norm(x) if self.norm is not None else x
|
|
|
|
|
|
|
|
|
|
|
452 |
|
453 |
|
454 |
class VampNet(at.ml.BaseModel):
|
|
|
456 |
self,
|
457 |
n_heads: int = 20,
|
458 |
n_layers: int = 16,
|
459 |
+
r_cond_dim: int = 64,
|
460 |
n_codebooks: int = 9,
|
461 |
n_conditioning_codebooks: int = 0,
|
462 |
latent_dim: int = 8,
|
|
|
467 |
dropout: float = 0.1
|
468 |
):
|
469 |
super().__init__()
|
|
|
470 |
self.n_heads = n_heads
|
471 |
self.n_layers = n_layers
|
472 |
self.r_cond_dim = r_cond_dim
|
|
|
513 |
),
|
514 |
)
|
515 |
|
516 |
+
def forward(self, x, cond):
|
517 |
x = self.embedding(x)
|
518 |
x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1)
|
519 |
|
520 |
+
cond = self.r_embed(cond)
|
|
|
|
|
|
|
521 |
|
522 |
+
x = rearrange(x, "b d n -> b n d")
|
523 |
+
out = self.transformer(x=x, x_mask=x_mask, cond=cond)
|
524 |
out = rearrange(out, "b n d -> b d n")
|
525 |
|
526 |
+
out = self.classifier(out, cond)
|
527 |
|
528 |
out = rearrange(out, "b (p c) t -> b p (t c)", c=self.n_predict_codebooks)
|
529 |
|
530 |
+
return out
|
|
|
|
|
|
|
531 |
|
532 |
def r_embed(self, r, max_positions=10000):
|
533 |
if self.r_cond_dim > 0:
|
|
|
578 |
self,
|
579 |
codec,
|
580 |
time_steps: int = 300,
|
581 |
+
sampling_steps: int = 24,
|
582 |
start_tokens: Optional[torch.Tensor] = None,
|
|
|
583 |
mask: Optional[torch.Tensor] = None,
|
584 |
+
temperature: Union[float, Tuple[float, float]] = 2.5,
|
585 |
typical_filtering=False,
|
586 |
typical_mass=0.2,
|
587 |
typical_min_tokens=1,
|
|
|
588 |
return_signal=True,
|
|
|
|
|
589 |
):
|
|
|
|
|
590 |
logging.debug(f"beginning generation with {sampling_steps} steps")
|
591 |
|
592 |
+
#####################
|
593 |
+
# resolve temperature #
|
594 |
+
#####################
|
595 |
+
assert isinstance(temperature, float)
|
596 |
+
logging.debug(f"temperature: {temperature}")
|
597 |
|
598 |
|
599 |
#####################
|
|
|
645 |
for i in range(sampling_steps):
|
646 |
logging.debug(f"step {i} of {sampling_steps}")
|
647 |
|
648 |
+
# our current temperature
|
649 |
+
logging.debug(f"temperature: {temperature}")
|
650 |
+
|
651 |
# our current schedule step
|
652 |
r = scalar_to_batch_tensor(
|
653 |
(i + 1) / sampling_steps,
|
|
|
662 |
|
663 |
# infer from latents
|
664 |
# NOTE: this collapses the codebook dimension into the sequence dimension
|
665 |
+
logits = self.forward(latents, r) # b, prob, seq
|
666 |
logits = logits.permute(0, 2, 1) # b, seq, prob
|
667 |
+
if typical_filtering:
|
668 |
+
typical_filter(logits,
|
669 |
+
typical_mass=typical_mass,
|
670 |
+
typical_min_tokens=typical_min_tokens
|
671 |
+
)
|
672 |
+
|
673 |
|
674 |
logging.debug(f"permuted logits with shape: {logits.shape}")
|
675 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
676 |
|
677 |
+
# logits2probs
|
678 |
+
probs = torch.softmax(logits, dim=-1)
|
679 |
+
logging.debug(f"computed probs with shape: {probs.shape}")
|
680 |
+
|
681 |
+
|
682 |
+
# sample from logits with multinomial sampling
|
683 |
+
b = probs.shape[0]
|
684 |
+
probs = rearrange(probs, "b seq prob -> (b seq) prob")
|
685 |
+
|
686 |
+
sampled_z = torch.multinomial(probs, 1).squeeze(-1)
|
687 |
+
|
688 |
+
sampled_z = rearrange(sampled_z, "(b seq)-> b seq", b=b)
|
689 |
+
probs = rearrange(probs, "(b seq) prob -> b seq prob", b=b)
|
690 |
logging.debug(f"sampled z with shape: {sampled_z.shape}")
|
691 |
|
692 |
+
# get the confidences: which tokens did we sample?
|
693 |
+
selected_probs = (
|
694 |
+
torch.take_along_dim(
|
695 |
+
probs, sampled_z.long().unsqueeze(-1),
|
696 |
+
dim=-1
|
697 |
+
).squeeze(-1)
|
698 |
+
)
|
699 |
+
|
700 |
# flatten z_masked and mask, so we can deal with the sampling logic
|
701 |
# we'll unflatten them at the end of the loop for the next forward pass
|
702 |
# remove conditioning codebooks, we'll add them back at the end
|
|
|
733 |
|
734 |
# get our new mask
|
735 |
mask = mask_by_random_topk(
|
736 |
+
num_to_mask, selected_probs, temperature * (1-r)
|
737 |
)
|
738 |
|
739 |
# update the mask
|
|
|
766 |
else:
|
767 |
return sampled_z
|
768 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
769 |
|
770 |
+
def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: float = 1.0):
|
|
|
|
|
|
|
|
|
771 |
"""
|
772 |
Args:
|
773 |
num_to_mask (int): number of tokens to mask
|
|
|
780 |
logging.debug(f"temperature: {temperature}")
|
781 |
logging.debug("")
|
782 |
|
783 |
+
confidence = torch.log(probs) + temperature * gumbel_noise_like(probs)
|
|
|
784 |
logging.debug(f"confidence shape: {confidence.shape}")
|
785 |
|
786 |
sorted_confidence, sorted_idx = confidence.sort(dim=-1)
|
|
|
850 |
z_mask_latent = torch.rand(
|
851 |
batch_size, model.latent_dim * model.n_codebooks, seq_len
|
852 |
).to(device)
|
853 |
+
z_hat = model(z_mask_latent, r)
|
854 |
|
855 |
pred = z_hat.argmax(dim=1)
|
856 |
pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks)
|