This view is limited to 50 files because it contains too many changes.  See the raw diff here.
Files changed (50) hide show
  1. .gitignore +3 -7
  2. README.md +5 -16
  3. app.py +43 -229
  4. conf/generated-v0/berta-goldman-speech/c2f.yml +15 -0
  5. conf/generated-v0/berta-goldman-speech/coarse.yml +8 -0
  6. conf/generated-v0/berta-goldman-speech/interface.yml +5 -0
  7. conf/generated-v0/gamelan-xeno-canto/c2f.yml +17 -0
  8. conf/generated-v0/gamelan-xeno-canto/coarse.yml +10 -0
  9. conf/generated-v0/gamelan-xeno-canto/interface.yml +6 -0
  10. conf/generated-v0/nasralla/c2f.yml +15 -0
  11. conf/generated-v0/nasralla/coarse.yml +8 -0
  12. conf/generated-v0/nasralla/interface.yml +5 -0
  13. conf/generated/breaks-steps/c2f.yml +15 -0
  14. conf/generated/breaks-steps/coarse.yml +8 -0
  15. conf/generated/breaks-steps/interface.yml +7 -0
  16. conf/generated/bulgarian-tv-choir/c2f.yml +15 -0
  17. conf/generated/bulgarian-tv-choir/coarse.yml +8 -0
  18. conf/generated/bulgarian-tv-choir/interface.yml +7 -0
  19. conf/generated/dariacore/c2f.yml +15 -0
  20. conf/generated/dariacore/coarse.yml +8 -0
  21. conf/generated/dariacore/interface.yml +7 -0
  22. conf/generated/musica-bolero-marimba/c2f.yml +18 -0
  23. conf/generated/musica-bolero-marimba/coarse.yml +11 -0
  24. conf/generated/musica-bolero-marimba/interface.yml +8 -0
  25. conf/generated/panchos/c2f.yml +15 -0
  26. conf/generated/panchos/coarse.yml +8 -0
  27. conf/generated/panchos/interface.yml +7 -0
  28. conf/generated/titi-monkey/c2f.yml +15 -0
  29. conf/generated/titi-monkey/coarse.yml +8 -0
  30. conf/generated/titi-monkey/interface.yml +7 -0
  31. conf/generated/xeno-canto/c2f.yml +15 -0
  32. conf/generated/xeno-canto/coarse.yml +8 -0
  33. conf/generated/xeno-canto/interface.yml +7 -0
  34. conf/lora/birds.yml +10 -0
  35. conf/lora/birdss.yml +12 -0
  36. conf/lora/constructions.yml +10 -0
  37. conf/lora/ella-baila-sola.yml +10 -0
  38. conf/lora/gas-station.yml +10 -0
  39. conf/lora/lora-is-this-charlie-parker.yml +10 -0
  40. conf/lora/lora.yml +7 -7
  41. conf/lora/underworld.yml +10 -0
  42. conf/lora/xeno-canto/c2f.yml +21 -0
  43. conf/lora/xeno-canto/coarse.yml +10 -0
  44. conf/vampnet-musdb-drums.yml +22 -0
  45. conf/vampnet.yml +19 -9
  46. requirements.txt +2 -4
  47. scripts/exp/fine_tune.py +7 -6
  48. scripts/exp/train.py +425 -483
  49. scripts/utils/{data/augment.py → augment.py} +24 -38
  50. scripts/utils/gtzan_embeddings.py +0 -263
.gitignore CHANGED
@@ -175,14 +175,10 @@ lyrebird-audio-codec
175
  samples-*/**
176
 
177
  gradio-outputs/
178
- models/
179
  samples*/
180
  models-all/
181
  models.zip
 
 
 
182
  .git-old
183
- conf/generated/*
184
- runs*/
185
-
186
-
187
- gtzan.zip
188
- .gtzan_emb_cache
 
175
  samples-*/**
176
 
177
  gradio-outputs/
 
178
  samples*/
179
  models-all/
180
  models.zip
181
+ audiotools/
182
+ descript-audio-codec/
183
+ # *.pth
184
  .git-old
 
 
 
 
 
 
README.md CHANGED
@@ -7,27 +7,16 @@ sdk: gradio
7
  sdk_version: 3.36.1
8
  app_file: app.py
9
  pinned: false
10
- python_version: 3.9
11
  ---
12
 
13
  # VampNet
14
 
15
- This repository contains recipes for training generative music models on top of the Descript Audio Codec.
16
-
17
- ## try `unloop`
18
- you can try vampnet in a co-creative looper called unloop. see this link: https://github.com/hugofloresgarcia/unloop
19
 
20
  # Setting up
21
 
22
- **Requires Python 3.9**.
23
-
24
- you'll need a Python 3.9 environment to run VampNet. This is due to a [known issue with madmom](https://github.com/hugofloresgarcia/vampnet/issues/15).
25
-
26
- (for example, using conda)
27
- ```bash
28
- conda create -n vampnet python=3.9
29
- conda activate vampnet
30
- ```
31
 
32
 
33
  install VampNet
@@ -46,7 +35,7 @@ Config files are stored in the `conf/` folder.
46
  ### Licensing for Pretrained Models:
47
  The weights for the models are licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml). Likewise, any VampNet models fine-tuned on the pretrained models are also licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml).
48
 
49
- Download the pretrained models from [this link](https://zenodo.org/record/8136629). Then, extract the models to the `models/` folder.
50
 
51
 
52
  # Usage
@@ -100,7 +89,7 @@ python scripts/exp/train.py --args.load conf/<fine_tune_name>/c2f.yml
100
 
101
  launch the interface:
102
  ```bash
103
- python app.py --args.load conf/generated/<fine_tune_name>/interface.yml
104
  ```
105
 
106
 
 
7
  sdk_version: 3.36.1
8
  app_file: app.py
9
  pinned: false
10
+ duplicated_from: hugggof/vampnet
11
  ---
12
 
13
  # VampNet
14
 
15
+ This repository contains recipes for training generative music models on top of the Lyrebird Audio Codec.
 
 
 
16
 
17
  # Setting up
18
 
19
+ Requires Python 3.9 or later.
 
 
 
 
 
 
 
 
20
 
21
 
22
  install VampNet
 
35
  ### Licensing for Pretrained Models:
36
  The weights for the models are licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml). Likewise, any VampNet models fine-tuned on the pretrained models are also licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml).
37
 
38
+ Download the pretrained models from [this link](https://zenodo.org/record/8136545). Then, extract the models to the `models/` folder.
39
 
40
 
41
  # Usage
 
89
 
90
  launch the interface:
91
  ```bash
92
+ python demo.py --args.load conf/generated/<fine_tune_name>/interface.yml
93
  ```
94
 
95
 
app.py CHANGED
@@ -1,12 +1,3 @@
1
- # huggingface space exclusive
2
- import os
3
-
4
- # print("installing pyharp")
5
- # os.system('pip install "pyharp@git+https://github.com/audacitorch/pyharp.git"')
6
- # print("installing madmom")
7
- os.system('pip install cython')
8
- os.system('pip install madmom')
9
-
10
  from pathlib import Path
11
  from typing import Tuple
12
  import yaml
@@ -24,38 +15,27 @@ import gradio as gr
24
  from vampnet.interface import Interface
25
  from vampnet import mask as pmask
26
 
27
- from pyharp import ModelCard, build_endpoint
28
-
29
-
30
-
31
- # loader = AudioLoader()
32
  # AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
33
 
34
- conf = argbind.parse_args()
35
-
36
-
37
- from torch_pitch_shift import pitch_shift, get_fast_shifts
38
- def shift_pitch(signal, interval: int):
39
- signal.samples = pitch_shift(
40
- signal.samples,
41
- shift=interval,
42
- sample_rate=signal.sample_rate
43
- )
44
- return signal
45
-
46
- def load_interface():
47
- interface = Interface(
48
- coarse_ckpt="./models/vampnet/coarse.pth",
49
- coarse2fine_ckpt="./models/vampnet/c2f.pth",
50
- codec_ckpt="./models/vampnet/codec.pth",
51
- wavebeat_ckpt="./models/wavebeat.pth",
52
- device="cuda" if torch.cuda.is_available() else "cpu",
53
- )
54
- return interface
55
-
56
 
57
- interface = load_interface()
 
58
 
 
 
 
 
 
 
 
59
 
60
  OUT_DIR = Path("gradio-outputs")
61
  OUT_DIR.mkdir(exist_ok=True, parents=True)
@@ -70,7 +50,7 @@ def load_audio(file):
70
  )
71
  sig = interface.preprocess(sig)
72
 
73
- out_dir = OUT_DIR / "tmp" / str(uuid.uuid4())
74
  out_dir.mkdir(parents=True, exist_ok=True)
75
  sig.write(out_dir / "input.wav")
76
  return sig.path_to_file
@@ -88,10 +68,6 @@ def _vamp(data, return_mask=False):
88
  out_dir = OUT_DIR / str(uuid.uuid4())
89
  out_dir.mkdir()
90
  sig = at.AudioSignal(data[input_audio])
91
- sig = interface.preprocess(sig)
92
-
93
- if data[pitch_shift_amt] != 0:
94
- sig = shift_pitch(sig, data[pitch_shift_amt])
95
 
96
  z = interface.encode(sig)
97
 
@@ -131,58 +107,24 @@ def _vamp(data, return_mask=False):
131
  mask = pmask.codebook_unmask(mask, ncc)
132
 
133
 
134
- print(f"dropout {data[dropout]}")
135
- print(f"masktemp {data[masktemp]}")
136
- print(f"sampletemp {data[sampletemp]}")
137
- print(f"top_p {data[top_p]}")
138
- print(f"prefix_s {data[prefix_s]}")
139
- print(f"suffix_s {data[suffix_s]}")
140
- print(f"rand_mask_intensity {data[rand_mask_intensity]}")
141
- print(f"num_steps {data[num_steps]}")
142
- print(f"periodic_p {data[periodic_p]}")
143
- print(f"periodic_w {data[periodic_w]}")
144
- print(f"n_conditioning_codebooks {data[n_conditioning_codebooks]}")
145
- print(f"use_coarse2fine {data[use_coarse2fine]}")
146
- print(f"onset_mask_width {data[onset_mask_width]}")
147
- print(f"beat_mask_width {data[beat_mask_width]}")
148
- print(f"beat_mask_downbeats {data[beat_mask_downbeats]}")
149
- print(f"stretch_factor {data[stretch_factor]}")
150
- print(f"seed {data[seed]}")
151
- print(f"pitch_shift_amt {data[pitch_shift_amt]}")
152
- print(f"sample_cutoff {data[sample_cutoff]}")
153
-
154
-
155
- _top_p = data[top_p] if data[top_p] > 0 else None
156
  # save the mask as a txt file
157
  np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
158
 
159
- _seed = data[seed] if data[seed] > 0 else None
160
  zv, mask_z = interface.coarse_vamp(
161
  z,
162
  mask=mask,
163
  sampling_steps=data[num_steps],
164
- mask_temperature=data[masktemp]*10,
165
- sampling_temperature=data[sampletemp],
166
  return_mask=True,
167
  typical_filtering=data[typical_filtering],
168
  typical_mass=data[typical_mass],
169
  typical_min_tokens=data[typical_min_tokens],
170
- top_p=_top_p,
171
  gen_fn=interface.coarse.generate,
172
- seed=_seed,
173
- sample_cutoff=data[sample_cutoff],
174
  )
175
 
176
  if use_coarse2fine:
177
- zv = interface.coarse_to_fine(
178
- zv,
179
- mask_temperature=data[masktemp]*10,
180
- sampling_temperature=data[sampletemp],
181
- mask=mask,
182
- sampling_steps=data[num_steps] // 2,
183
- sample_cutoff=data[sample_cutoff],
184
- seed=_seed,
185
- )
186
 
187
  sig = interface.to_signal(zv).cpu()
188
  print("done")
@@ -215,9 +157,7 @@ def save_vamp(data):
215
  sig_out.write(out_dir / "output.wav")
216
 
217
  _data = {
218
- "masktemp": data[masktemp],
219
- "sampletemp": data[sampletemp],
220
- "top_p": data[top_p],
221
  "prefix_s": data[prefix_s],
222
  "suffix_s": data[suffix_s],
223
  "rand_mask_intensity": data[rand_mask_intensity],
@@ -228,8 +168,6 @@ def save_vamp(data):
228
  "n_conditioning_codebooks": data[n_conditioning_codebooks],
229
  "use_coarse2fine": data[use_coarse2fine],
230
  "stretch_factor": data[stretch_factor],
231
- "seed": data[seed],
232
- "samplecutoff": data[sample_cutoff],
233
  }
234
 
235
  # save with yaml
@@ -245,54 +183,13 @@ def save_vamp(data):
245
  return f"saved! your save code is {out_dir.stem}", zip_path
246
 
247
 
248
- def harp_vamp(_input_audio, _beat_mask_width, _sampletemp):
249
-
250
- out_dir = OUT_DIR / str(uuid.uuid4())
251
- out_dir.mkdir()
252
- sig = at.AudioSignal(_input_audio)
253
- sig = interface.preprocess(sig)
254
-
255
- z = interface.encode(sig)
256
-
257
- # build the mask
258
- mask = pmask.linear_random(z, 1.0)
259
- if _beat_mask_width > 0:
260
- beat_mask = interface.make_beat_mask(
261
- sig,
262
- after_beat_s=(_beat_mask_width/1000),
263
- )
264
- mask = pmask.mask_and(mask, beat_mask)
265
-
266
- # save the mask as a txt file
267
- zv, mask_z = interface.coarse_vamp(
268
- z,
269
- mask=mask,
270
- sampling_temperature=_sampletemp,
271
- return_mask=True,
272
- gen_fn=interface.coarse.generate,
273
- )
274
-
275
-
276
- zv = interface.coarse_to_fine(
277
- zv,
278
- sampling_temperature=_sampletemp,
279
- mask=mask,
280
- )
281
-
282
- sig = interface.to_signal(zv).cpu()
283
- print("done")
284
-
285
- sig.write(out_dir / "output.wav")
286
-
287
- return sig.path_to_file
288
-
289
  with gr.Blocks() as demo:
290
 
291
  with gr.Row():
292
  with gr.Column():
293
- gr.Markdown("# VampNet Audio Vamping")
294
  gr.Markdown("""## Description:
295
- This is a demo of the VampNet, a generative audio model that transforms the input audio based on the chosen settings.
296
  You can control the extent and nature of variation with a set of manual controls and presets.
297
  Use this interface to experiment with different mask settings and explore the audio outputs.
298
  """)
@@ -300,8 +197,8 @@ with gr.Blocks() as demo:
300
  gr.Markdown("""
301
  ## Instructions:
302
  1. You can start by uploading some audio, or by loading the example audio.
303
- 2. Choose a preset for the vamp operation, or manually adjust the controls to customize the mask settings.
304
- 3. Click the "generate (vamp)!!!" button to apply the vamp operation. Listen to the output audio.
305
  4. Optionally, you can add some notes and save the result.
306
  5. You can also use the output as the new input and continue experimenting!
307
  """)
@@ -352,25 +249,19 @@ with gr.Blocks() as demo:
352
  "beat_mask_downbeats": False,
353
  },
354
  "slight periodic variation": {
355
- "periodic_p": 5,
356
- "onset_mask_width": 5,
357
- "beat_mask_width": 0,
358
- "beat_mask_downbeats": False,
359
- },
360
- "moderate periodic variation": {
361
- "periodic_p": 13,
362
- "onset_mask_width": 5,
363
  "beat_mask_width": 0,
364
  "beat_mask_downbeats": False,
365
  },
366
  "strong periodic variation": {
367
- "periodic_p": 17,
368
  "onset_mask_width": 5,
369
  "beat_mask_width": 0,
370
  "beat_mask_downbeats": False,
371
  },
372
  "very strong periodic variation": {
373
- "periodic_p": 21,
374
  "onset_mask_width": 5,
375
  "beat_mask_width": 0,
376
  "beat_mask_downbeats": False,
@@ -378,15 +269,9 @@ with gr.Blocks() as demo:
378
  "beat-driven variation": {
379
  "periodic_p": 0,
380
  "onset_mask_width": 0,
381
- "beat_mask_width": 50,
382
  "beat_mask_downbeats": False,
383
  },
384
- "beat-driven variation (downbeats only)": {
385
- "periodic_p": 0,
386
- "onset_mask_width": 0,
387
- "beat_mask_width": 50,
388
- "beat_mask_downbeats": True,
389
- },
390
  "beat-driven variation (downbeats only, strong)": {
391
  "periodic_p": 0,
392
  "onset_mask_width": 0,
@@ -408,20 +293,20 @@ with gr.Blocks() as demo:
408
  minimum=0,
409
  maximum=128,
410
  step=1,
411
- value=3,
412
  )
413
 
414
 
415
  onset_mask_width = gr.Slider(
416
  label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
417
  minimum=0,
418
- maximum=100,
419
  step=1,
420
  value=5,
421
  )
422
 
423
  beat_mask_width = gr.Slider(
424
- label="beat prompt (ms)",
425
  minimum=0,
426
  maximum=200,
427
  value=0,
@@ -433,14 +318,6 @@ with gr.Blocks() as demo:
433
 
434
 
435
  with gr.Accordion("extras ", open=False):
436
- pitch_shift_amt = gr.Slider(
437
- label="pitch shift amount (semitones)",
438
- minimum=-12,
439
- maximum=12,
440
- step=1,
441
- value=0,
442
- )
443
-
444
  rand_mask_intensity = gr.Slider(
445
  label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
446
  minimum=0.0,
@@ -500,34 +377,21 @@ with gr.Blocks() as demo:
500
  value=0.0
501
  )
502
 
503
- masktemp = gr.Slider(
504
- label="mask temperature",
505
  minimum=0.0,
506
- maximum=100.0,
507
- value=1.5
508
- )
509
- sampletemp = gr.Slider(
510
- label="sample temperature",
511
- minimum=0.1,
512
  maximum=10.0,
513
- value=1.0,
514
- step=0.001
515
  )
516
-
517
 
518
 
519
  with gr.Accordion("sampling settings", open=False):
520
- top_p = gr.Slider(
521
- label="top p (0.0 = off)",
522
- minimum=0.0,
523
- maximum=1.0,
524
- value=0.9
525
- )
526
  typical_filtering = gr.Checkbox(
527
  label="typical filtering ",
528
  value=False
529
  )
530
- typical_mass = gr.Slider(
531
  label="typical mass (should probably stay between 0.1 and 0.5)",
532
  minimum=0.01,
533
  maximum=0.99,
@@ -540,18 +404,10 @@ with gr.Blocks() as demo:
540
  step=1,
541
  value=64
542
  )
543
- sample_cutoff = gr.Slider(
544
- label="sample cutoff",
545
- minimum=0.0,
546
- maximum=1.0,
547
- value=0.5,
548
- step=0.01
549
- )
550
 
551
  use_coarse2fine = gr.Checkbox(
552
  label="use coarse2fine",
553
- value=True,
554
- visible=False
555
  )
556
 
557
  num_steps = gr.Slider(
@@ -571,24 +427,8 @@ with gr.Blocks() as demo:
571
  )
572
 
573
 
574
- seed = gr.Number(
575
- label="seed (0 for random)",
576
- value=0,
577
- precision=0,
578
- )
579
-
580
-
581
-
582
  # mask settings
583
  with gr.Column():
584
-
585
- # lora_choice = gr.Dropdown(
586
- # label="lora choice",
587
- # choices=list(loras.keys()),
588
- # value=LORA_NONE,
589
- # visible=False
590
- # )
591
-
592
  vamp_button = gr.Button("generate (vamp)!!!")
593
  output_audio = gr.Audio(
594
  label="output audio",
@@ -614,9 +454,7 @@ with gr.Blocks() as demo:
614
  _inputs = {
615
  input_audio,
616
  num_steps,
617
- masktemp,
618
- sampletemp,
619
- top_p,
620
  prefix_s, suffix_s,
621
  rand_mask_intensity,
622
  periodic_p, periodic_w,
@@ -629,11 +467,7 @@ with gr.Blocks() as demo:
629
  typical_mass,
630
  typical_min_tokens,
631
  beat_mask_width,
632
- beat_mask_downbeats,
633
- seed,
634
- # lora_choice,
635
- pitch_shift_amt,
636
- sample_cutoff
637
  }
638
 
639
  # connect widgets
@@ -663,24 +497,4 @@ with gr.Blocks() as demo:
663
  outputs=[thank_you, download_file]
664
  )
665
 
666
- # harp stuff
667
- harp_inputs = [
668
- input_audio,
669
- beat_mask_width,
670
- sampletemp,
671
- ]
672
-
673
- build_endpoint(
674
- inputs=harp_inputs,
675
- output=output_audio,
676
- process_fn=harp_vamp,
677
- card=ModelCard(
678
- name="vampnet",
679
- description="Generate variations on music input, based on small prompts around the beat. NOTE: vampnet's has a maximum context length of 10 seconds. Please split all audio clips into 10 second chunks, or processing will result in an error. ",
680
- author="Hugo Flores García",
681
- tags=["music", "generative"]
682
- ),
683
- visible=False
684
- )
685
-
686
- demo.launch()
 
 
 
 
 
 
 
 
 
 
1
  from pathlib import Path
2
  from typing import Tuple
3
  import yaml
 
15
  from vampnet.interface import Interface
16
  from vampnet import mask as pmask
17
 
18
+ # Interface = argbind.bind(Interface)
 
 
 
 
19
  # AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
20
 
21
+ interface = Interface(
22
+ coarse_ckpt="./models/vampnet/coarse.pth",
23
+ coarse2fine_ckpt="./models/vampnet/c2f.pth",
24
+ codec_ckpt="./models/vampnet/codec.pth",
25
+ wavebeat_ckpt="./models/wavebeat.pth",
26
+ device="cuda" if torch.cuda.is_available() else "cpu",
27
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ # loader = AudioLoader()
30
+ print(f"interface device is {interface.device}")
31
 
32
+ # dataset = at.data.datasets.AudioDataset(
33
+ # loader,
34
+ # sample_rate=interface.codec.sample_rate,
35
+ # duration=interface.coarse.chunk_size_s,
36
+ # n_examples=5000,
37
+ # without_replacement=True,
38
+ # )
39
 
40
  OUT_DIR = Path("gradio-outputs")
41
  OUT_DIR.mkdir(exist_ok=True, parents=True)
 
50
  )
51
  sig = interface.preprocess(sig)
52
 
53
+ out_dir = OUT_DIR / str(uuid.uuid4())
54
  out_dir.mkdir(parents=True, exist_ok=True)
55
  sig.write(out_dir / "input.wav")
56
  return sig.path_to_file
 
68
  out_dir = OUT_DIR / str(uuid.uuid4())
69
  out_dir.mkdir()
70
  sig = at.AudioSignal(data[input_audio])
 
 
 
 
71
 
72
  z = interface.encode(sig)
73
 
 
107
  mask = pmask.codebook_unmask(mask, ncc)
108
 
109
 
110
+ print(f"created mask with: linear random {data[rand_mask_intensity]}, inpaint {data[prefix_s]}:{data[suffix_s]}, periodic {data[periodic_p]}:{data[periodic_w]}, dropout {data[dropout]}, codebook unmask {ncc}, onset mask {data[onset_mask_width]}, num steps {data[num_steps]}, init temp {data[temp]}, use coarse2fine {data[use_coarse2fine]}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  # save the mask as a txt file
112
  np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
113
 
 
114
  zv, mask_z = interface.coarse_vamp(
115
  z,
116
  mask=mask,
117
  sampling_steps=data[num_steps],
118
+ temperature=float(data[temp]*10),
 
119
  return_mask=True,
120
  typical_filtering=data[typical_filtering],
121
  typical_mass=data[typical_mass],
122
  typical_min_tokens=data[typical_min_tokens],
 
123
  gen_fn=interface.coarse.generate,
 
 
124
  )
125
 
126
  if use_coarse2fine:
127
+ zv = interface.coarse_to_fine(zv, temperature=data[temp])
 
 
 
 
 
 
 
 
128
 
129
  sig = interface.to_signal(zv).cpu()
130
  print("done")
 
157
  sig_out.write(out_dir / "output.wav")
158
 
159
  _data = {
160
+ "temp": data[temp],
 
 
161
  "prefix_s": data[prefix_s],
162
  "suffix_s": data[suffix_s],
163
  "rand_mask_intensity": data[rand_mask_intensity],
 
168
  "n_conditioning_codebooks": data[n_conditioning_codebooks],
169
  "use_coarse2fine": data[use_coarse2fine],
170
  "stretch_factor": data[stretch_factor],
 
 
171
  }
172
 
173
  # save with yaml
 
183
  return f"saved! your save code is {out_dir.stem}", zip_path
184
 
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  with gr.Blocks() as demo:
187
 
188
  with gr.Row():
189
  with gr.Column():
190
+ gr.Markdown("# VampNet")
191
  gr.Markdown("""## Description:
192
+ This is a demo of VampNet, a masked generative music model capable of doing music variations.
193
  You can control the extent and nature of variation with a set of manual controls and presets.
194
  Use this interface to experiment with different mask settings and explore the audio outputs.
195
  """)
 
197
  gr.Markdown("""
198
  ## Instructions:
199
  1. You can start by uploading some audio, or by loading the example audio.
200
+ 2. Choose a preset for the vamp operation, or manually adjust the controls to customize the mask settings. Click the load preset button.
201
+ 3. Click the "generate (vamp)!!!" button to generate audio. Listen to the output audio, and the masked audio to hear the mask hints.
202
  4. Optionally, you can add some notes and save the result.
203
  5. You can also use the output as the new input and continue experimenting!
204
  """)
 
249
  "beat_mask_downbeats": False,
250
  },
251
  "slight periodic variation": {
252
+ "periodic_p": 7,
253
+ "onset_mask_width": 0,
 
 
 
 
 
 
254
  "beat_mask_width": 0,
255
  "beat_mask_downbeats": False,
256
  },
257
  "strong periodic variation": {
258
+ "periodic_p": 13,
259
  "onset_mask_width": 5,
260
  "beat_mask_width": 0,
261
  "beat_mask_downbeats": False,
262
  },
263
  "very strong periodic variation": {
264
+ "periodic_p": 17,
265
  "onset_mask_width": 5,
266
  "beat_mask_width": 0,
267
  "beat_mask_downbeats": False,
 
269
  "beat-driven variation": {
270
  "periodic_p": 0,
271
  "onset_mask_width": 0,
272
+ "beat_mask_width": 20,
273
  "beat_mask_downbeats": False,
274
  },
 
 
 
 
 
 
275
  "beat-driven variation (downbeats only, strong)": {
276
  "periodic_p": 0,
277
  "onset_mask_width": 0,
 
293
  minimum=0,
294
  maximum=128,
295
  step=1,
296
+ value=13,
297
  )
298
 
299
 
300
  onset_mask_width = gr.Slider(
301
  label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
302
  minimum=0,
303
+ maximum=20,
304
  step=1,
305
  value=5,
306
  )
307
 
308
  beat_mask_width = gr.Slider(
309
+ label="beat mask width (in milliseconds)",
310
  minimum=0,
311
  maximum=200,
312
  value=0,
 
318
 
319
 
320
  with gr.Accordion("extras ", open=False):
 
 
 
 
 
 
 
 
321
  rand_mask_intensity = gr.Slider(
322
  label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
323
  minimum=0.0,
 
377
  value=0.0
378
  )
379
 
380
+ temp = gr.Slider(
381
+ label="temperature",
382
  minimum=0.0,
 
 
 
 
 
 
383
  maximum=10.0,
384
+ value=1.8
 
385
  )
386
+
387
 
388
 
389
  with gr.Accordion("sampling settings", open=False):
 
 
 
 
 
 
390
  typical_filtering = gr.Checkbox(
391
  label="typical filtering ",
392
  value=False
393
  )
394
+ typical_mass = gr.Slider(
395
  label="typical mass (should probably stay between 0.1 and 0.5)",
396
  minimum=0.01,
397
  maximum=0.99,
 
404
  step=1,
405
  value=64
406
  )
 
 
 
 
 
 
 
407
 
408
  use_coarse2fine = gr.Checkbox(
409
  label="use coarse2fine",
410
+ value=True
 
411
  )
412
 
413
  num_steps = gr.Slider(
 
427
  )
428
 
429
 
 
 
 
 
 
 
 
 
430
  # mask settings
431
  with gr.Column():
 
 
 
 
 
 
 
 
432
  vamp_button = gr.Button("generate (vamp)!!!")
433
  output_audio = gr.Audio(
434
  label="output audio",
 
454
  _inputs = {
455
  input_audio,
456
  num_steps,
457
+ temp,
 
 
458
  prefix_s, suffix_s,
459
  rand_mask_intensity,
460
  periodic_p, periodic_w,
 
467
  typical_mass,
468
  typical_min_tokens,
469
  beat_mask_width,
470
+ beat_mask_downbeats
 
 
 
 
471
  }
472
 
473
  # connect widgets
 
497
  outputs=[thank_you, download_file]
498
  )
499
 
500
+ demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
conf/generated-v0/berta-goldman-speech/c2f.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ AudioDataset.duration: 3.0
4
+ AudioDataset.loudness_cutoff: -40.0
5
+ VampNet.embedding_dim: 1280
6
+ VampNet.n_codebooks: 14
7
+ VampNet.n_conditioning_codebooks: 4
8
+ VampNet.n_heads: 20
9
+ VampNet.n_layers: 16
10
+ fine_tune: true
11
+ save_path: ./runs/berta-goldman-speech/c2f
12
+ train/AudioLoader.sources:
13
+ - /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3
14
+ val/AudioLoader.sources:
15
+ - /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3
conf/generated-v0/berta-goldman-speech/coarse.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ fine_tune: true
4
+ save_path: ./runs/berta-goldman-speech/coarse
5
+ train/AudioLoader.sources:
6
+ - /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3
7
+ val/AudioLoader.sources:
8
+ - /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3
conf/generated-v0/berta-goldman-speech/interface.yml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ AudioLoader.sources:
2
+ - /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3
3
+ Interface.coarse2fine_ckpt: ./runs/berta-goldman-speech/c2f/best/vampnet/weights.pth
4
+ Interface.coarse_ckpt: ./runs/berta-goldman-speech/coarse/best/vampnet/weights.pth
5
+ Interface.codec_ckpt: ./models/spotdl/codec.pth
conf/generated-v0/gamelan-xeno-canto/c2f.yml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ AudioDataset.duration: 3.0
4
+ AudioDataset.loudness_cutoff: -40.0
5
+ VampNet.embedding_dim: 1280
6
+ VampNet.n_codebooks: 14
7
+ VampNet.n_conditioning_codebooks: 4
8
+ VampNet.n_heads: 20
9
+ VampNet.n_layers: 16
10
+ fine_tune: true
11
+ save_path: ./runs/gamelan-xeno-canto/c2f
12
+ train/AudioLoader.sources:
13
+ - /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3
14
+ - /media/CHONK/hugo/loras/xeno-canto-2
15
+ val/AudioLoader.sources:
16
+ - /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3
17
+ - /media/CHONK/hugo/loras/xeno-canto-2
conf/generated-v0/gamelan-xeno-canto/coarse.yml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ fine_tune: true
4
+ save_path: ./runs/gamelan-xeno-canto/coarse
5
+ train/AudioLoader.sources:
6
+ - /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3
7
+ - /media/CHONK/hugo/loras/xeno-canto-2
8
+ val/AudioLoader.sources:
9
+ - /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3
10
+ - /media/CHONK/hugo/loras/xeno-canto-2
conf/generated-v0/gamelan-xeno-canto/interface.yml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ AudioLoader.sources:
2
+ - /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3
3
+ - /media/CHONK/hugo/loras/xeno-canto-2
4
+ Interface.coarse2fine_ckpt: ./runs/gamelan-xeno-canto/c2f/best/vampnet/weights.pth
5
+ Interface.coarse_ckpt: ./runs/gamelan-xeno-canto/coarse/best/vampnet/weights.pth
6
+ Interface.codec_ckpt: ./models/spotdl/codec.pth
conf/generated-v0/nasralla/c2f.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ AudioDataset.duration: 3.0
4
+ AudioDataset.loudness_cutoff: -40.0
5
+ VampNet.embedding_dim: 1280
6
+ VampNet.n_codebooks: 14
7
+ VampNet.n_conditioning_codebooks: 4
8
+ VampNet.n_heads: 20
9
+ VampNet.n_layers: 16
10
+ fine_tune: true
11
+ save_path: ./runs/nasralla/c2f
12
+ train/AudioLoader.sources:
13
+ - /media/CHONK/hugo/nasralla
14
+ val/AudioLoader.sources:
15
+ - /media/CHONK/hugo/nasralla
conf/generated-v0/nasralla/coarse.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ fine_tune: true
4
+ save_path: ./runs/nasralla/coarse
5
+ train/AudioLoader.sources:
6
+ - /media/CHONK/hugo/nasralla
7
+ val/AudioLoader.sources:
8
+ - /media/CHONK/hugo/nasralla
conf/generated-v0/nasralla/interface.yml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ AudioLoader.sources:
2
+ - /media/CHONK/hugo/nasralla
3
+ Interface.coarse2fine_ckpt: ./runs/nasralla/c2f/best/vampnet/weights.pth
4
+ Interface.coarse_ckpt: ./runs/nasralla/coarse/best/vampnet/weights.pth
5
+ Interface.codec_ckpt: ./models/spotdl/codec.pth
conf/generated/breaks-steps/c2f.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ AudioDataset.duration: 3.0
4
+ AudioDataset.loudness_cutoff: -40.0
5
+ VampNet.embedding_dim: 1280
6
+ VampNet.n_codebooks: 14
7
+ VampNet.n_conditioning_codebooks: 4
8
+ VampNet.n_heads: 20
9
+ VampNet.n_layers: 16
10
+ fine_tune: true
11
+ fine_tune_checkpoint: ./models/spotdl/c2f.pth
12
+ save_path: ./runs/breaks-steps/c2f
13
+ train/AudioLoader.sources: &id001
14
+ - /media/CHONK/hugo/breaks-steps
15
+ val/AudioLoader.sources: *id001
conf/generated/breaks-steps/coarse.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ fine_tune: true
4
+ fine_tune_checkpoint: ./models/spotdl/coarse.pth
5
+ save_path: ./runs/breaks-steps/coarse
6
+ train/AudioLoader.sources: &id001
7
+ - /media/CHONK/hugo/breaks-steps
8
+ val/AudioLoader.sources: *id001
conf/generated/breaks-steps/interface.yml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ AudioLoader.sources:
2
+ - - /media/CHONK/hugo/breaks-steps
3
+ Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
4
+ Interface.coarse2fine_lora_ckpt: ./runs/breaks-steps/c2f/latest/lora.pth
5
+ Interface.coarse_ckpt: ./models/spotdl/coarse.pth
6
+ Interface.coarse_lora_ckpt: ./runs/breaks-steps/coarse/latest/lora.pth
7
+ Interface.codec_ckpt: ./models/spotdl/codec.pth
conf/generated/bulgarian-tv-choir/c2f.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ AudioDataset.duration: 3.0
4
+ AudioDataset.loudness_cutoff: -40.0
5
+ VampNet.embedding_dim: 1280
6
+ VampNet.n_codebooks: 14
7
+ VampNet.n_conditioning_codebooks: 4
8
+ VampNet.n_heads: 20
9
+ VampNet.n_layers: 16
10
+ fine_tune: true
11
+ fine_tune_checkpoint: ./models/spotdl/c2f.pth
12
+ save_path: ./runs/bulgarian-tv-choir/c2f
13
+ train/AudioLoader.sources: &id001
14
+ - /media/CHONK/hugo/loras/bulgarian-female-tv-choir/
15
+ val/AudioLoader.sources: *id001
conf/generated/bulgarian-tv-choir/coarse.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ fine_tune: true
4
+ fine_tune_checkpoint: ./models/spotdl/coarse.pth
5
+ save_path: ./runs/bulgarian-tv-choir/coarse
6
+ train/AudioLoader.sources: &id001
7
+ - /media/CHONK/hugo/loras/bulgarian-female-tv-choir/
8
+ val/AudioLoader.sources: *id001
conf/generated/bulgarian-tv-choir/interface.yml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ AudioLoader.sources:
2
+ - - /media/CHONK/hugo/loras/bulgarian-female-tv-choir/
3
+ Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
4
+ Interface.coarse2fine_lora_ckpt: ./runs/bulgarian-tv-choir/c2f/latest/lora.pth
5
+ Interface.coarse_ckpt: ./models/spotdl/coarse.pth
6
+ Interface.coarse_lora_ckpt: ./runs/bulgarian-tv-choir/coarse/latest/lora.pth
7
+ Interface.codec_ckpt: ./models/spotdl/codec.pth
conf/generated/dariacore/c2f.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ AudioDataset.duration: 3.0
4
+ AudioDataset.loudness_cutoff: -40.0
5
+ VampNet.embedding_dim: 1280
6
+ VampNet.n_codebooks: 14
7
+ VampNet.n_conditioning_codebooks: 4
8
+ VampNet.n_heads: 20
9
+ VampNet.n_layers: 16
10
+ fine_tune: true
11
+ fine_tune_checkpoint: ./models/spotdl/c2f.pth
12
+ save_path: ./runs/dariacore/c2f
13
+ train/AudioLoader.sources: &id001
14
+ - /media/CHONK/hugo/loras/dariacore
15
+ val/AudioLoader.sources: *id001
conf/generated/dariacore/coarse.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ fine_tune: true
4
+ fine_tune_checkpoint: ./models/spotdl/coarse.pth
5
+ save_path: ./runs/dariacore/coarse
6
+ train/AudioLoader.sources: &id001
7
+ - /media/CHONK/hugo/loras/dariacore
8
+ val/AudioLoader.sources: *id001
conf/generated/dariacore/interface.yml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ AudioLoader.sources:
2
+ - - /media/CHONK/hugo/loras/dariacore
3
+ Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
4
+ Interface.coarse2fine_lora_ckpt: ./runs/dariacore/c2f/latest/lora.pth
5
+ Interface.coarse_ckpt: ./models/spotdl/coarse.pth
6
+ Interface.coarse_lora_ckpt: ./runs/dariacore/coarse/latest/lora.pth
7
+ Interface.codec_ckpt: ./models/spotdl/codec.pth
conf/generated/musica-bolero-marimba/c2f.yml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ AudioDataset.duration: 3.0
4
+ AudioDataset.loudness_cutoff: -40.0
5
+ VampNet.embedding_dim: 1280
6
+ VampNet.n_codebooks: 14
7
+ VampNet.n_conditioning_codebooks: 4
8
+ VampNet.n_heads: 20
9
+ VampNet.n_layers: 16
10
+ fine_tune: true
11
+ fine_tune_checkpoint: ./models/spotdl/c2f.pth
12
+ save_path: ./runs/musica-bolero-marimba/c2f
13
+ train/AudioLoader.sources:
14
+ - /media/CHONK/hugo/loras/boleros
15
+ - /media/CHONK/hugo/loras/marimba-honduras
16
+ val/AudioLoader.sources:
17
+ - /media/CHONK/hugo/loras/boleros
18
+ - /media/CHONK/hugo/loras/marimba-honduras
conf/generated/musica-bolero-marimba/coarse.yml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ fine_tune: true
4
+ fine_tune_checkpoint: ./models/spotdl/coarse.pth
5
+ save_path: ./runs/musica-bolero-marimba/coarse
6
+ train/AudioLoader.sources:
7
+ - /media/CHONK/hugo/loras/boleros
8
+ - /media/CHONK/hugo/loras/marimba-honduras
9
+ val/AudioLoader.sources:
10
+ - /media/CHONK/hugo/loras/boleros
11
+ - /media/CHONK/hugo/loras/marimba-honduras
conf/generated/musica-bolero-marimba/interface.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ AudioLoader.sources:
2
+ - /media/CHONK/hugo/loras/boleros
3
+ - /media/CHONK/hugo/loras/marimba-honduras
4
+ Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
5
+ Interface.coarse2fine_lora_ckpt: ./runs/musica-bolero-marimba/c2f/latest/lora.pth
6
+ Interface.coarse_ckpt: ./models/spotdl/coarse.pth
7
+ Interface.coarse_lora_ckpt: ./runs/musica-bolero-marimba/coarse/latest/lora.pth
8
+ Interface.codec_ckpt: ./models/spotdl/codec.pth
conf/generated/panchos/c2f.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ AudioDataset.duration: 3.0
4
+ AudioDataset.loudness_cutoff: -40.0
5
+ VampNet.embedding_dim: 1280
6
+ VampNet.n_codebooks: 14
7
+ VampNet.n_conditioning_codebooks: 4
8
+ VampNet.n_heads: 20
9
+ VampNet.n_layers: 16
10
+ fine_tune: true
11
+ fine_tune_checkpoint: ./models/spotdl/c2f.pth
12
+ save_path: ./runs/panchos/c2f
13
+ train/AudioLoader.sources: &id001
14
+ - /media/CHONK/hugo/loras/panchos/
15
+ val/AudioLoader.sources: *id001
conf/generated/panchos/coarse.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ fine_tune: true
4
+ fine_tune_checkpoint: ./models/spotdl/coarse.pth
5
+ save_path: ./runs/panchos/coarse
6
+ train/AudioLoader.sources: &id001
7
+ - /media/CHONK/hugo/loras/panchos/
8
+ val/AudioLoader.sources: *id001
conf/generated/panchos/interface.yml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ AudioLoader.sources:
2
+ - - /media/CHONK/hugo/loras/panchos/
3
+ Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
4
+ Interface.coarse2fine_lora_ckpt: ./runs/panchos/c2f/latest/lora.pth
5
+ Interface.coarse_ckpt: ./models/spotdl/coarse.pth
6
+ Interface.coarse_lora_ckpt: ./runs/panchos/coarse/latest/lora.pth
7
+ Interface.codec_ckpt: ./models/spotdl/codec.pth
conf/generated/titi-monkey/c2f.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ AudioDataset.duration: 3.0
4
+ AudioDataset.loudness_cutoff: -40.0
5
+ VampNet.embedding_dim: 1280
6
+ VampNet.n_codebooks: 14
7
+ VampNet.n_conditioning_codebooks: 4
8
+ VampNet.n_heads: 20
9
+ VampNet.n_layers: 16
10
+ fine_tune: true
11
+ fine_tune_checkpoint: ./models/spotdl/c2f.pth
12
+ save_path: ./runs/titi-monkey/c2f
13
+ train/AudioLoader.sources: &id001
14
+ - /media/CHONK/hugo/loras/titi-monkey.mp3
15
+ val/AudioLoader.sources: *id001
conf/generated/titi-monkey/coarse.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ fine_tune: true
4
+ fine_tune_checkpoint: ./models/spotdl/coarse.pth
5
+ save_path: ./runs/titi-monkey/coarse
6
+ train/AudioLoader.sources: &id001
7
+ - /media/CHONK/hugo/loras/titi-monkey.mp3
8
+ val/AudioLoader.sources: *id001
conf/generated/titi-monkey/interface.yml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ AudioLoader.sources:
2
+ - - /media/CHONK/hugo/loras/titi-monkey.mp3
3
+ Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
4
+ Interface.coarse2fine_lora_ckpt: ./runs/titi-monkey/c2f/latest/lora.pth
5
+ Interface.coarse_ckpt: ./models/spotdl/coarse.pth
6
+ Interface.coarse_lora_ckpt: ./runs/titi-monkey/coarse/latest/lora.pth
7
+ Interface.codec_ckpt: ./models/spotdl/codec.pth
conf/generated/xeno-canto/c2f.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ AudioDataset.duration: 3.0
4
+ AudioDataset.loudness_cutoff: -40.0
5
+ VampNet.embedding_dim: 1280
6
+ VampNet.n_codebooks: 14
7
+ VampNet.n_conditioning_codebooks: 4
8
+ VampNet.n_heads: 20
9
+ VampNet.n_layers: 16
10
+ fine_tune: true
11
+ fine_tune_checkpoint: ./models/spotdl/c2f.pth
12
+ save_path: ./runs/xeno-canto/c2f
13
+ train/AudioLoader.sources: &id001
14
+ - /media/CHONK/hugo/loras/xeno-canto-2/
15
+ val/AudioLoader.sources: *id001
conf/generated/xeno-canto/coarse.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ fine_tune: true
4
+ fine_tune_checkpoint: ./models/spotdl/coarse.pth
5
+ save_path: ./runs/xeno-canto/coarse
6
+ train/AudioLoader.sources: &id001
7
+ - /media/CHONK/hugo/loras/xeno-canto-2/
8
+ val/AudioLoader.sources: *id001
conf/generated/xeno-canto/interface.yml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ AudioLoader.sources:
2
+ - - /media/CHONK/hugo/loras/xeno-canto-2/
3
+ Interface.coarse2fine_ckpt: ./mod els/spotdl/c2f.pth
4
+ Interface.coarse2fine_lora_ckpt: ./runs/xeno-canto/c2f/latest/lora.pth
5
+ Interface.coarse_ckpt: ./models/spotdl/coarse.pth
6
+ Interface.coarse_lora_ckpt: ./runs/xeno-canto/coarse/latest/lora.pth
7
+ Interface.codec_ckpt: ./models/spotdl/codec.pth
conf/lora/birds.yml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+
4
+ fine_tune: True
5
+
6
+ train/AudioLoader.sources:
7
+ - /media/CHONK/hugo/spotdl/subsets/birds
8
+
9
+ val/AudioLoader.sources:
10
+ - /media/CHONK/hugo/spotdl/subsets/birds
conf/lora/birdss.yml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+
4
+ fine_tune: True
5
+
6
+ train/AudioLoader.sources:
7
+ - /media/CHONK/hugo/spotdl/subsets/birds
8
+ - /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/
9
+
10
+ val/AudioLoader.sources:
11
+ - /media/CHONK/hugo/spotdl/subsets/birds
12
+ - /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/
conf/lora/constructions.yml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+
4
+ fine_tune: True
5
+
6
+ train/AudioLoader.sources:
7
+ - /media/CHONK/hugo/spotdl/subsets/constructions/third.mp3
8
+
9
+ val/AudioLoader.sources:
10
+ - /media/CHONK/hugo/spotdl/subsets/constructions/third.mp3
conf/lora/ella-baila-sola.yml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+
4
+ fine_tune: True
5
+
6
+ train/AudioLoader.sources:
7
+ - /media/CHONK/hugo/spotdl/subsets/ella-baila-sola.mp3
8
+
9
+ val/AudioLoader.sources:
10
+ - /media/CHONK/hugo/spotdl/subsets/ella-baila-sola.mp3
conf/lora/gas-station.yml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+
4
+ fine_tune: True
5
+
6
+ train/AudioLoader.sources:
7
+ - /media/CHONK/hugo/spotdl/subsets/gas-station-sushi.mp3
8
+
9
+ val/AudioLoader.sources:
10
+ - /media/CHONK/hugo/spotdl/subsets/gas-station-sushi.mp3
conf/lora/lora-is-this-charlie-parker.yml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+
4
+ fine_tune: True
5
+
6
+ train/AudioLoader.sources:
7
+ - /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/Charlie Parker - Donna Lee.mp3
8
+
9
+ val/AudioLoader.sources:
10
+ - /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/Charlie Parker - Donna Lee.mp3
conf/lora/lora.yml CHANGED
@@ -3,20 +3,20 @@ $include:
3
 
4
  fine_tune: True
5
 
6
- train/AudioDataset.n_examples: 100000000
7
- val/AudioDataset.n_examples: 500
 
8
 
9
 
10
  NoamScheduler.warmup: 500
11
 
12
- batch_size: 6
13
  num_workers: 7
14
- save_iters: [10000, 20000, 30000, 40000, 50000, 100000]
15
- sample_freq: 1000
16
- val_freq: 500
17
 
18
  AdamW.lr: 0.0001
19
 
20
  # let's us organize sound classes into folders and choose from those sound classes uniformly
21
  AudioDataset.without_replacement: False
22
- num_iters: 500000
 
3
 
4
  fine_tune: True
5
 
6
+ train/AudioDataset.n_examples: 10000000
7
+
8
+ val/AudioDataset.n_examples: 10
9
 
10
 
11
  NoamScheduler.warmup: 500
12
 
13
+ batch_size: 7
14
  num_workers: 7
15
+ epoch_length: 100
16
+ save_audio_epochs: 10
 
17
 
18
  AdamW.lr: 0.0001
19
 
20
  # let's us organize sound classes into folders and choose from those sound classes uniformly
21
  AudioDataset.without_replacement: False
22
+ max_epochs: 500
conf/lora/underworld.yml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+
4
+ fine_tune: True
5
+
6
+ train/AudioLoader.sources:
7
+ - /media/CHONK/hugo/spotdl/subsets/underworld.mp3
8
+
9
+ val/AudioLoader.sources:
10
+ - /media/CHONK/hugo/spotdl/subsets/underworld.mp3
conf/lora/xeno-canto/c2f.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+
4
+ fine_tune: True
5
+
6
+ train/AudioLoader.sources:
7
+ - /media/CHONK/hugo/xeno-canto-2
8
+
9
+ val/AudioLoader.sources:
10
+ - /media/CHONK/hugo/xeno-canto-2
11
+
12
+
13
+ VampNet.n_codebooks: 14
14
+ VampNet.n_conditioning_codebooks: 4
15
+
16
+ VampNet.embedding_dim: 1280
17
+ VampNet.n_layers: 16
18
+ VampNet.n_heads: 20
19
+
20
+ AudioDataset.duration: 3.0
21
+ AudioDataset.loudness_cutoff: -40.0
conf/lora/xeno-canto/coarse.yml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+
4
+ fine_tune: True
5
+
6
+ train/AudioLoader.sources:
7
+ - /media/CHONK/hugo/xeno-canto-2
8
+
9
+ val/AudioLoader.sources:
10
+ - /media/CHONK/hugo/xeno-canto-2
conf/vampnet-musdb-drums.yml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/vampnet.yml
3
+
4
+ VampNet.embedding_dim: 512
5
+ VampNet.n_layers: 12
6
+ VampNet.n_heads: 8
7
+
8
+ AudioDataset.duration: 12.0
9
+
10
+ train/AudioDataset.n_examples: 10000000
11
+ train/AudioLoader.sources:
12
+ - /data/musdb18hq/train/**/*drums.wav
13
+
14
+
15
+ val/AudioDataset.n_examples: 500
16
+ val/AudioLoader.sources:
17
+ - /data/musdb18hq/test/**/*drums.wav
18
+
19
+
20
+ test/AudioDataset.n_examples: 1000
21
+ test/AudioLoader.sources:
22
+ - /data/musdb18hq/test/**/*drums.wav
conf/vampnet.yml CHANGED
@@ -1,17 +1,21 @@
1
 
2
- codec_ckpt: ./models/vampnet/codec.pth
3
  save_path: ckpt
4
-
5
- num_iters: 1000000000
6
- save_iters: [10000, 50000, 100000, 300000, 500000]
7
  val_idx: [0,1,2,3,4,5,6,7,8,9]
8
- sample_freq: 10000
9
- val_freq: 1000
 
 
 
10
 
11
  batch_size: 8
12
  num_workers: 10
13
 
14
  # Optimization
 
15
  amp: false
16
 
17
  CrossEntropyLoss.label_smoothing: 0.1
@@ -21,6 +25,9 @@ AdamW.lr: 0.001
21
  NoamScheduler.factor: 2.0
22
  NoamScheduler.warmup: 10000
23
 
 
 
 
24
  VampNet.vocab_size: 1024
25
  VampNet.n_codebooks: 4
26
  VampNet.n_conditioning_codebooks: 0
@@ -32,7 +39,7 @@ VampNet.n_heads: 20
32
  VampNet.flash_attn: false
33
  VampNet.dropout: 0.1
34
 
35
- AudioLoader.relative_path: ""
36
  AudioDataset.loudness_cutoff: -30.0
37
  AudioDataset.without_replacement: true
38
  AudioLoader.shuffle: true
@@ -41,9 +48,12 @@ AudioDataset.duration: 10.0
41
 
42
  train/AudioDataset.n_examples: 10000000
43
  train/AudioLoader.sources:
44
- - /media/CHONK/hugo/spotdl/audio-train
45
 
46
  val/AudioDataset.n_examples: 2000
47
  val/AudioLoader.sources:
48
- - /media/CHONK/hugo/spotdl/audio-val
49
 
 
 
 
 
1
 
2
+ codec_ckpt: ./models/spotdl/codec.pth
3
  save_path: ckpt
4
+ max_epochs: 1000
5
+ epoch_length: 1000
6
+ save_audio_epochs: 2
7
  val_idx: [0,1,2,3,4,5,6,7,8,9]
8
+
9
+ prefix_amt: 0.0
10
+ suffix_amt: 0.0
11
+ prefix_dropout: 0.1
12
+ suffix_dropout: 0.1
13
 
14
  batch_size: 8
15
  num_workers: 10
16
 
17
  # Optimization
18
+ detect_anomaly: false
19
  amp: false
20
 
21
  CrossEntropyLoss.label_smoothing: 0.1
 
25
  NoamScheduler.factor: 2.0
26
  NoamScheduler.warmup: 10000
27
 
28
+ PitchShift.shift_amount: [const, 0]
29
+ PitchShift.prob: 0.0
30
+
31
  VampNet.vocab_size: 1024
32
  VampNet.n_codebooks: 4
33
  VampNet.n_conditioning_codebooks: 0
 
39
  VampNet.flash_attn: false
40
  VampNet.dropout: 0.1
41
 
42
+ AudioLoader.relative_path: /data/
43
  AudioDataset.loudness_cutoff: -30.0
44
  AudioDataset.without_replacement: true
45
  AudioLoader.shuffle: true
 
48
 
49
  train/AudioDataset.n_examples: 10000000
50
  train/AudioLoader.sources:
51
+ - /data/spotdl/audio/train
52
 
53
  val/AudioDataset.n_examples: 2000
54
  val/AudioLoader.sources:
55
+ - /data/spotdl/audio/val
56
 
57
+ test/AudioDataset.n_examples: 1000
58
+ test/AudioLoader.sources:
59
+ - /data/spotdl/audio/test
requirements.txt CHANGED
@@ -1,10 +1,8 @@
1
  torch
2
  argbind>=0.3.2
3
- numpy==1.23
4
  gradio
5
  loralib
6
  wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
7
  lac @ git+https://github.com/hugofloresgarcia/lac.git
8
- descript-audiotools @ git+https://github.com/descriptinc/audiotools.git@0.7.2
9
- -e git+https://github.com/audacitorch/pyharp.git#egg=pyharp
10
- torch_pitch_shift
 
1
  torch
2
  argbind>=0.3.2
3
+ numpy==1.22
4
  gradio
5
  loralib
6
  wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
7
  lac @ git+https://github.com/hugofloresgarcia/lac.git
8
+ audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git
 
 
scripts/exp/fine_tune.py CHANGED
@@ -35,7 +35,7 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
35
  "AudioDataset.duration": 3.0,
36
  "AudioDataset.loudness_cutoff": -40.0,
37
  "save_path": f"./runs/{name}/c2f",
38
- "fine_tune_checkpoint": "./models/vampnet/c2f.pth"
39
  }
40
 
41
  finetune_coarse_conf = {
@@ -44,16 +44,17 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
44
  "train/AudioLoader.sources": audio_files_or_folders,
45
  "val/AudioLoader.sources": audio_files_or_folders,
46
  "save_path": f"./runs/{name}/coarse",
47
- "fine_tune_checkpoint": "./models/vampnet/coarse.pth"
48
  }
49
 
50
  interface_conf = {
51
- "Interface.coarse_ckpt": f"./runs/{name}/coarse/latest/vampnet/weights.pth",
 
52
 
53
- "Interface.coarse2fine_ckpt": f"./runs/{name}/c2f/latest/vampnet/weights.pth",
54
- "Interface.wavebeat_ckpt": "./models/wavebeat.pth",
55
 
56
- "Interface.codec_ckpt": "./models/vampnet/codec.pth",
57
  "AudioLoader.sources": [audio_files_or_folders],
58
  }
59
 
 
35
  "AudioDataset.duration": 3.0,
36
  "AudioDataset.loudness_cutoff": -40.0,
37
  "save_path": f"./runs/{name}/c2f",
38
+ "fine_tune_checkpoint": "./models/spotdl/c2f.pth"
39
  }
40
 
41
  finetune_coarse_conf = {
 
44
  "train/AudioLoader.sources": audio_files_or_folders,
45
  "val/AudioLoader.sources": audio_files_or_folders,
46
  "save_path": f"./runs/{name}/coarse",
47
+ "fine_tune_checkpoint": "./models/spotdl/coarse.pth"
48
  }
49
 
50
  interface_conf = {
51
+ "Interface.coarse_ckpt": f"./models/spotdl/coarse.pth",
52
+ "Interface.coarse_lora_ckpt": f"./runs/{name}/coarse/latest/lora.pth",
53
 
54
+ "Interface.coarse2fine_ckpt": f"./models/spotdl/c2f.pth",
55
+ "Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
56
 
57
+ "Interface.codec_ckpt": "./models/spotdl/codec.pth",
58
  "AudioLoader.sources": [audio_files_or_folders],
59
  }
60
 
scripts/exp/train.py CHANGED
@@ -1,9 +1,9 @@
1
  import os
2
- import sys
 
3
  import warnings
4
  from pathlib import Path
5
  from typing import Optional
6
- from dataclasses import dataclass
7
 
8
  import argbind
9
  import audiotools as at
@@ -14,7 +14,7 @@ from audiotools.data import transforms
14
  from einops import rearrange
15
  from rich import pretty
16
  from rich.traceback import install
17
- from torch.utils.tensorboard import SummaryWriter
18
 
19
  import vampnet
20
  from vampnet.modules.transformer import VampNet
@@ -23,15 +23,6 @@ from vampnet import mask as pmask
23
  # from dac.model.dac import DAC
24
  from lac.model.lac import LAC as DAC
25
 
26
- from audiotools.ml.decorators import (
27
- timer, Tracker, when
28
- )
29
-
30
- import loralib as lora
31
-
32
- import torch._dynamo
33
- torch._dynamo.config.verbose=True
34
-
35
 
36
  # Enable cudnn autotuner to speed up training
37
  # (can be altered by the funcs.seed function)
@@ -94,7 +85,11 @@ def build_datasets(args, sample_rate: int):
94
  )
95
  with argbind.scope(args, "val"):
96
  val_data = AudioDataset(AudioLoader(), sample_rate, transform=build_transform())
97
- return train_data, val_data
 
 
 
 
98
 
99
 
100
  def rand_float(shape, low, high, rng):
@@ -105,392 +100,16 @@ def flip_coin(shape, p, rng):
105
  return rng.draw(shape)[:, 0] < p
106
 
107
 
108
- def num_params_hook(o, p):
109
- return o + f" {p/1e6:<.3f}M params."
110
-
111
-
112
- def add_num_params_repr_hook(model):
113
- import numpy as np
114
- from functools import partial
115
-
116
- for n, m in model.named_modules():
117
- o = m.extra_repr()
118
- p = sum([np.prod(p.size()) for p in m.parameters()])
119
-
120
- setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p))
121
-
122
-
123
- def accuracy(
124
- preds: torch.Tensor,
125
- target: torch.Tensor,
126
- top_k: int = 1,
127
- ignore_index: Optional[int] = None,
128
- ) -> torch.Tensor:
129
- # Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class)
130
- preds = rearrange(preds, "b p s -> (b s) p")
131
- target = rearrange(target, "b s -> (b s)")
132
-
133
- # return torchmetrics.functional.accuracy(preds, target, task='multiclass', top_k=topk, num_classes=preds.shape[-1], ignore_index=ignore_index)
134
- if ignore_index is not None:
135
- # Create a mask for the ignored index
136
- mask = target != ignore_index
137
- # Apply the mask to the target and predictions
138
- preds = preds[mask]
139
- target = target[mask]
140
-
141
- # Get the top-k predicted classes and their indices
142
- _, pred_indices = torch.topk(preds, k=top_k, dim=-1)
143
-
144
- # Determine if the true target is in the top-k predicted classes
145
- correct = torch.sum(torch.eq(pred_indices, target.unsqueeze(1)), dim=1)
146
-
147
- # Calculate the accuracy
148
- accuracy = torch.mean(correct.float())
149
-
150
- return accuracy
151
-
152
- def _metrics(z_hat, r, target, flat_mask, output):
153
- for r_range in [(0, 0.5), (0.5, 1.0)]:
154
- unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX)
155
- masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
156
-
157
- assert target.shape[0] == r.shape[0]
158
- # grab the indices of the r values that are in the range
159
- r_idx = (r >= r_range[0]) & (r < r_range[1])
160
-
161
- # grab the target and z_hat values that are in the range
162
- r_unmasked_target = unmasked_target[r_idx]
163
- r_masked_target = masked_target[r_idx]
164
- r_z_hat = z_hat[r_idx]
165
-
166
- for topk in (1, 25):
167
- s, e = r_range
168
- tag = f"accuracy-{s}-{e}/top{topk}"
169
-
170
- output[f"{tag}/unmasked"] = accuracy(
171
- preds=r_z_hat,
172
- target=r_unmasked_target,
173
- ignore_index=IGNORE_INDEX,
174
- top_k=topk,
175
- )
176
- output[f"{tag}/masked"] = accuracy(
177
- preds=r_z_hat,
178
- target=r_masked_target,
179
- ignore_index=IGNORE_INDEX,
180
- top_k=topk,
181
- )
182
-
183
-
184
- @dataclass
185
- class State:
186
- model: VampNet
187
- codec: DAC
188
-
189
- optimizer: AdamW
190
- scheduler: NoamScheduler
191
- criterion: CrossEntropyLoss
192
- grad_clip_val: float
193
-
194
- rng: torch.quasirandom.SobolEngine
195
-
196
- train_data: AudioDataset
197
- val_data: AudioDataset
198
-
199
- tracker: Tracker
200
-
201
-
202
- @timer()
203
- def train_loop(state: State, batch: dict, accel: Accelerator):
204
- state.model.train()
205
- batch = at.util.prepare_batch(batch, accel.device)
206
- signal = apply_transform(state.train_data.transform, batch)
207
-
208
- output = {}
209
- vn = accel.unwrap(state.model)
210
- with accel.autocast():
211
- with torch.inference_mode():
212
- state.codec.to(accel.device)
213
- z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
214
- z = z[:, : vn.n_codebooks, :]
215
-
216
- n_batch = z.shape[0]
217
- r = state.rng.draw(n_batch)[:, 0].to(accel.device)
218
-
219
- mask = pmask.random(z, r)
220
- mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
221
- z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
222
-
223
- z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
224
-
225
- dtype = torch.bfloat16 if accel.amp else None
226
- with accel.autocast(dtype=dtype):
227
- z_hat = state.model(z_mask_latent)
228
-
229
- target = codebook_flatten(
230
- z[:, vn.n_conditioning_codebooks :, :],
231
- )
232
-
233
- flat_mask = codebook_flatten(
234
- mask[:, vn.n_conditioning_codebooks :, :],
235
- )
236
-
237
- # replace target with ignore index for masked tokens
238
- t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
239
- output["loss"] = state.criterion(z_hat, t_masked)
240
-
241
- _metrics(
242
- r=r,
243
- z_hat=z_hat,
244
- target=target,
245
- flat_mask=flat_mask,
246
- output=output,
247
- )
248
-
249
-
250
- accel.backward(output["loss"])
251
-
252
- output["other/learning_rate"] = state.optimizer.param_groups[0]["lr"]
253
- output["other/batch_size"] = z.shape[0]
254
-
255
-
256
- accel.scaler.unscale_(state.optimizer)
257
- output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_(
258
- state.model.parameters(), state.grad_clip_val
259
- )
260
-
261
- accel.step(state.optimizer)
262
- state.optimizer.zero_grad()
263
-
264
- state.scheduler.step()
265
- accel.update()
266
-
267
-
268
- return {k: v for k, v in sorted(output.items())}
269
-
270
-
271
- @timer()
272
- @torch.no_grad()
273
- def val_loop(state: State, batch: dict, accel: Accelerator):
274
- state.model.eval()
275
- state.codec.eval()
276
- batch = at.util.prepare_batch(batch, accel.device)
277
- signal = apply_transform(state.val_data.transform, batch)
278
-
279
- vn = accel.unwrap(state.model)
280
- z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
281
- z = z[:, : vn.n_codebooks, :]
282
-
283
- n_batch = z.shape[0]
284
- r = state.rng.draw(n_batch)[:, 0].to(accel.device)
285
-
286
- mask = pmask.random(z, r)
287
- mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
288
- z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
289
-
290
- z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
291
-
292
- z_hat = state.model(z_mask_latent)
293
-
294
- target = codebook_flatten(
295
- z[:, vn.n_conditioning_codebooks :, :],
296
- )
297
-
298
- flat_mask = codebook_flatten(
299
- mask[:, vn.n_conditioning_codebooks :, :]
300
- )
301
-
302
- output = {}
303
- # replace target with ignore index for masked tokens
304
- t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
305
- output["loss"] = state.criterion(z_hat, t_masked)
306
-
307
- _metrics(
308
- r=r,
309
- z_hat=z_hat,
310
- target=target,
311
- flat_mask=flat_mask,
312
- output=output,
313
- )
314
-
315
- return output
316
-
317
-
318
- def validate(state, val_dataloader, accel):
319
- for batch in val_dataloader:
320
- output = val_loop(state, batch, accel)
321
- # Consolidate state dicts if using ZeroRedundancyOptimizer
322
- if hasattr(state.optimizer, "consolidate_state_dict"):
323
- state.optimizer.consolidate_state_dict()
324
- return output
325
-
326
-
327
- def checkpoint(state, save_iters, save_path, fine_tune):
328
- if accel.local_rank != 0:
329
- state.tracker.print(f"ERROR:Skipping checkpoint on rank {accel.local_rank}")
330
- return
331
-
332
- metadata = {"logs": dict(state.tracker.history)}
333
-
334
- tags = ["latest"]
335
- state.tracker.print(f"Saving to {str(Path('.').absolute())}")
336
-
337
- if state.tracker.step in save_iters:
338
- tags.append(f"{state.tracker.step // 1000}k")
339
-
340
- if state.tracker.is_best("val", "loss"):
341
- state.tracker.print(f"Best model so far")
342
- tags.append("best")
343
-
344
- if fine_tune:
345
- for tag in tags:
346
- # save the lora model
347
- (Path(save_path) / tag).mkdir(parents=True, exist_ok=True)
348
- torch.save(
349
- lora.lora_state_dict(accel.unwrap(state.model)),
350
- f"{save_path}/{tag}/lora.pth"
351
- )
352
-
353
- for tag in tags:
354
- model_extra = {
355
- "optimizer.pth": state.optimizer.state_dict(),
356
- "scheduler.pth": state.scheduler.state_dict(),
357
- "tracker.pth": state.tracker.state_dict(),
358
- "metadata.pth": metadata,
359
- }
360
-
361
- accel.unwrap(state.model).metadata = metadata
362
- accel.unwrap(state.model).save_to_folder(
363
- f"{save_path}/{tag}", model_extra, package=False
364
- )
365
-
366
-
367
- def save_sampled(state, z, writer):
368
- num_samples = z.shape[0]
369
-
370
- for i in range(num_samples):
371
- sampled = accel.unwrap(state.model).generate(
372
- codec=state.codec,
373
- time_steps=z.shape[-1],
374
- start_tokens=z[i : i + 1],
375
- )
376
- sampled.cpu().write_audio_to_tb(
377
- f"sampled/{i}",
378
- writer,
379
- step=state.tracker.step,
380
- plot_fn=None,
381
- )
382
-
383
-
384
- def save_imputation(state, z, val_idx, writer):
385
- n_prefix = int(z.shape[-1] * 0.25)
386
- n_suffix = int(z.shape[-1] * 0.25)
387
-
388
- vn = accel.unwrap(state.model)
389
-
390
- mask = pmask.inpaint(z, n_prefix, n_suffix)
391
- mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
392
- z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
393
-
394
- imputed_noisy = vn.to_signal(z_mask, state.codec)
395
- imputed_true = vn.to_signal(z, state.codec)
396
-
397
- imputed = []
398
- for i in range(len(z)):
399
- imputed.append(
400
- vn.generate(
401
- codec=state.codec,
402
- time_steps=z.shape[-1],
403
- start_tokens=z[i][None, ...],
404
- mask=mask[i][None, ...],
405
- )
406
- )
407
- imputed = AudioSignal.batch(imputed)
408
-
409
- for i in range(len(val_idx)):
410
- imputed_noisy[i].cpu().write_audio_to_tb(
411
- f"inpainted_prompt/{i}",
412
- writer,
413
- step=state.tracker.step,
414
- plot_fn=None,
415
- )
416
- imputed[i].cpu().write_audio_to_tb(
417
- f"inpainted_middle/{i}",
418
- writer,
419
- step=state.tracker.step,
420
- plot_fn=None,
421
- )
422
- imputed_true[i].cpu().write_audio_to_tb(
423
- f"reconstructed/{i}",
424
- writer,
425
- step=state.tracker.step,
426
- plot_fn=None,
427
- )
428
-
429
-
430
- @torch.no_grad()
431
- def save_samples(state: State, val_idx: int, writer: SummaryWriter):
432
- state.model.eval()
433
- state.codec.eval()
434
- vn = accel.unwrap(state.model)
435
-
436
- batch = [state.val_data[i] for i in val_idx]
437
- batch = at.util.prepare_batch(state.val_data.collate(batch), accel.device)
438
-
439
- signal = apply_transform(state.val_data.transform, batch)
440
-
441
- z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
442
- z = z[:, : vn.n_codebooks, :]
443
-
444
- r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device)
445
-
446
-
447
- mask = pmask.random(z, r)
448
- mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
449
- z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
450
-
451
- z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
452
-
453
- z_hat = state.model(z_mask_latent)
454
-
455
- z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
456
- z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks)
457
- z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
458
-
459
- generated = vn.to_signal(z_pred, state.codec)
460
- reconstructed = vn.to_signal(z, state.codec)
461
- masked = vn.to_signal(z_mask.squeeze(1), state.codec)
462
-
463
- for i in range(generated.batch_size):
464
- audio_dict = {
465
- "original": signal[i],
466
- "masked": masked[i],
467
- "generated": generated[i],
468
- "reconstructed": reconstructed[i],
469
- }
470
- for k, v in audio_dict.items():
471
- v.cpu().write_audio_to_tb(
472
- f"onestep/_{i}.r={r[i]:0.2f}/{k}",
473
- writer,
474
- step=state.tracker.step,
475
- plot_fn=None,
476
- )
477
-
478
- save_sampled(state=state, z=z, writer=writer)
479
- save_imputation(state=state, z=z, val_idx=val_idx, writer=writer)
480
-
481
-
482
-
483
  @argbind.bind(without_prefix=True)
484
  def load(
485
  args,
486
  accel: at.ml.Accelerator,
487
- tracker: Tracker,
488
  save_path: str,
489
  resume: bool = False,
490
  tag: str = "latest",
 
491
  fine_tune_checkpoint: Optional[str] = None,
492
- grad_clip_val: float = 5.0,
493
- ) -> State:
494
  codec = DAC.load(args["codec_ckpt"], map_location="cpu")
495
  codec.eval()
496
 
@@ -500,9 +119,8 @@ def load(
500
  kwargs = {
501
  "folder": f"{save_path}/{tag}",
502
  "map_location": "cpu",
503
- "package": False,
504
  }
505
- tracker.print(f"Loading checkpoint from {kwargs['folder']}")
506
  if (Path(kwargs["folder"]) / "vampnet").exists():
507
  model, v_extra = VampNet.load_from_folder(**kwargs)
508
  else:
@@ -513,14 +131,11 @@ def load(
513
 
514
  if args["fine_tune"]:
515
  assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
516
- model = torch.compile(
517
- VampNet.load(location=Path(fine_tune_checkpoint),
518
- map_location="cpu",
519
- )
520
- )
521
 
 
522
 
523
- model = torch.compile(VampNet()) if model is None else model
524
  model = accel.prepare_model(model)
525
 
526
  # assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
@@ -532,57 +147,89 @@ def load(
532
  scheduler = NoamScheduler(optimizer, d_model=accel.unwrap(model).embedding_dim)
533
  scheduler.step()
534
 
 
 
535
  if "optimizer.pth" in v_extra:
536
  optimizer.load_state_dict(v_extra["optimizer.pth"])
 
537
  scheduler.load_state_dict(v_extra["scheduler.pth"])
538
- if "tracker.pth" in v_extra:
539
- tracker.load_state_dict(v_extra["tracker.pth"])
540
-
541
- criterion = CrossEntropyLoss()
542
 
543
- sample_rate = codec.sample_rate
 
 
 
 
 
 
544
 
545
- # a better rng for sampling from our schedule
546
- rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=args["seed"])
547
 
548
- # log a model summary w/ num params
549
- if accel.local_rank == 0:
550
- add_num_params_repr_hook(accel.unwrap(model))
551
- with open(f"{save_path}/model.txt", "w") as f:
552
- f.write(repr(accel.unwrap(model)))
553
 
554
- # load the datasets
555
- train_data, val_data = build_datasets(args, sample_rate)
556
-
557
- return State(
558
- tracker=tracker,
559
- model=model,
560
- codec=codec,
561
- optimizer=optimizer,
562
- scheduler=scheduler,
563
- criterion=criterion,
564
- rng=rng,
565
- train_data=train_data,
566
- val_data=val_data,
567
- grad_clip_val=grad_clip_val,
568
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
569
 
570
 
571
  @argbind.bind(without_prefix=True)
572
  def train(
573
  args,
574
  accel: at.ml.Accelerator,
575
- seed: int = 0,
576
  codec_ckpt: str = None,
 
577
  save_path: str = "ckpt",
578
- num_iters: int = int(1000e6),
579
- save_iters: list = [10000, 50000, 100000, 300000, 500000,],
580
- sample_freq: int = 10000,
581
- val_freq: int = 1000,
582
- batch_size: int = 12,
 
583
  val_idx: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
584
  num_workers: int = 10,
 
 
585
  fine_tune: bool = False,
 
586
  ):
587
  assert codec_ckpt is not None, "codec_ckpt is required"
588
 
@@ -594,79 +241,376 @@ def train(
594
  writer = SummaryWriter(log_dir=f"{save_path}/logs/")
595
  argbind.dump_args(args, f"{save_path}/args.yml")
596
 
597
- tracker = Tracker(
598
- writer=writer, log_file=f"{save_path}/log.txt", rank=accel.local_rank
599
- )
600
-
601
  # load the codec model
602
- state: State = load(
603
- args=args,
604
- accel=accel,
605
- tracker=tracker,
606
- save_path=save_path)
607
- print("initialized state.")
608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609
  train_dataloader = accel.prepare_dataloader(
610
- state.train_data,
611
- start_idx=state.tracker.step * batch_size,
612
  num_workers=num_workers,
613
  batch_size=batch_size,
614
- collate_fn=state.train_data.collate,
615
  )
616
  val_dataloader = accel.prepare_dataloader(
617
- state.val_data,
618
  start_idx=0,
619
  num_workers=num_workers,
620
  batch_size=batch_size,
621
- collate_fn=state.val_data.collate,
622
- persistent_workers=num_workers > 0,
623
  )
624
- print("initialized dataloader.")
625
 
626
-
627
 
628
  if fine_tune:
629
- lora.mark_only_lora_as_trainable(state.model)
630
- print("marked only lora as trainable.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
631
 
632
- # Wrap the functions so that they neatly track in TensorBoard + progress bars
633
- # and only run when specific conditions are met.
634
- global train_loop, val_loop, validate, save_samples, checkpoint
635
 
636
- train_loop = tracker.log("train", "value", history=False)(
637
- tracker.track("train", num_iters, completed=state.tracker.step)(train_loop)
638
- )
639
- val_loop = tracker.track("val", len(val_dataloader))(val_loop)
640
- validate = tracker.log("val", "mean")(validate)
641
 
642
- save_samples = when(lambda: accel.local_rank == 0)(save_samples)
643
- checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
644
 
645
- print("starting training loop.")
646
- with tracker.live:
647
- for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
648
- train_loop(state, batch, accel)
649
 
650
- last_iter = (
651
- tracker.step == num_iters - 1 if num_iters is not None else False
652
  )
653
 
654
- if tracker.step % sample_freq == 0 or last_iter:
655
- save_samples(state, val_idx, writer)
 
656
 
657
- if tracker.step % val_freq == 0 or last_iter:
658
- validate(state, val_dataloader, accel)
659
- checkpoint(
660
- state=state,
661
- save_iters=save_iters,
662
- save_path=save_path,
663
- fine_tune=fine_tune)
 
 
 
 
 
 
664
 
665
- # Reset validation progress bar, print summary since last validation.
666
- tracker.done("val", f"Iteration {tracker.step}")
667
 
668
- if last_iter:
669
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
670
 
671
 
672
  if __name__ == "__main__":
@@ -674,6 +618,4 @@ if __name__ == "__main__":
674
  args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0
675
  with argbind.scope(args):
676
  with Accelerator() as accel:
677
- if accel.local_rank != 0:
678
- sys.tracebacklimit = 0
679
  train(args, accel)
 
1
  import os
2
+ import subprocess
3
+ import time
4
  import warnings
5
  from pathlib import Path
6
  from typing import Optional
 
7
 
8
  import argbind
9
  import audiotools as at
 
14
  from einops import rearrange
15
  from rich import pretty
16
  from rich.traceback import install
17
+ from tensorboardX import SummaryWriter
18
 
19
  import vampnet
20
  from vampnet.modules.transformer import VampNet
 
23
  # from dac.model.dac import DAC
24
  from lac.model.lac import LAC as DAC
25
 
 
 
 
 
 
 
 
 
 
26
 
27
  # Enable cudnn autotuner to speed up training
28
  # (can be altered by the funcs.seed function)
 
85
  )
86
  with argbind.scope(args, "val"):
87
  val_data = AudioDataset(AudioLoader(), sample_rate, transform=build_transform())
88
+ with argbind.scope(args, "test"):
89
+ test_data = AudioDataset(
90
+ AudioLoader(), sample_rate, transform=build_transform()
91
+ )
92
+ return train_data, val_data, test_data
93
 
94
 
95
  def rand_float(shape, low, high, rng):
 
100
  return rng.draw(shape)[:, 0] < p
101
 
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  @argbind.bind(without_prefix=True)
104
  def load(
105
  args,
106
  accel: at.ml.Accelerator,
 
107
  save_path: str,
108
  resume: bool = False,
109
  tag: str = "latest",
110
+ load_weights: bool = False,
111
  fine_tune_checkpoint: Optional[str] = None,
112
+ ):
 
113
  codec = DAC.load(args["codec_ckpt"], map_location="cpu")
114
  codec.eval()
115
 
 
119
  kwargs = {
120
  "folder": f"{save_path}/{tag}",
121
  "map_location": "cpu",
122
+ "package": not load_weights,
123
  }
 
124
  if (Path(kwargs["folder"]) / "vampnet").exists():
125
  model, v_extra = VampNet.load_from_folder(**kwargs)
126
  else:
 
131
 
132
  if args["fine_tune"]:
133
  assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
134
+ model = VampNet.load(location=Path(fine_tune_checkpoint), map_location="cpu")
135
+
 
 
 
136
 
137
+ model = VampNet() if model is None else model
138
 
 
139
  model = accel.prepare_model(model)
140
 
141
  # assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
 
147
  scheduler = NoamScheduler(optimizer, d_model=accel.unwrap(model).embedding_dim)
148
  scheduler.step()
149
 
150
+ trainer_state = {"state_dict": None, "start_idx": 0}
151
+
152
  if "optimizer.pth" in v_extra:
153
  optimizer.load_state_dict(v_extra["optimizer.pth"])
154
+ if "scheduler.pth" in v_extra:
155
  scheduler.load_state_dict(v_extra["scheduler.pth"])
156
+ if "trainer.pth" in v_extra:
157
+ trainer_state = v_extra["trainer.pth"]
 
 
158
 
159
+ return {
160
+ "model": model,
161
+ "codec": codec,
162
+ "optimizer": optimizer,
163
+ "scheduler": scheduler,
164
+ "trainer_state": trainer_state,
165
+ }
166
 
 
 
167
 
 
 
 
 
 
168
 
169
+ def num_params_hook(o, p):
170
+ return o + f" {p/1e6:<.3f}M params."
171
+
172
+
173
+ def add_num_params_repr_hook(model):
174
+ import numpy as np
175
+ from functools import partial
176
+
177
+ for n, m in model.named_modules():
178
+ o = m.extra_repr()
179
+ p = sum([np.prod(p.size()) for p in m.parameters()])
180
+
181
+ setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p))
182
+
183
+
184
+ def accuracy(
185
+ preds: torch.Tensor,
186
+ target: torch.Tensor,
187
+ top_k: int = 1,
188
+ ignore_index: Optional[int] = None,
189
+ ) -> torch.Tensor:
190
+ # Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class)
191
+ preds = rearrange(preds, "b p s -> (b s) p")
192
+ target = rearrange(target, "b s -> (b s)")
193
+
194
+ # return torchmetrics.functional.accuracy(preds, target, task='multiclass', top_k=topk, num_classes=preds.shape[-1], ignore_index=ignore_index)
195
+ if ignore_index is not None:
196
+ # Create a mask for the ignored index
197
+ mask = target != ignore_index
198
+ # Apply the mask to the target and predictions
199
+ preds = preds[mask]
200
+ target = target[mask]
201
+
202
+ # Get the top-k predicted classes and their indices
203
+ _, pred_indices = torch.topk(preds, k=top_k, dim=-1)
204
+
205
+ # Determine if the true target is in the top-k predicted classes
206
+ correct = torch.sum(torch.eq(pred_indices, target.unsqueeze(1)), dim=1)
207
+
208
+ # Calculate the accuracy
209
+ accuracy = torch.mean(correct.float())
210
+
211
+ return accuracy
212
 
213
 
214
  @argbind.bind(without_prefix=True)
215
  def train(
216
  args,
217
  accel: at.ml.Accelerator,
 
218
  codec_ckpt: str = None,
219
+ seed: int = 0,
220
  save_path: str = "ckpt",
221
+ max_epochs: int = int(100e3),
222
+ epoch_length: int = 1000,
223
+ save_audio_epochs: int = 2,
224
+ save_epochs: list = [10, 50, 100, 200, 300, 400,],
225
+ batch_size: int = 48,
226
+ grad_acc_steps: int = 1,
227
  val_idx: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
228
  num_workers: int = 10,
229
+ detect_anomaly: bool = False,
230
+ grad_clip_val: float = 5.0,
231
  fine_tune: bool = False,
232
+ quiet: bool = False,
233
  ):
234
  assert codec_ckpt is not None, "codec_ckpt is required"
235
 
 
241
  writer = SummaryWriter(log_dir=f"{save_path}/logs/")
242
  argbind.dump_args(args, f"{save_path}/args.yml")
243
 
 
 
 
 
244
  # load the codec model
245
+ loaded = load(args, accel, save_path)
246
+ model = loaded["model"]
247
+ codec = loaded["codec"]
248
+ optimizer = loaded["optimizer"]
249
+ scheduler = loaded["scheduler"]
250
+ trainer_state = loaded["trainer_state"]
251
 
252
+ sample_rate = codec.sample_rate
253
+
254
+ # a better rng for sampling from our schedule
255
+ rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=seed)
256
+
257
+ # log a model summary w/ num params
258
+ if accel.local_rank == 0:
259
+ add_num_params_repr_hook(accel.unwrap(model))
260
+ with open(f"{save_path}/model.txt", "w") as f:
261
+ f.write(repr(accel.unwrap(model)))
262
+
263
+ # load the datasets
264
+ train_data, val_data, _ = build_datasets(args, sample_rate)
265
  train_dataloader = accel.prepare_dataloader(
266
+ train_data,
267
+ start_idx=trainer_state["start_idx"],
268
  num_workers=num_workers,
269
  batch_size=batch_size,
270
+ collate_fn=train_data.collate,
271
  )
272
  val_dataloader = accel.prepare_dataloader(
273
+ val_data,
274
  start_idx=0,
275
  num_workers=num_workers,
276
  batch_size=batch_size,
277
+ collate_fn=val_data.collate,
 
278
  )
 
279
 
280
+ criterion = CrossEntropyLoss()
281
 
282
  if fine_tune:
283
+ import loralib as lora
284
+ lora.mark_only_lora_as_trainable(model)
285
+
286
+
287
+ class Trainer(at.ml.BaseTrainer):
288
+ _last_grad_norm = 0.0
289
+
290
+ def _metrics(self, vn, z_hat, r, target, flat_mask, output):
291
+ for r_range in [(0, 0.5), (0.5, 1.0)]:
292
+ unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX)
293
+ masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
294
+
295
+ assert target.shape[0] == r.shape[0]
296
+ # grab the indices of the r values that are in the range
297
+ r_idx = (r >= r_range[0]) & (r < r_range[1])
298
+
299
+ # grab the target and z_hat values that are in the range
300
+ r_unmasked_target = unmasked_target[r_idx]
301
+ r_masked_target = masked_target[r_idx]
302
+ r_z_hat = z_hat[r_idx]
303
+
304
+ for topk in (1, 25):
305
+ s, e = r_range
306
+ tag = f"accuracy-{s}-{e}/top{topk}"
307
+
308
+ output[f"{tag}/unmasked"] = accuracy(
309
+ preds=r_z_hat,
310
+ target=r_unmasked_target,
311
+ ignore_index=IGNORE_INDEX,
312
+ top_k=topk,
313
+ )
314
+ output[f"{tag}/masked"] = accuracy(
315
+ preds=r_z_hat,
316
+ target=r_masked_target,
317
+ ignore_index=IGNORE_INDEX,
318
+ top_k=topk,
319
+ )
320
+
321
+ def train_loop(self, engine, batch):
322
+ model.train()
323
+ batch = at.util.prepare_batch(batch, accel.device)
324
+ signal = apply_transform(train_data.transform, batch)
325
+
326
+ output = {}
327
+ vn = accel.unwrap(model)
328
+ with accel.autocast():
329
+ with torch.inference_mode():
330
+ codec.to(accel.device)
331
+ z = codec.encode(signal.samples, signal.sample_rate)["codes"]
332
+ z = z[:, : vn.n_codebooks, :]
333
+
334
+ n_batch = z.shape[0]
335
+ r = rng.draw(n_batch)[:, 0].to(accel.device)
336
+
337
+ mask = pmask.random(z, r)
338
+ mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
339
+ z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
340
+
341
+ z_mask_latent = vn.embedding.from_codes(z_mask, codec)
342
+
343
+ dtype = torch.bfloat16 if accel.amp else None
344
+ with accel.autocast(dtype=dtype):
345
+ z_hat = model(z_mask_latent, r)
346
+
347
+ target = codebook_flatten(
348
+ z[:, vn.n_conditioning_codebooks :, :],
349
+ )
350
+
351
+ flat_mask = codebook_flatten(
352
+ mask[:, vn.n_conditioning_codebooks :, :],
353
+ )
354
+
355
+ # replace target with ignore index for masked tokens
356
+ t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
357
+ output["loss"] = criterion(z_hat, t_masked)
358
+
359
+ self._metrics(
360
+ vn=vn,
361
+ r=r,
362
+ z_hat=z_hat,
363
+ target=target,
364
+ flat_mask=flat_mask,
365
+ output=output,
366
+ )
367
+
368
+
369
+ accel.backward(output["loss"] / grad_acc_steps)
370
+
371
+ output["other/learning_rate"] = optimizer.param_groups[0]["lr"]
372
+ output["other/batch_size"] = z.shape[0]
373
+
374
+ if (
375
+ (engine.state.iteration % grad_acc_steps == 0)
376
+ or (engine.state.iteration % epoch_length == 0)
377
+ or (engine.state.iteration % epoch_length == 1)
378
+ ): # (or we reached the end of the epoch)
379
+ accel.scaler.unscale_(optimizer)
380
+ output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_(
381
+ model.parameters(), grad_clip_val
382
+ )
383
+ self._last_grad_norm = output["other/grad_norm"]
384
+
385
+ accel.step(optimizer)
386
+ optimizer.zero_grad()
387
+
388
+ scheduler.step()
389
+ accel.update()
390
+ else:
391
+ output["other/grad_norm"] = self._last_grad_norm
392
+
393
+ return {k: v for k, v in sorted(output.items())}
394
+
395
+ @torch.no_grad()
396
+ def val_loop(self, engine, batch):
397
+ model.eval()
398
+ codec.eval()
399
+ batch = at.util.prepare_batch(batch, accel.device)
400
+ signal = apply_transform(val_data.transform, batch)
401
+
402
+ vn = accel.unwrap(model)
403
+ z = codec.encode(signal.samples, signal.sample_rate)["codes"]
404
+ z = z[:, : vn.n_codebooks, :]
405
 
406
+ n_batch = z.shape[0]
407
+ r = rng.draw(n_batch)[:, 0].to(accel.device)
 
408
 
409
+ mask = pmask.random(z, r)
410
+ mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
411
+ z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
 
 
412
 
413
+ z_mask_latent = vn.embedding.from_codes(z_mask, codec)
 
414
 
415
+ z_hat = model(z_mask_latent, r)
 
 
 
416
 
417
+ target = codebook_flatten(
418
+ z[:, vn.n_conditioning_codebooks :, :],
419
  )
420
 
421
+ flat_mask = codebook_flatten(
422
+ mask[:, vn.n_conditioning_codebooks :, :]
423
+ )
424
 
425
+ output = {}
426
+ # replace target with ignore index for masked tokens
427
+ t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
428
+ output["loss"] = criterion(z_hat, t_masked)
429
+
430
+ self._metrics(
431
+ vn=vn,
432
+ r=r,
433
+ z_hat=z_hat,
434
+ target=target,
435
+ flat_mask=flat_mask,
436
+ output=output,
437
+ )
438
 
439
+ return output
 
440
 
441
+ def checkpoint(self, engine):
442
+ if accel.local_rank != 0:
443
+ print(f"ERROR:Skipping checkpoint on rank {accel.local_rank}")
444
+ return
445
+
446
+ metadata = {"logs": dict(engine.state.logs["epoch"])}
447
+
448
+ if self.state.epoch % save_audio_epochs == 0:
449
+ self.save_samples()
450
+
451
+ tags = ["latest"]
452
+ loss_key = "loss/val" if "loss/val" in metadata["logs"] else "loss/train"
453
+ self.print(f"Saving to {str(Path('.').absolute())}")
454
+
455
+ if self.state.epoch in save_epochs:
456
+ tags.append(f"epoch={self.state.epoch}")
457
+
458
+ if self.is_best(engine, loss_key):
459
+ self.print(f"Best model so far")
460
+ tags.append("best")
461
+
462
+ if fine_tune:
463
+ for tag in tags:
464
+ # save the lora model
465
+ (Path(save_path) / tag).mkdir(parents=True, exist_ok=True)
466
+ torch.save(
467
+ lora.lora_state_dict(accel.unwrap(model)),
468
+ f"{save_path}/{tag}/lora.pth"
469
+ )
470
+
471
+ for tag in tags:
472
+ model_extra = {
473
+ "optimizer.pth": optimizer.state_dict(),
474
+ "scheduler.pth": scheduler.state_dict(),
475
+ "trainer.pth": {
476
+ "start_idx": self.state.iteration * batch_size,
477
+ "state_dict": self.state_dict(),
478
+ },
479
+ "metadata.pth": metadata,
480
+ }
481
+
482
+ accel.unwrap(model).metadata = metadata
483
+ accel.unwrap(model).save_to_folder(
484
+ f"{save_path}/{tag}", model_extra,
485
+ )
486
+
487
+ def save_sampled(self, z):
488
+ num_samples = z.shape[0]
489
+
490
+ for i in range(num_samples):
491
+ sampled = accel.unwrap(model).generate(
492
+ codec=codec,
493
+ time_steps=z.shape[-1],
494
+ start_tokens=z[i : i + 1],
495
+ )
496
+ sampled.cpu().write_audio_to_tb(
497
+ f"sampled/{i}",
498
+ self.writer,
499
+ step=self.state.epoch,
500
+ plot_fn=None,
501
+ )
502
+
503
+
504
+ def save_imputation(self, z: torch.Tensor):
505
+ n_prefix = int(z.shape[-1] * 0.25)
506
+ n_suffix = int(z.shape[-1] * 0.25)
507
+
508
+ vn = accel.unwrap(model)
509
+
510
+ mask = pmask.inpaint(z, n_prefix, n_suffix)
511
+ mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
512
+ z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
513
+
514
+ imputed_noisy = vn.to_signal(z_mask, codec)
515
+ imputed_true = vn.to_signal(z, codec)
516
+
517
+ imputed = []
518
+ for i in range(len(z)):
519
+ imputed.append(
520
+ vn.generate(
521
+ codec=codec,
522
+ time_steps=z.shape[-1],
523
+ start_tokens=z[i][None, ...],
524
+ mask=mask[i][None, ...],
525
+ )
526
+ )
527
+ imputed = AudioSignal.batch(imputed)
528
+
529
+ for i in range(len(val_idx)):
530
+ imputed_noisy[i].cpu().write_audio_to_tb(
531
+ f"imputed_noisy/{i}",
532
+ self.writer,
533
+ step=self.state.epoch,
534
+ plot_fn=None,
535
+ )
536
+ imputed[i].cpu().write_audio_to_tb(
537
+ f"imputed/{i}",
538
+ self.writer,
539
+ step=self.state.epoch,
540
+ plot_fn=None,
541
+ )
542
+ imputed_true[i].cpu().write_audio_to_tb(
543
+ f"imputed_true/{i}",
544
+ self.writer,
545
+ step=self.state.epoch,
546
+ plot_fn=None,
547
+ )
548
+
549
+ @torch.no_grad()
550
+ def save_samples(self):
551
+ model.eval()
552
+ codec.eval()
553
+ vn = accel.unwrap(model)
554
+
555
+ batch = [val_data[i] for i in val_idx]
556
+ batch = at.util.prepare_batch(val_data.collate(batch), accel.device)
557
+
558
+ signal = apply_transform(val_data.transform, batch)
559
+
560
+ z = codec.encode(signal.samples, signal.sample_rate)["codes"]
561
+ z = z[:, : vn.n_codebooks, :]
562
+
563
+ r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device)
564
+
565
+
566
+ mask = pmask.random(z, r)
567
+ mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
568
+ z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
569
+
570
+ z_mask_latent = vn.embedding.from_codes(z_mask, codec)
571
+
572
+ z_hat = model(z_mask_latent, r)
573
+
574
+ z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
575
+ z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks)
576
+ z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
577
+
578
+ generated = vn.to_signal(z_pred, codec)
579
+ reconstructed = vn.to_signal(z, codec)
580
+ masked = vn.to_signal(z_mask.squeeze(1), codec)
581
+
582
+ for i in range(generated.batch_size):
583
+ audio_dict = {
584
+ "original": signal[i],
585
+ "masked": masked[i],
586
+ "generated": generated[i],
587
+ "reconstructed": reconstructed[i],
588
+ }
589
+ for k, v in audio_dict.items():
590
+ v.cpu().write_audio_to_tb(
591
+ f"samples/_{i}.r={r[i]:0.2f}/{k}",
592
+ self.writer,
593
+ step=self.state.epoch,
594
+ plot_fn=None,
595
+ )
596
+
597
+ self.save_sampled(z)
598
+ self.save_imputation(z)
599
+
600
+ trainer = Trainer(writer=writer, quiet=quiet)
601
+
602
+ if trainer_state["state_dict"] is not None:
603
+ trainer.load_state_dict(trainer_state["state_dict"])
604
+ if hasattr(train_dataloader.sampler, "set_epoch"):
605
+ train_dataloader.sampler.set_epoch(trainer.trainer.state.epoch)
606
+
607
+ trainer.run(
608
+ train_dataloader,
609
+ val_dataloader,
610
+ num_epochs=max_epochs,
611
+ epoch_length=epoch_length,
612
+ detect_anomaly=detect_anomaly,
613
+ )
614
 
615
 
616
  if __name__ == "__main__":
 
618
  args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0
619
  with argbind.scope(args):
620
  with Accelerator() as accel:
 
 
621
  train(args, accel)
scripts/utils/{data/augment.py → augment.py} RENAMED
@@ -5,19 +5,34 @@ from audiotools import AudioSignal
5
 
6
  import argbind
7
  import tqdm
8
- import torch
9
 
10
 
11
- from torch_pitch_shift import pitch_shift, get_fast_shifts
12
- from torch_time_stretch import time_stretch, get_fast_stretches
 
 
13
 
14
- from audiotools.core.util import sample_from_dist
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  @argbind.bind(without_prefix=True)
18
  def augment(
19
- audio_folder: Path = None,
20
- dest_folder: Path = None,
21
  n_augmentations: int = 10,
22
  ):
23
  """
@@ -26,8 +41,7 @@ def augment(
26
  The dest foler will contain a folder for each of the clean dataset's files.
27
  Under each of these folders, there will be a clean file and many augmented files.
28
  """
29
- assert audio_folder is not None
30
- assert dest_folder is not None
31
  audio_files = at.util.find_audio(audio_folder)
32
 
33
  for audio_file in tqdm.tqdm(audio_files):
@@ -35,33 +49,5 @@ def augment(
35
  subdir = subtree / audio_file.stem
36
  subdir.mkdir(parents=True, exist_ok=True)
37
 
38
- src = AudioSignal(audio_file).to("cuda" if torch.cuda.is_available() else "cpu")
39
-
40
-
41
- for i, chunk in tqdm.tqdm(enumerate(src.windows(10, 10))):
42
- # apply pedalboard transforms
43
- for j in range(n_augmentations):
44
- # pitch shift between -7 and 7 semitones
45
- import random
46
- dst = chunk.clone()
47
- dst.samples = pitch_shift(
48
- dst.samples,
49
- shift=random.choice(get_fast_shifts(src.sample_rate,
50
- condition=lambda x: x >= 0.25 and x <= 1.0)),
51
- sample_rate=src.sample_rate
52
- )
53
- dst.samples = time_stretch(
54
- dst.samples,
55
- stretch=random.choice(get_fast_stretches(src.sample_rate,
56
- condition=lambda x: x >= 0.667 and x <= 1.5, )),
57
- sample_rate=src.sample_rate,
58
- )
59
-
60
- dst.cpu().write(subdir / f"{i}-{j}.wav")
61
-
62
-
63
- if __name__ == "__main__":
64
- args = argbind.parse_args()
65
-
66
- with argbind.scope(args):
67
- augment()
 
5
 
6
  import argbind
7
  import tqdm
 
8
 
9
 
10
+ from pedalboard import (
11
+ Compressor, Gain, Chorus, LadderFilter, Phaser, Convolution, Reverb, Pedalboard
12
+ )
13
+ from pedalboard.io import AudioFile
14
 
15
+ # Read in a whole file, resampling to our desired sample rate:
16
+ samplerate = 44100.0
17
+ with AudioFile('guitar-input.wav').resampled_to(samplerate) as f:
18
+ audio = f.read(f.frames)
19
+
20
+ # Make a pretty interesting sounding guitar pedalboard:
21
+ board = Pedalboard([
22
+ Compressor(threshold_db=-50, ratio=25),
23
+ Gain(gain_db=30),
24
+ Chorus(),
25
+ LadderFilter(mode=LadderFilter.Mode.HPF12, cutoff_hz=900),
26
+ Phaser(),
27
+ Convolution("./guitar_amp.wav", 1.0),
28
+ Reverb(room_size=0.25),
29
+ ])
30
 
31
 
32
  @argbind.bind(without_prefix=True)
33
  def augment(
34
+ audio_folder: Path,
35
+ dest_folder: Path,
36
  n_augmentations: int = 10,
37
  ):
38
  """
 
41
  The dest foler will contain a folder for each of the clean dataset's files.
42
  Under each of these folders, there will be a clean file and many augmented files.
43
  """
44
+
 
45
  audio_files = at.util.find_audio(audio_folder)
46
 
47
  for audio_file in tqdm.tqdm(audio_files):
 
49
  subdir = subtree / audio_file.stem
50
  subdir.mkdir(parents=True, exist_ok=True)
51
 
52
+ # apply pedalboard transforms
53
+ for i in range(n_augmentations):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/utils/gtzan_embeddings.py DELETED
@@ -1,263 +0,0 @@
1
- """
2
- TODO: train a linear probe
3
- usage:
4
- python gtzan_embeddings.py --args.load conf/interface.yml --Interface.device cuda --path_to_gtzan /path/to/gtzan/genres_original --output_dir /path/to/output
5
- """
6
- from pathlib import Path
7
- from typing import List
8
-
9
- import audiotools as at
10
- from audiotools import AudioSignal
11
- import argbind
12
- import torch
13
- import numpy as np
14
- import zipfile
15
- import json
16
-
17
- from vampnet.interface import Interface
18
- import tqdm
19
-
20
- # bind the Interface to argbind
21
- Interface = argbind.bind(Interface)
22
-
23
- DEBUG = False
24
-
25
- def smart_plotly_export(fig, save_path):
26
- img_format = save_path.split('.')[-1]
27
- if img_format == 'html':
28
- fig.write_html(save_path)
29
- elif img_format == 'bytes':
30
- return fig.to_image(format='png')
31
- #TODO: come back and make this prettier
32
- elif img_format == 'numpy':
33
- import io
34
- from PIL import Image
35
-
36
- def plotly_fig2array(fig):
37
- #convert Plotly fig to an array
38
- fig_bytes = fig.to_image(format="png", width=1200, height=700)
39
- buf = io.BytesIO(fig_bytes)
40
- img = Image.open(buf)
41
- return np.asarray(img)
42
-
43
- return plotly_fig2array(fig)
44
- elif img_format == 'jpeg' or 'png' or 'webp':
45
- fig.write_image(save_path)
46
- else:
47
- raise ValueError("invalid image format")
48
-
49
- def dim_reduce(emb, labels, save_path, n_components=3, method='tsne', title=''):
50
- """
51
- dimensionality reduction for visualization!
52
- saves an html plotly figure to save_path
53
- parameters:
54
- emb (np.ndarray): the samples to be reduces with shape (samples, features)
55
- labels (list): list of labels for embedding
56
- save_path (str): path where u wanna save ur figure
57
- method (str): umap, tsne, or pca
58
- title (str): title for ur figure
59
- returns:
60
- proj (np.ndarray): projection vector with shape (samples, dimensions)
61
- """
62
- import pandas as pd
63
- import plotly.express as px
64
- if method == 'umap':
65
- reducer = umap.UMAP(n_components=n_components)
66
- elif method == 'tsne':
67
- from sklearn.manifold import TSNE
68
- reducer = TSNE(n_components=n_components)
69
- elif method == 'pca':
70
- from sklearn.decomposition import PCA
71
- reducer = PCA(n_components=n_components)
72
- else:
73
- raise ValueError
74
-
75
- proj = reducer.fit_transform(emb)
76
-
77
- if n_components == 2:
78
- df = pd.DataFrame(dict(
79
- x=proj[:, 0],
80
- y=proj[:, 1],
81
- instrument=labels
82
- ))
83
- fig = px.scatter(df, x='x', y='y', color='instrument',
84
- title=title+f"_{method}")
85
-
86
- elif n_components == 3:
87
- df = pd.DataFrame(dict(
88
- x=proj[:, 0],
89
- y=proj[:, 1],
90
- z=proj[:, 2],
91
- instrument=labels
92
- ))
93
- fig = px.scatter_3d(df, x='x', y='y', z='z',
94
- color='instrument',
95
- title=title)
96
- else:
97
- raise ValueError("cant plot more than 3 components")
98
-
99
- fig.update_traces(marker=dict(size=6,
100
- line=dict(width=1,
101
- color='DarkSlateGrey')),
102
- selector=dict(mode='markers'))
103
-
104
- return smart_plotly_export(fig, save_path)
105
-
106
-
107
-
108
- # per JukeMIR, we want the emebddings from the middle layer?
109
- def vampnet_embed(sig: AudioSignal, interface: Interface, layer=10):
110
- with torch.inference_mode():
111
- # preprocess the signal
112
- sig = interface.preprocess(sig)
113
-
114
- # get the coarse vampnet model
115
- vampnet = interface.coarse
116
-
117
- # get the tokens
118
- z = interface.encode(sig)[:, :vampnet.n_codebooks, :]
119
- z_latents = vampnet.embedding.from_codes(z, interface.codec)
120
-
121
- # do a forward pass through the model, get the embeddings
122
- _z, embeddings = vampnet(z_latents, return_activations=True)
123
- # print(f"got embeddings with shape {embeddings.shape}")
124
- # [layer, batch, time, n_dims]
125
- # [20, 1, 600ish, 768]
126
-
127
-
128
- # squeeze batch dim (1 bc layer should be dim 0)
129
- assert embeddings.shape[1] == 1, f"expected batch dim to be 1, got {embeddings.shape[0]}"
130
- embeddings = embeddings.squeeze(1)
131
-
132
- num_layers = embeddings.shape[0]
133
- assert layer < num_layers, f"layer {layer} is out of bounds for model with {num_layers} layers"
134
-
135
- # do meanpooling over the time dimension
136
- embeddings = embeddings.mean(dim=-2)
137
- # [20, 768]
138
-
139
- # return the embeddings
140
- return embeddings
141
-
142
- from dataclasses import dataclass, fields
143
- @dataclass
144
- class Embedding:
145
- genre: str
146
- filename: str
147
- embedding: np.ndarray
148
-
149
- def save(self, path):
150
- """Save the Embedding object to a given path as a zip file."""
151
- with zipfile.ZipFile(path, 'w') as archive:
152
-
153
- # Save numpy array
154
- with archive.open('embedding.npy', 'w') as f:
155
- np.save(f, self.embedding)
156
-
157
- # Save non-numpy data as json
158
- non_numpy_data = {f.name: getattr(self, f.name) for f in fields(self) if f.name != 'embedding'}
159
- with archive.open('data.json', 'w') as f:
160
- f.write(json.dumps(non_numpy_data).encode('utf-8'))
161
-
162
- @classmethod
163
- def load(cls, path):
164
- """Load the Embedding object from a given zip path."""
165
- with zipfile.ZipFile(path, 'r') as archive:
166
-
167
- # Load numpy array
168
- with archive.open('embedding.npy') as f:
169
- embedding = np.load(f)
170
-
171
- # Load non-numpy data from json
172
- with archive.open('data.json') as f:
173
- data = json.loads(f.read().decode('utf-8'))
174
-
175
- return cls(embedding=embedding, **data)
176
-
177
-
178
- @argbind.bind(without_prefix=True)
179
- def main(
180
- path_to_gtzan: str = None,
181
- cache_dir: str = "./.gtzan_emb_cache",
182
- output_dir: str = "./gtzan_vampnet_embeddings",
183
- layers: List[int] = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19]
184
- ):
185
- path_to_gtzan = Path(path_to_gtzan)
186
- assert path_to_gtzan.exists(), f"{path_to_gtzan} does not exist"
187
-
188
- cache_dir = Path(cache_dir)
189
- output_dir = Path(output_dir)
190
- output_dir.mkdir(exist_ok=True, parents=True)
191
-
192
- # load our interface
193
- # argbind will automatically load the default config,
194
- interface = Interface()
195
-
196
- # gtzan should have a folder for each genre, so let's get the list of genres
197
- genres = [Path(x).name for x in path_to_gtzan.iterdir() if x.is_dir()]
198
- print(f"Found {len(genres)} genres")
199
- print(f"genres: {genres}")
200
-
201
- # collect audio files, genres, and embeddings
202
- data = []
203
- for genre in genres:
204
- audio_files = list(at.util.find_audio(path_to_gtzan / genre))
205
- print(f"Found {len(audio_files)} audio files for genre {genre}")
206
-
207
- for audio_file in tqdm.tqdm(audio_files, desc=f"embedding genre {genre}"):
208
- # check if we have a cached embedding for this file
209
- cached_path = (cache_dir / f"{genre}_{audio_file.stem}.emb")
210
- if cached_path.exists():
211
- # if so, load it
212
- if DEBUG:
213
- print(f"loading cached embedding for {cached_path.stem}")
214
- embedding = Embedding.load(cached_path)
215
- data.append(embedding)
216
- else:
217
- try:
218
- sig = AudioSignal(audio_file)
219
- except Exception as e:
220
- print(f"failed to load {audio_file.name} with error {e}")
221
- print(f"skipping {audio_file.name}")
222
- continue
223
-
224
- # gets the embedding
225
- emb = vampnet_embed(sig, interface).cpu().numpy()
226
-
227
- # create an embedding we can save/load
228
- embedding = Embedding(
229
- genre=genre,
230
- filename=audio_file.name,
231
- embedding=emb
232
- )
233
-
234
- # cache the embeddings
235
- cached_path.parent.mkdir(exist_ok=True, parents=True)
236
- embedding.save(cached_path)
237
-
238
- # now, let's do a dim reduction on the embeddings
239
- # and visualize them.
240
-
241
- # collect a list of embeddings and labels
242
- embeddings = [d.embedding for d in data]
243
- labels = [d.genre for d in data]
244
-
245
- # convert the embeddings to a numpy array
246
- embeddings = np.stack(embeddings)
247
-
248
- # do dimensionality reduction for each layer we're given
249
- for layer in tqdm.tqdm(layers, desc="dim reduction"):
250
- dim_reduce(
251
- embeddings[:, layer, :], labels,
252
- save_path=str(output_dir / f'vampnet-gtzan-layer={layer}.html'),
253
- n_components=2, method='tsne',
254
- title=f'vampnet-gtzan-layer={layer}'
255
- )
256
-
257
-
258
-
259
-
260
- if __name__ == "__main__":
261
- args = argbind.parse_args()
262
- with argbind.scope(args):
263
- main()