Hugo Flores Garcia commited on
Commit
457f9d1
·
1 Parent(s): fca3233

update audiotools version, update recipe

Browse files
.gitignore CHANGED
@@ -180,5 +180,4 @@ samples*/
180
  models-all/
181
  models.zip
182
  .git-old
183
-
184
- conf/generated/*
 
180
  models-all/
181
  models.zip
182
  .git-old
183
+ conf/generated/*
 
conf/lora/lora.yml CHANGED
@@ -3,20 +3,18 @@ $include:
3
 
4
  fine_tune: True
5
 
6
- train/AudioDataset.n_examples: 10000000
7
-
8
- val/AudioDataset.n_examples: 10
9
 
10
 
11
  NoamScheduler.warmup: 500
12
 
13
  batch_size: 7
14
  num_workers: 7
15
- epoch_length: 100
16
- save_audio_epochs: 10
17
 
18
  AdamW.lr: 0.0001
19
 
20
  # let's us organize sound classes into folders and choose from those sound classes uniformly
21
  AudioDataset.without_replacement: False
22
- max_epochs: 500
 
3
 
4
  fine_tune: True
5
 
6
+ train/AudioDataset.n_examples: 100000000
7
+ val/AudioDataset.n_examples: 100
 
8
 
9
 
10
  NoamScheduler.warmup: 500
11
 
12
  batch_size: 7
13
  num_workers: 7
14
+ save_iters: [100000, 200000, 300000, 4000000, 500000]
 
15
 
16
  AdamW.lr: 0.0001
17
 
18
  # let's us organize sound classes into folders and choose from those sound classes uniformly
19
  AudioDataset.without_replacement: False
20
+ num_iters: 500000
conf/vampnet.yml CHANGED
@@ -1,21 +1,17 @@
1
 
2
- codec_ckpt: ./models/spotdl/codec.pth
3
  save_path: ckpt
4
- max_epochs: 1000
5
- epoch_length: 1000
6
- save_audio_epochs: 2
7
- val_idx: [0,1,2,3,4,5,6,7,8,9]
8
 
9
- prefix_amt: 0.0
10
- suffix_amt: 0.0
11
- prefix_dropout: 0.1
12
- suffix_dropout: 0.1
 
13
 
14
  batch_size: 8
15
  num_workers: 10
16
 
17
  # Optimization
18
- detect_anomaly: false
19
  amp: false
20
 
21
  CrossEntropyLoss.label_smoothing: 0.1
@@ -25,9 +21,6 @@ AdamW.lr: 0.001
25
  NoamScheduler.factor: 2.0
26
  NoamScheduler.warmup: 10000
27
 
28
- PitchShift.shift_amount: [const, 0]
29
- PitchShift.prob: 0.0
30
-
31
  VampNet.vocab_size: 1024
32
  VampNet.n_codebooks: 4
33
  VampNet.n_conditioning_codebooks: 0
@@ -48,12 +41,9 @@ AudioDataset.duration: 10.0
48
 
49
  train/AudioDataset.n_examples: 10000000
50
  train/AudioLoader.sources:
51
- - /data/spotdl/audio/train
52
 
53
  val/AudioDataset.n_examples: 2000
54
  val/AudioLoader.sources:
55
- - /data/spotdl/audio/val
56
 
57
- test/AudioDataset.n_examples: 1000
58
- test/AudioLoader.sources:
59
- - /data/spotdl/audio/test
 
1
 
2
+ codec_ckpt: ./models/vampnet/codec.pth
3
  save_path: ckpt
 
 
 
 
4
 
5
+ num_iters: 1000000000
6
+ save_iters: [10000, 50000, 100000, 300000, 500000]
7
+ val_idx: [0,1,2,3,4,5,6,7,8,9]
8
+ sample_freq: 10000
9
+ val_freq: 1000
10
 
11
  batch_size: 8
12
  num_workers: 10
13
 
14
  # Optimization
 
15
  amp: false
16
 
17
  CrossEntropyLoss.label_smoothing: 0.1
 
21
  NoamScheduler.factor: 2.0
22
  NoamScheduler.warmup: 10000
23
 
 
 
 
24
  VampNet.vocab_size: 1024
25
  VampNet.n_codebooks: 4
26
  VampNet.n_conditioning_codebooks: 0
 
41
 
42
  train/AudioDataset.n_examples: 10000000
43
  train/AudioLoader.sources:
44
+ - /media/CHONK/hugo/spotdl/audio-train
45
 
46
  val/AudioDataset.n_examples: 2000
47
  val/AudioLoader.sources:
48
+ - /media/CHONK/hugo/spotdl/audio-val
49
 
 
 
 
scripts/exp/fine_tune.py CHANGED
@@ -35,7 +35,7 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
35
  "AudioDataset.duration": 3.0,
36
  "AudioDataset.loudness_cutoff": -40.0,
37
  "save_path": f"./runs/{name}/c2f",
38
- "fine_tune_checkpoint": "./models/spotdl/c2f.pth"
39
  }
40
 
41
  finetune_coarse_conf = {
@@ -44,17 +44,17 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
44
  "train/AudioLoader.sources": audio_files_or_folders,
45
  "val/AudioLoader.sources": audio_files_or_folders,
46
  "save_path": f"./runs/{name}/coarse",
47
- "fine_tune_checkpoint": "./models/spotdl/coarse.pth"
48
  }
49
 
50
  interface_conf = {
51
- "Interface.coarse_ckpt": f"./models/spotdl/coarse.pth",
52
  "Interface.coarse_lora_ckpt": f"./runs/{name}/coarse/latest/lora.pth",
53
 
54
- "Interface.coarse2fine_ckpt": f"./models/spotdl/c2f.pth",
55
  "Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
56
 
57
- "Interface.codec_ckpt": "./models/spotdl/codec.pth",
58
  "AudioLoader.sources": [audio_files_or_folders],
59
  }
60
 
 
35
  "AudioDataset.duration": 3.0,
36
  "AudioDataset.loudness_cutoff": -40.0,
37
  "save_path": f"./runs/{name}/c2f",
38
+ "fine_tune_checkpoint": "./models/vampnet/c2f.pth"
39
  }
40
 
41
  finetune_coarse_conf = {
 
44
  "train/AudioLoader.sources": audio_files_or_folders,
45
  "val/AudioLoader.sources": audio_files_or_folders,
46
  "save_path": f"./runs/{name}/coarse",
47
+ "fine_tune_checkpoint": "./models/vampnet/coarse.pth"
48
  }
49
 
50
  interface_conf = {
51
+ "Interface.coarse_ckpt": f"./models/vampnet/coarse.pth",
52
  "Interface.coarse_lora_ckpt": f"./runs/{name}/coarse/latest/lora.pth",
53
 
54
+ "Interface.coarse2fine_ckpt": f"./models/vampnet/c2f.pth",
55
  "Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
56
 
57
+ "Interface.codec_ckpt": "./models/vampnet/codec.pth",
58
  "AudioLoader.sources": [audio_files_or_folders],
59
  }
60
 
scripts/exp/train.py CHANGED
@@ -1,9 +1,9 @@
1
  import os
2
- import subprocess
3
- import time
4
  import warnings
5
  from pathlib import Path
6
  from typing import Optional
 
7
 
8
  import argbind
9
  import audiotools as at
@@ -23,6 +23,12 @@ from vampnet import mask as pmask
23
  # from dac.model.dac import DAC
24
  from lac.model.lac import LAC as DAC
25
 
 
 
 
 
 
 
26
 
27
  # Enable cudnn autotuner to speed up training
28
  # (can be altered by the funcs.seed function)
@@ -85,11 +91,7 @@ def build_datasets(args, sample_rate: int):
85
  )
86
  with argbind.scope(args, "val"):
87
  val_data = AudioDataset(AudioLoader(), sample_rate, transform=build_transform())
88
- with argbind.scope(args, "test"):
89
- test_data = AudioDataset(
90
- AudioLoader(), sample_rate, transform=build_transform()
91
- )
92
- return train_data, val_data, test_data
93
 
94
 
95
  def rand_float(shape, low, high, rng):
@@ -100,16 +102,393 @@ def flip_coin(shape, p, rng):
100
  return rng.draw(shape)[:, 0] < p
101
 
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  @argbind.bind(without_prefix=True)
104
  def load(
105
  args,
106
  accel: at.ml.Accelerator,
 
107
  save_path: str,
108
  resume: bool = False,
109
  tag: str = "latest",
110
  load_weights: bool = False,
111
  fine_tune_checkpoint: Optional[str] = None,
112
- ):
 
113
  codec = DAC.load(args["codec_ckpt"], map_location="cpu")
114
  codec.eval()
115
 
@@ -121,6 +500,7 @@ def load(
121
  "map_location": "cpu",
122
  "package": not load_weights,
123
  }
 
124
  if (Path(kwargs["folder"]) / "vampnet").exists():
125
  model, v_extra = VampNet.load_from_folder(**kwargs)
126
  else:
@@ -147,89 +527,57 @@ def load(
147
  scheduler = NoamScheduler(optimizer, d_model=accel.unwrap(model).embedding_dim)
148
  scheduler.step()
149
 
150
- trainer_state = {"state_dict": None, "start_idx": 0}
151
-
152
  if "optimizer.pth" in v_extra:
153
  optimizer.load_state_dict(v_extra["optimizer.pth"])
154
- if "scheduler.pth" in v_extra:
155
  scheduler.load_state_dict(v_extra["scheduler.pth"])
156
- if "trainer.pth" in v_extra:
157
- trainer_state = v_extra["trainer.pth"]
158
-
159
- return {
160
- "model": model,
161
- "codec": codec,
162
- "optimizer": optimizer,
163
- "scheduler": scheduler,
164
- "trainer_state": trainer_state,
165
- }
166
-
167
-
168
-
169
- def num_params_hook(o, p):
170
- return o + f" {p/1e6:<.3f}M params."
171
-
172
-
173
- def add_num_params_repr_hook(model):
174
- import numpy as np
175
- from functools import partial
176
-
177
- for n, m in model.named_modules():
178
- o = m.extra_repr()
179
- p = sum([np.prod(p.size()) for p in m.parameters()])
180
-
181
- setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p))
182
-
183
-
184
- def accuracy(
185
- preds: torch.Tensor,
186
- target: torch.Tensor,
187
- top_k: int = 1,
188
- ignore_index: Optional[int] = None,
189
- ) -> torch.Tensor:
190
- # Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class)
191
- preds = rearrange(preds, "b p s -> (b s) p")
192
- target = rearrange(target, "b s -> (b s)")
193
-
194
- # return torchmetrics.functional.accuracy(preds, target, task='multiclass', top_k=topk, num_classes=preds.shape[-1], ignore_index=ignore_index)
195
- if ignore_index is not None:
196
- # Create a mask for the ignored index
197
- mask = target != ignore_index
198
- # Apply the mask to the target and predictions
199
- preds = preds[mask]
200
- target = target[mask]
201
 
202
- # Get the top-k predicted classes and their indices
203
- _, pred_indices = torch.topk(preds, k=top_k, dim=-1)
204
 
205
- # Determine if the true target is in the top-k predicted classes
206
- correct = torch.sum(torch.eq(pred_indices, target.unsqueeze(1)), dim=1)
207
 
208
- # Calculate the accuracy
209
- accuracy = torch.mean(correct.float())
 
 
 
210
 
211
- return accuracy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
 
214
  @argbind.bind(without_prefix=True)
215
  def train(
216
  args,
217
  accel: at.ml.Accelerator,
218
- codec_ckpt: str = None,
219
  seed: int = 0,
 
220
  save_path: str = "ckpt",
221
- max_epochs: int = int(100e3),
222
- epoch_length: int = 1000,
223
- save_audio_epochs: int = 2,
224
- save_epochs: list = [10, 50, 100, 200, 300, 400,],
225
- batch_size: int = 48,
226
- grad_acc_steps: int = 1,
227
  val_idx: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
228
  num_workers: int = 10,
229
- detect_anomaly: bool = False,
230
- grad_clip_val: float = 5.0,
231
  fine_tune: bool = False,
232
- quiet: bool = False,
233
  ):
234
  assert codec_ckpt is not None, "codec_ckpt is required"
235
 
@@ -241,376 +589,76 @@ def train(
241
  writer = SummaryWriter(log_dir=f"{save_path}/logs/")
242
  argbind.dump_args(args, f"{save_path}/args.yml")
243
 
244
- # load the codec model
245
- loaded = load(args, accel, save_path)
246
- model = loaded["model"]
247
- codec = loaded["codec"]
248
- optimizer = loaded["optimizer"]
249
- scheduler = loaded["scheduler"]
250
- trainer_state = loaded["trainer_state"]
251
-
252
- sample_rate = codec.sample_rate
253
 
254
- # a better rng for sampling from our schedule
255
- rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=seed)
 
 
 
 
256
 
257
- # log a model summary w/ num params
258
- if accel.local_rank == 0:
259
- add_num_params_repr_hook(accel.unwrap(model))
260
- with open(f"{save_path}/model.txt", "w") as f:
261
- f.write(repr(accel.unwrap(model)))
262
 
263
- # load the datasets
264
- train_data, val_data, _ = build_datasets(args, sample_rate)
265
  train_dataloader = accel.prepare_dataloader(
266
- train_data,
267
- start_idx=trainer_state["start_idx"],
268
  num_workers=num_workers,
269
  batch_size=batch_size,
270
- collate_fn=train_data.collate,
271
  )
272
  val_dataloader = accel.prepare_dataloader(
273
- val_data,
274
  start_idx=0,
275
  num_workers=num_workers,
276
  batch_size=batch_size,
277
- collate_fn=val_data.collate,
 
278
  )
279
 
280
- criterion = CrossEntropyLoss()
281
 
282
  if fine_tune:
283
- import loralib as lora
284
- lora.mark_only_lora_as_trainable(model)
285
-
286
-
287
- class Trainer(at.ml.BaseTrainer):
288
- _last_grad_norm = 0.0
289
-
290
- def _metrics(self, vn, z_hat, r, target, flat_mask, output):
291
- for r_range in [(0, 0.5), (0.5, 1.0)]:
292
- unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX)
293
- masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
294
-
295
- assert target.shape[0] == r.shape[0]
296
- # grab the indices of the r values that are in the range
297
- r_idx = (r >= r_range[0]) & (r < r_range[1])
298
-
299
- # grab the target and z_hat values that are in the range
300
- r_unmasked_target = unmasked_target[r_idx]
301
- r_masked_target = masked_target[r_idx]
302
- r_z_hat = z_hat[r_idx]
303
-
304
- for topk in (1, 25):
305
- s, e = r_range
306
- tag = f"accuracy-{s}-{e}/top{topk}"
307
-
308
- output[f"{tag}/unmasked"] = accuracy(
309
- preds=r_z_hat,
310
- target=r_unmasked_target,
311
- ignore_index=IGNORE_INDEX,
312
- top_k=topk,
313
- )
314
- output[f"{tag}/masked"] = accuracy(
315
- preds=r_z_hat,
316
- target=r_masked_target,
317
- ignore_index=IGNORE_INDEX,
318
- top_k=topk,
319
- )
320
-
321
- def train_loop(self, engine, batch):
322
- model.train()
323
- batch = at.util.prepare_batch(batch, accel.device)
324
- signal = apply_transform(train_data.transform, batch)
325
-
326
- output = {}
327
- vn = accel.unwrap(model)
328
- with accel.autocast():
329
- with torch.inference_mode():
330
- codec.to(accel.device)
331
- z = codec.encode(signal.samples, signal.sample_rate)["codes"]
332
- z = z[:, : vn.n_codebooks, :]
333
-
334
- n_batch = z.shape[0]
335
- r = rng.draw(n_batch)[:, 0].to(accel.device)
336
-
337
- mask = pmask.random(z, r)
338
- mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
339
- z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
340
-
341
- z_mask_latent = vn.embedding.from_codes(z_mask, codec)
342
-
343
- dtype = torch.bfloat16 if accel.amp else None
344
- with accel.autocast(dtype=dtype):
345
- z_hat = model(z_mask_latent, r)
346
-
347
- target = codebook_flatten(
348
- z[:, vn.n_conditioning_codebooks :, :],
349
- )
350
-
351
- flat_mask = codebook_flatten(
352
- mask[:, vn.n_conditioning_codebooks :, :],
353
- )
354
-
355
- # replace target with ignore index for masked tokens
356
- t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
357
- output["loss"] = criterion(z_hat, t_masked)
358
-
359
- self._metrics(
360
- vn=vn,
361
- r=r,
362
- z_hat=z_hat,
363
- target=target,
364
- flat_mask=flat_mask,
365
- output=output,
366
- )
367
-
368
-
369
- accel.backward(output["loss"] / grad_acc_steps)
370
-
371
- output["other/learning_rate"] = optimizer.param_groups[0]["lr"]
372
- output["other/batch_size"] = z.shape[0]
373
-
374
- if (
375
- (engine.state.iteration % grad_acc_steps == 0)
376
- or (engine.state.iteration % epoch_length == 0)
377
- or (engine.state.iteration % epoch_length == 1)
378
- ): # (or we reached the end of the epoch)
379
- accel.scaler.unscale_(optimizer)
380
- output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_(
381
- model.parameters(), grad_clip_val
382
- )
383
- self._last_grad_norm = output["other/grad_norm"]
384
-
385
- accel.step(optimizer)
386
- optimizer.zero_grad()
387
-
388
- scheduler.step()
389
- accel.update()
390
- else:
391
- output["other/grad_norm"] = self._last_grad_norm
392
-
393
- return {k: v for k, v in sorted(output.items())}
394
-
395
- @torch.no_grad()
396
- def val_loop(self, engine, batch):
397
- model.eval()
398
- codec.eval()
399
- batch = at.util.prepare_batch(batch, accel.device)
400
- signal = apply_transform(val_data.transform, batch)
401
-
402
- vn = accel.unwrap(model)
403
- z = codec.encode(signal.samples, signal.sample_rate)["codes"]
404
- z = z[:, : vn.n_codebooks, :]
405
-
406
- n_batch = z.shape[0]
407
- r = rng.draw(n_batch)[:, 0].to(accel.device)
408
-
409
- mask = pmask.random(z, r)
410
- mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
411
- z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
412
 
413
- z_mask_latent = vn.embedding.from_codes(z_mask, codec)
 
 
414
 
415
- z_hat = model(z_mask_latent, r)
 
 
 
 
416
 
417
- target = codebook_flatten(
418
- z[:, vn.n_conditioning_codebooks :, :],
419
- )
420
 
421
- flat_mask = codebook_flatten(
422
- mask[:, vn.n_conditioning_codebooks :, :]
423
- )
424
 
425
- output = {}
426
- # replace target with ignore index for masked tokens
427
- t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
428
- output["loss"] = criterion(z_hat, t_masked)
429
-
430
- self._metrics(
431
- vn=vn,
432
- r=r,
433
- z_hat=z_hat,
434
- target=target,
435
- flat_mask=flat_mask,
436
- output=output,
437
  )
438
 
439
- return output
 
440
 
441
- def checkpoint(self, engine):
442
- if accel.local_rank != 0:
443
- print(f"ERROR:Skipping checkpoint on rank {accel.local_rank}")
444
- return
445
-
446
- metadata = {"logs": dict(engine.state.logs["epoch"])}
447
-
448
- if self.state.epoch % save_audio_epochs == 0:
449
- self.save_samples()
450
-
451
- tags = ["latest"]
452
- loss_key = "loss/val" if "loss/val" in metadata["logs"] else "loss/train"
453
- self.print(f"Saving to {str(Path('.').absolute())}")
454
-
455
- if self.state.epoch in save_epochs:
456
- tags.append(f"epoch={self.state.epoch}")
457
-
458
- if self.is_best(engine, loss_key):
459
- self.print(f"Best model so far")
460
- tags.append("best")
461
-
462
- if fine_tune:
463
- for tag in tags:
464
- # save the lora model
465
- (Path(save_path) / tag).mkdir(parents=True, exist_ok=True)
466
- torch.save(
467
- lora.lora_state_dict(accel.unwrap(model)),
468
- f"{save_path}/{tag}/lora.pth"
469
- )
470
-
471
- for tag in tags:
472
- model_extra = {
473
- "optimizer.pth": optimizer.state_dict(),
474
- "scheduler.pth": scheduler.state_dict(),
475
- "trainer.pth": {
476
- "start_idx": self.state.iteration * batch_size,
477
- "state_dict": self.state_dict(),
478
- },
479
- "metadata.pth": metadata,
480
- }
481
-
482
- accel.unwrap(model).metadata = metadata
483
- accel.unwrap(model).save_to_folder(
484
- f"{save_path}/{tag}", model_extra,
485
- )
486
-
487
- def save_sampled(self, z):
488
- num_samples = z.shape[0]
489
-
490
- for i in range(num_samples):
491
- sampled = accel.unwrap(model).generate(
492
- codec=codec,
493
- time_steps=z.shape[-1],
494
- start_tokens=z[i : i + 1],
495
- )
496
- sampled.cpu().write_audio_to_tb(
497
- f"sampled/{i}",
498
- self.writer,
499
- step=self.state.epoch,
500
- plot_fn=None,
501
- )
502
-
503
-
504
- def save_imputation(self, z: torch.Tensor):
505
- n_prefix = int(z.shape[-1] * 0.25)
506
- n_suffix = int(z.shape[-1] * 0.25)
507
-
508
- vn = accel.unwrap(model)
509
-
510
- mask = pmask.inpaint(z, n_prefix, n_suffix)
511
- mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
512
- z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
513
-
514
- imputed_noisy = vn.to_signal(z_mask, codec)
515
- imputed_true = vn.to_signal(z, codec)
516
-
517
- imputed = []
518
- for i in range(len(z)):
519
- imputed.append(
520
- vn.generate(
521
- codec=codec,
522
- time_steps=z.shape[-1],
523
- start_tokens=z[i][None, ...],
524
- mask=mask[i][None, ...],
525
- )
526
- )
527
- imputed = AudioSignal.batch(imputed)
528
-
529
- for i in range(len(val_idx)):
530
- imputed_noisy[i].cpu().write_audio_to_tb(
531
- f"imputed_noisy/{i}",
532
- self.writer,
533
- step=self.state.epoch,
534
- plot_fn=None,
535
- )
536
- imputed[i].cpu().write_audio_to_tb(
537
- f"imputed/{i}",
538
- self.writer,
539
- step=self.state.epoch,
540
- plot_fn=None,
541
- )
542
- imputed_true[i].cpu().write_audio_to_tb(
543
- f"imputed_true/{i}",
544
- self.writer,
545
- step=self.state.epoch,
546
- plot_fn=None,
547
- )
548
-
549
- @torch.no_grad()
550
- def save_samples(self):
551
- model.eval()
552
- codec.eval()
553
- vn = accel.unwrap(model)
554
-
555
- batch = [val_data[i] for i in val_idx]
556
- batch = at.util.prepare_batch(val_data.collate(batch), accel.device)
557
-
558
- signal = apply_transform(val_data.transform, batch)
559
-
560
- z = codec.encode(signal.samples, signal.sample_rate)["codes"]
561
- z = z[:, : vn.n_codebooks, :]
562
-
563
- r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device)
564
-
565
-
566
- mask = pmask.random(z, r)
567
- mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
568
- z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
569
-
570
- z_mask_latent = vn.embedding.from_codes(z_mask, codec)
571
-
572
- z_hat = model(z_mask_latent, r)
573
 
574
- z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
575
- z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks)
576
- z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
577
 
578
- generated = vn.to_signal(z_pred, codec)
579
- reconstructed = vn.to_signal(z, codec)
580
- masked = vn.to_signal(z_mask.squeeze(1), codec)
581
-
582
- for i in range(generated.batch_size):
583
- audio_dict = {
584
- "original": signal[i],
585
- "masked": masked[i],
586
- "generated": generated[i],
587
- "reconstructed": reconstructed[i],
588
- }
589
- for k, v in audio_dict.items():
590
- v.cpu().write_audio_to_tb(
591
- f"samples/_{i}.r={r[i]:0.2f}/{k}",
592
- self.writer,
593
- step=self.state.epoch,
594
- plot_fn=None,
595
- )
596
-
597
- self.save_sampled(z)
598
- self.save_imputation(z)
599
-
600
- trainer = Trainer(writer=writer, quiet=quiet)
601
-
602
- if trainer_state["state_dict"] is not None:
603
- trainer.load_state_dict(trainer_state["state_dict"])
604
- if hasattr(train_dataloader.sampler, "set_epoch"):
605
- train_dataloader.sampler.set_epoch(trainer.trainer.state.epoch)
606
-
607
- trainer.run(
608
- train_dataloader,
609
- val_dataloader,
610
- num_epochs=max_epochs,
611
- epoch_length=epoch_length,
612
- detect_anomaly=detect_anomaly,
613
- )
614
 
615
 
616
  if __name__ == "__main__":
@@ -618,4 +666,6 @@ if __name__ == "__main__":
618
  args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0
619
  with argbind.scope(args):
620
  with Accelerator() as accel:
 
 
621
  train(args, accel)
 
1
  import os
2
+ import sys
 
3
  import warnings
4
  from pathlib import Path
5
  from typing import Optional
6
+ from dataclasses import dataclass
7
 
8
  import argbind
9
  import audiotools as at
 
23
  # from dac.model.dac import DAC
24
  from lac.model.lac import LAC as DAC
25
 
26
+ from audiotools.ml.decorators import (
27
+ timer, Tracker, when
28
+ )
29
+
30
+ import loralib as lora
31
+
32
 
33
  # Enable cudnn autotuner to speed up training
34
  # (can be altered by the funcs.seed function)
 
91
  )
92
  with argbind.scope(args, "val"):
93
  val_data = AudioDataset(AudioLoader(), sample_rate, transform=build_transform())
94
+ return train_data, val_data
 
 
 
 
95
 
96
 
97
  def rand_float(shape, low, high, rng):
 
102
  return rng.draw(shape)[:, 0] < p
103
 
104
 
105
+ def num_params_hook(o, p):
106
+ return o + f" {p/1e6:<.3f}M params."
107
+
108
+
109
+ def add_num_params_repr_hook(model):
110
+ import numpy as np
111
+ from functools import partial
112
+
113
+ for n, m in model.named_modules():
114
+ o = m.extra_repr()
115
+ p = sum([np.prod(p.size()) for p in m.parameters()])
116
+
117
+ setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p))
118
+
119
+
120
+ def accuracy(
121
+ preds: torch.Tensor,
122
+ target: torch.Tensor,
123
+ top_k: int = 1,
124
+ ignore_index: Optional[int] = None,
125
+ ) -> torch.Tensor:
126
+ # Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class)
127
+ preds = rearrange(preds, "b p s -> (b s) p")
128
+ target = rearrange(target, "b s -> (b s)")
129
+
130
+ # return torchmetrics.functional.accuracy(preds, target, task='multiclass', top_k=topk, num_classes=preds.shape[-1], ignore_index=ignore_index)
131
+ if ignore_index is not None:
132
+ # Create a mask for the ignored index
133
+ mask = target != ignore_index
134
+ # Apply the mask to the target and predictions
135
+ preds = preds[mask]
136
+ target = target[mask]
137
+
138
+ # Get the top-k predicted classes and their indices
139
+ _, pred_indices = torch.topk(preds, k=top_k, dim=-1)
140
+
141
+ # Determine if the true target is in the top-k predicted classes
142
+ correct = torch.sum(torch.eq(pred_indices, target.unsqueeze(1)), dim=1)
143
+
144
+ # Calculate the accuracy
145
+ accuracy = torch.mean(correct.float())
146
+
147
+ return accuracy
148
+
149
+ def _metrics(z_hat, r, target, flat_mask, output):
150
+ for r_range in [(0, 0.5), (0.5, 1.0)]:
151
+ unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX)
152
+ masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
153
+
154
+ assert target.shape[0] == r.shape[0]
155
+ # grab the indices of the r values that are in the range
156
+ r_idx = (r >= r_range[0]) & (r < r_range[1])
157
+
158
+ # grab the target and z_hat values that are in the range
159
+ r_unmasked_target = unmasked_target[r_idx]
160
+ r_masked_target = masked_target[r_idx]
161
+ r_z_hat = z_hat[r_idx]
162
+
163
+ for topk in (1, 25):
164
+ s, e = r_range
165
+ tag = f"accuracy-{s}-{e}/top{topk}"
166
+
167
+ output[f"{tag}/unmasked"] = accuracy(
168
+ preds=r_z_hat,
169
+ target=r_unmasked_target,
170
+ ignore_index=IGNORE_INDEX,
171
+ top_k=topk,
172
+ )
173
+ output[f"{tag}/masked"] = accuracy(
174
+ preds=r_z_hat,
175
+ target=r_masked_target,
176
+ ignore_index=IGNORE_INDEX,
177
+ top_k=topk,
178
+ )
179
+
180
+
181
+ @dataclass
182
+ class State:
183
+ model: VampNet
184
+ codec: DAC
185
+
186
+ optimizer: AdamW
187
+ scheduler: NoamScheduler
188
+ criterion: CrossEntropyLoss
189
+ grad_clip_val: float
190
+
191
+ rng: torch.quasirandom.SobolEngine
192
+
193
+ train_data: AudioDataset
194
+ val_data: AudioDataset
195
+
196
+ tracker: Tracker
197
+
198
+
199
+ @timer()
200
+ def train_loop(state: State, batch: dict, accel: Accelerator):
201
+ state.model.train()
202
+ batch = at.util.prepare_batch(batch, accel.device)
203
+ signal = apply_transform(state.train_data.transform, batch)
204
+
205
+ output = {}
206
+ vn = accel.unwrap(state.model)
207
+ with accel.autocast():
208
+ with torch.inference_mode():
209
+ state.codec.to(accel.device)
210
+ z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
211
+ z = z[:, : vn.n_codebooks, :]
212
+
213
+ n_batch = z.shape[0]
214
+ r = state.rng.draw(n_batch)[:, 0].to(accel.device)
215
+
216
+ mask = pmask.random(z, r)
217
+ mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
218
+ z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
219
+
220
+ z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
221
+
222
+ dtype = torch.bfloat16 if accel.amp else None
223
+ with accel.autocast(dtype=dtype):
224
+ z_hat = state.model(z_mask_latent, r)
225
+
226
+ target = codebook_flatten(
227
+ z[:, vn.n_conditioning_codebooks :, :],
228
+ )
229
+
230
+ flat_mask = codebook_flatten(
231
+ mask[:, vn.n_conditioning_codebooks :, :],
232
+ )
233
+
234
+ # replace target with ignore index for masked tokens
235
+ t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
236
+ output["loss"] = state.criterion(z_hat, t_masked)
237
+
238
+ _metrics(
239
+ r=r,
240
+ z_hat=z_hat,
241
+ target=target,
242
+ flat_mask=flat_mask,
243
+ output=output,
244
+ )
245
+
246
+
247
+ accel.backward(output["loss"])
248
+
249
+ output["other/learning_rate"] = state.optimizer.param_groups[0]["lr"]
250
+ output["other/batch_size"] = z.shape[0]
251
+
252
+
253
+ accel.scaler.unscale_(state.optimizer)
254
+ output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_(
255
+ state.model.parameters(), state.grad_clip_val
256
+ )
257
+
258
+ accel.step(state.optimizer)
259
+ state.optimizer.zero_grad()
260
+
261
+ state.scheduler.step()
262
+ accel.update()
263
+
264
+
265
+ return {k: v for k, v in sorted(output.items())}
266
+
267
+
268
+ @timer()
269
+ @torch.no_grad()
270
+ def val_loop(state: State, batch: dict, accel: Accelerator):
271
+ state.model.eval()
272
+ state.codec.eval()
273
+ batch = at.util.prepare_batch(batch, accel.device)
274
+ signal = apply_transform(state.val_data.transform, batch)
275
+
276
+ vn = accel.unwrap(state.model)
277
+ z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
278
+ z = z[:, : vn.n_codebooks, :]
279
+
280
+ n_batch = z.shape[0]
281
+ r = state.rng.draw(n_batch)[:, 0].to(accel.device)
282
+
283
+ mask = pmask.random(z, r)
284
+ mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
285
+ z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
286
+
287
+ z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
288
+
289
+ z_hat = state.model(z_mask_latent, r)
290
+
291
+ target = codebook_flatten(
292
+ z[:, vn.n_conditioning_codebooks :, :],
293
+ )
294
+
295
+ flat_mask = codebook_flatten(
296
+ mask[:, vn.n_conditioning_codebooks :, :]
297
+ )
298
+
299
+ output = {}
300
+ # replace target with ignore index for masked tokens
301
+ t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
302
+ output["loss"] = state.criterion(z_hat, t_masked)
303
+
304
+ _metrics(
305
+ r=r,
306
+ z_hat=z_hat,
307
+ target=target,
308
+ flat_mask=flat_mask,
309
+ output=output,
310
+ )
311
+
312
+ return output
313
+
314
+
315
+ def validate(state, val_dataloader, accel):
316
+ for batch in val_dataloader:
317
+ output = val_loop(state, batch, accel)
318
+ # Consolidate state dicts if using ZeroRedundancyOptimizer
319
+ if hasattr(state.optimizer, "consolidate_state_dict"):
320
+ state.optimizer.consolidate_state_dict()
321
+ return output
322
+
323
+
324
+ def checkpoint(state, save_iters, save_path, fine_tune):
325
+ if accel.local_rank != 0:
326
+ state.tracker.print(f"ERROR:Skipping checkpoint on rank {accel.local_rank}")
327
+ return
328
+
329
+ metadata = {"logs": dict(state.tracker.history)}
330
+
331
+ tags = ["latest"]
332
+ state.tracker.print(f"Saving to {str(Path('.').absolute())}")
333
+
334
+ if state.tracker.step in save_iters:
335
+ tags.append(f"{state.tracker.step // 1000}k")
336
+
337
+ if state.tracker.is_best("val", "loss"):
338
+ state.tracker.print(f"Best model so far")
339
+ tags.append("best")
340
+
341
+ if fine_tune:
342
+ for tag in tags:
343
+ # save the lora model
344
+ (Path(save_path) / tag).mkdir(parents=True, exist_ok=True)
345
+ torch.save(
346
+ lora.lora_state_dict(accel.unwrap(state.model)),
347
+ f"{save_path}/{tag}/lora.pth"
348
+ )
349
+
350
+ for tag in tags:
351
+ model_extra = {
352
+ "optimizer.pth": state.optimizer.state_dict(),
353
+ "scheduler.pth": state.scheduler.state_dict(),
354
+ "tracker.pth": state.tracker.state_dict(),
355
+ "metadata.pth": metadata,
356
+ }
357
+
358
+ accel.unwrap(state.model).metadata = metadata
359
+ accel.unwrap(state.model).save_to_folder(
360
+ f"{save_path}/{tag}", model_extra, package=False
361
+ )
362
+
363
+
364
+ def save_sampled(state, z, writer):
365
+ num_samples = z.shape[0]
366
+
367
+ for i in range(num_samples):
368
+ sampled = accel.unwrap(state.model).generate(
369
+ codec=state.codec,
370
+ time_steps=z.shape[-1],
371
+ start_tokens=z[i : i + 1],
372
+ )
373
+ sampled.cpu().write_audio_to_tb(
374
+ f"sampled/{i}",
375
+ writer,
376
+ step=state.tracker.step,
377
+ plot_fn=None,
378
+ )
379
+
380
+
381
+ def save_imputation(state, z, val_idx, writer):
382
+ n_prefix = int(z.shape[-1] * 0.25)
383
+ n_suffix = int(z.shape[-1] * 0.25)
384
+
385
+ vn = accel.unwrap(state.model)
386
+
387
+ mask = pmask.inpaint(z, n_prefix, n_suffix)
388
+ mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
389
+ z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
390
+
391
+ imputed_noisy = vn.to_signal(z_mask, state.codec)
392
+ imputed_true = vn.to_signal(z, state.codec)
393
+
394
+ imputed = []
395
+ for i in range(len(z)):
396
+ imputed.append(
397
+ vn.generate(
398
+ codec=state.codec,
399
+ time_steps=z.shape[-1],
400
+ start_tokens=z[i][None, ...],
401
+ mask=mask[i][None, ...],
402
+ )
403
+ )
404
+ imputed = AudioSignal.batch(imputed)
405
+
406
+ for i in range(len(val_idx)):
407
+ imputed_noisy[i].cpu().write_audio_to_tb(
408
+ f"imputed_noisy/{i}",
409
+ writer,
410
+ step=state.tracker.step,
411
+ plot_fn=None,
412
+ )
413
+ imputed[i].cpu().write_audio_to_tb(
414
+ f"imputed/{i}",
415
+ writer,
416
+ step=state.tracker.step,
417
+ plot_fn=None,
418
+ )
419
+ imputed_true[i].cpu().write_audio_to_tb(
420
+ f"imputed_true/{i}",
421
+ writer,
422
+ step=state.tracker.step,
423
+ plot_fn=None,
424
+ )
425
+
426
+
427
+ @torch.no_grad()
428
+ def save_samples(state: State, val_idx: int, writer: SummaryWriter):
429
+ state.model.eval()
430
+ state.codec.eval()
431
+ vn = accel.unwrap(state.model)
432
+
433
+ batch = [state.val_data[i] for i in val_idx]
434
+ batch = at.util.prepare_batch(state.val_data.collate(batch), accel.device)
435
+
436
+ signal = apply_transform(state.val_data.transform, batch)
437
+
438
+ z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
439
+ z = z[:, : vn.n_codebooks, :]
440
+
441
+ r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device)
442
+
443
+
444
+ mask = pmask.random(z, r)
445
+ mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
446
+ z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
447
+
448
+ z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
449
+
450
+ z_hat = state.model(z_mask_latent, r)
451
+
452
+ z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
453
+ z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks)
454
+ z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
455
+
456
+ generated = vn.to_signal(z_pred, state.codec)
457
+ reconstructed = vn.to_signal(z, state.codec)
458
+ masked = vn.to_signal(z_mask.squeeze(1), state.codec)
459
+
460
+ for i in range(generated.batch_size):
461
+ audio_dict = {
462
+ "original": signal[i],
463
+ "masked": masked[i],
464
+ "generated": generated[i],
465
+ "reconstructed": reconstructed[i],
466
+ }
467
+ for k, v in audio_dict.items():
468
+ v.cpu().write_audio_to_tb(
469
+ f"samples/_{i}.r={r[i]:0.2f}/{k}",
470
+ writer,
471
+ step=state.tracker.step,
472
+ plot_fn=None,
473
+ )
474
+
475
+ save_sampled(state=state, z=z, writer=writer)
476
+ save_imputation(state=state, z=z, val_idx=val_idx, writer=writer)
477
+
478
+
479
+
480
  @argbind.bind(without_prefix=True)
481
  def load(
482
  args,
483
  accel: at.ml.Accelerator,
484
+ tracker: Tracker,
485
  save_path: str,
486
  resume: bool = False,
487
  tag: str = "latest",
488
  load_weights: bool = False,
489
  fine_tune_checkpoint: Optional[str] = None,
490
+ grad_clip_val: float = 5.0,
491
+ ) -> State:
492
  codec = DAC.load(args["codec_ckpt"], map_location="cpu")
493
  codec.eval()
494
 
 
500
  "map_location": "cpu",
501
  "package": not load_weights,
502
  }
503
+ tracker.print(f"Loading checkpoint from {kwargs['folder']}")
504
  if (Path(kwargs["folder"]) / "vampnet").exists():
505
  model, v_extra = VampNet.load_from_folder(**kwargs)
506
  else:
 
527
  scheduler = NoamScheduler(optimizer, d_model=accel.unwrap(model).embedding_dim)
528
  scheduler.step()
529
 
 
 
530
  if "optimizer.pth" in v_extra:
531
  optimizer.load_state_dict(v_extra["optimizer.pth"])
 
532
  scheduler.load_state_dict(v_extra["scheduler.pth"])
533
+ if "tracker.pth" in v_extra:
534
+ tracker.load_state_dict(v_extra["tracker.pth"])
535
+
536
+ criterion = CrossEntropyLoss()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537
 
538
+ sample_rate = codec.sample_rate
 
539
 
540
+ # a better rng for sampling from our schedule
541
+ rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=args["seed"])
542
 
543
+ # log a model summary w/ num params
544
+ if accel.local_rank == 0:
545
+ add_num_params_repr_hook(accel.unwrap(model))
546
+ with open(f"{save_path}/model.txt", "w") as f:
547
+ f.write(repr(accel.unwrap(model)))
548
 
549
+ # load the datasets
550
+ train_data, val_data = build_datasets(args, sample_rate)
551
+
552
+ return State(
553
+ tracker=tracker,
554
+ model=model,
555
+ codec=codec,
556
+ optimizer=optimizer,
557
+ scheduler=scheduler,
558
+ criterion=criterion,
559
+ rng=rng,
560
+ train_data=train_data,
561
+ val_data=val_data,
562
+ grad_clip_val=grad_clip_val,
563
+ )
564
 
565
 
566
  @argbind.bind(without_prefix=True)
567
  def train(
568
  args,
569
  accel: at.ml.Accelerator,
 
570
  seed: int = 0,
571
+ codec_ckpt: str = None,
572
  save_path: str = "ckpt",
573
+ num_iters: int = int(1000e6),
574
+ save_iters: list = [10000, 50000, 100000, 300000, 500000,],
575
+ sample_freq: int = 10000,
576
+ val_freq: int = 1000,
577
+ batch_size: int = 12,
 
578
  val_idx: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
579
  num_workers: int = 10,
 
 
580
  fine_tune: bool = False,
 
581
  ):
582
  assert codec_ckpt is not None, "codec_ckpt is required"
583
 
 
589
  writer = SummaryWriter(log_dir=f"{save_path}/logs/")
590
  argbind.dump_args(args, f"{save_path}/args.yml")
591
 
592
+ tracker = Tracker(
593
+ writer=writer, log_file=f"{save_path}/log.txt", rank=accel.local_rank
594
+ )
 
 
 
 
 
 
595
 
596
+ # load the codec model
597
+ state: State = load(
598
+ args=args,
599
+ accel=accel,
600
+ tracker=tracker,
601
+ save_path=save_path)
602
 
 
 
 
 
 
603
 
 
 
604
  train_dataloader = accel.prepare_dataloader(
605
+ state.train_data,
606
+ start_idx=state.tracker.step * batch_size,
607
  num_workers=num_workers,
608
  batch_size=batch_size,
609
+ collate_fn=state.train_data.collate,
610
  )
611
  val_dataloader = accel.prepare_dataloader(
612
+ state.val_data,
613
  start_idx=0,
614
  num_workers=num_workers,
615
  batch_size=batch_size,
616
+ collate_fn=state.val_data.collate,
617
+ persistent_workers=True,
618
  )
619
 
620
+
621
 
622
  if fine_tune:
623
+ lora.mark_only_lora_as_trainable(state.model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
624
 
625
+ # Wrap the functions so that they neatly track in TensorBoard + progress bars
626
+ # and only run when specific conditions are met.
627
+ global train_loop, val_loop, validate, save_samples, checkpoint
628
 
629
+ train_loop = tracker.log("train", "value", history=False)(
630
+ tracker.track("train", num_iters, completed=state.tracker.step)(train_loop)
631
+ )
632
+ val_loop = tracker.track("val", len(val_dataloader))(val_loop)
633
+ validate = tracker.log("val", "mean")(validate)
634
 
635
+ save_samples = when(lambda: accel.local_rank == 0)(save_samples)
636
+ checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
 
637
 
638
+ with tracker.live:
639
+ for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
640
+ train_loop(state, batch, accel)
641
 
642
+ last_iter = (
643
+ tracker.step == num_iters - 1 if num_iters is not None else False
 
 
 
 
 
 
 
 
 
 
644
  )
645
 
646
+ if tracker.step % sample_freq == 0 or last_iter:
647
+ save_samples(state, val_idx, writer)
648
 
649
+ if tracker.step % val_freq == 0 or last_iter:
650
+ validate(state, val_dataloader, accel)
651
+ checkpoint(
652
+ state=state,
653
+ save_iters=save_iters,
654
+ save_path=save_path,
655
+ fine_tune=fine_tune)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
656
 
657
+ # Reset validation progress bar, print summary since last validation.
658
+ tracker.done("val", f"Iteration {tracker.step}")
 
659
 
660
+ if last_iter:
661
+ break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
662
 
663
 
664
  if __name__ == "__main__":
 
666
  args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0
667
  with argbind.scope(args):
668
  with Accelerator() as accel:
669
+ if accel.local_rank != 0:
670
+ sys.tracebacklimit = 0
671
  train(args, accel)
setup.py CHANGED
@@ -31,7 +31,7 @@ setup(
31
  "numpy==1.22",
32
  "wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat",
33
  "lac @ git+https://github.com/hugofloresgarcia/lac.git",
34
- "audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git",
35
  "gradio",
36
  "tensorboardX",
37
  "loralib",
 
31
  "numpy==1.22",
32
  "wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat",
33
  "lac @ git+https://github.com/hugofloresgarcia/lac.git",
34
+ "descript-audiotools @ git+https://github.com/descriptinc/audiotools.git@0.7.2",
35
  "gradio",
36
  "tensorboardX",
37
  "loralib",
vampnet/modules/__init__.py CHANGED
@@ -2,3 +2,5 @@ import audiotools
2
 
3
  audiotools.ml.BaseModel.INTERN += ["vampnet.modules.**"]
4
  audiotools.ml.BaseModel.EXTERN += ["einops", "flash_attn.flash_attention", "loralib"]
 
 
 
2
 
3
  audiotools.ml.BaseModel.INTERN += ["vampnet.modules.**"]
4
  audiotools.ml.BaseModel.EXTERN += ["einops", "flash_attn.flash_attention", "loralib"]
5
+
6
+ from .transformer import VampNet