Spaces:
Sleeping
Sleeping
Hugo Flores Garcia
commited on
Commit
·
457f9d1
1
Parent(s):
fca3233
update audiotools version, update recipe
Browse files- .gitignore +1 -2
- conf/lora/lora.yml +4 -6
- conf/vampnet.yml +8 -18
- scripts/exp/fine_tune.py +5 -5
- scripts/exp/train.py +468 -418
- setup.py +1 -1
- vampnet/modules/__init__.py +2 -0
.gitignore
CHANGED
@@ -180,5 +180,4 @@ samples*/
|
|
180 |
models-all/
|
181 |
models.zip
|
182 |
.git-old
|
183 |
-
|
184 |
-
conf/generated/*
|
|
|
180 |
models-all/
|
181 |
models.zip
|
182 |
.git-old
|
183 |
+
conf/generated/*
|
|
conf/lora/lora.yml
CHANGED
@@ -3,20 +3,18 @@ $include:
|
|
3 |
|
4 |
fine_tune: True
|
5 |
|
6 |
-
train/AudioDataset.n_examples:
|
7 |
-
|
8 |
-
val/AudioDataset.n_examples: 10
|
9 |
|
10 |
|
11 |
NoamScheduler.warmup: 500
|
12 |
|
13 |
batch_size: 7
|
14 |
num_workers: 7
|
15 |
-
|
16 |
-
save_audio_epochs: 10
|
17 |
|
18 |
AdamW.lr: 0.0001
|
19 |
|
20 |
# let's us organize sound classes into folders and choose from those sound classes uniformly
|
21 |
AudioDataset.without_replacement: False
|
22 |
-
|
|
|
3 |
|
4 |
fine_tune: True
|
5 |
|
6 |
+
train/AudioDataset.n_examples: 100000000
|
7 |
+
val/AudioDataset.n_examples: 100
|
|
|
8 |
|
9 |
|
10 |
NoamScheduler.warmup: 500
|
11 |
|
12 |
batch_size: 7
|
13 |
num_workers: 7
|
14 |
+
save_iters: [100000, 200000, 300000, 4000000, 500000]
|
|
|
15 |
|
16 |
AdamW.lr: 0.0001
|
17 |
|
18 |
# let's us organize sound classes into folders and choose from those sound classes uniformly
|
19 |
AudioDataset.without_replacement: False
|
20 |
+
num_iters: 500000
|
conf/vampnet.yml
CHANGED
@@ -1,21 +1,17 @@
|
|
1 |
|
2 |
-
codec_ckpt: ./models/
|
3 |
save_path: ckpt
|
4 |
-
max_epochs: 1000
|
5 |
-
epoch_length: 1000
|
6 |
-
save_audio_epochs: 2
|
7 |
-
val_idx: [0,1,2,3,4,5,6,7,8,9]
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
13 |
|
14 |
batch_size: 8
|
15 |
num_workers: 10
|
16 |
|
17 |
# Optimization
|
18 |
-
detect_anomaly: false
|
19 |
amp: false
|
20 |
|
21 |
CrossEntropyLoss.label_smoothing: 0.1
|
@@ -25,9 +21,6 @@ AdamW.lr: 0.001
|
|
25 |
NoamScheduler.factor: 2.0
|
26 |
NoamScheduler.warmup: 10000
|
27 |
|
28 |
-
PitchShift.shift_amount: [const, 0]
|
29 |
-
PitchShift.prob: 0.0
|
30 |
-
|
31 |
VampNet.vocab_size: 1024
|
32 |
VampNet.n_codebooks: 4
|
33 |
VampNet.n_conditioning_codebooks: 0
|
@@ -48,12 +41,9 @@ AudioDataset.duration: 10.0
|
|
48 |
|
49 |
train/AudioDataset.n_examples: 10000000
|
50 |
train/AudioLoader.sources:
|
51 |
-
- /
|
52 |
|
53 |
val/AudioDataset.n_examples: 2000
|
54 |
val/AudioLoader.sources:
|
55 |
-
- /
|
56 |
|
57 |
-
test/AudioDataset.n_examples: 1000
|
58 |
-
test/AudioLoader.sources:
|
59 |
-
- /data/spotdl/audio/test
|
|
|
1 |
|
2 |
+
codec_ckpt: ./models/vampnet/codec.pth
|
3 |
save_path: ckpt
|
|
|
|
|
|
|
|
|
4 |
|
5 |
+
num_iters: 1000000000
|
6 |
+
save_iters: [10000, 50000, 100000, 300000, 500000]
|
7 |
+
val_idx: [0,1,2,3,4,5,6,7,8,9]
|
8 |
+
sample_freq: 10000
|
9 |
+
val_freq: 1000
|
10 |
|
11 |
batch_size: 8
|
12 |
num_workers: 10
|
13 |
|
14 |
# Optimization
|
|
|
15 |
amp: false
|
16 |
|
17 |
CrossEntropyLoss.label_smoothing: 0.1
|
|
|
21 |
NoamScheduler.factor: 2.0
|
22 |
NoamScheduler.warmup: 10000
|
23 |
|
|
|
|
|
|
|
24 |
VampNet.vocab_size: 1024
|
25 |
VampNet.n_codebooks: 4
|
26 |
VampNet.n_conditioning_codebooks: 0
|
|
|
41 |
|
42 |
train/AudioDataset.n_examples: 10000000
|
43 |
train/AudioLoader.sources:
|
44 |
+
- /media/CHONK/hugo/spotdl/audio-train
|
45 |
|
46 |
val/AudioDataset.n_examples: 2000
|
47 |
val/AudioLoader.sources:
|
48 |
+
- /media/CHONK/hugo/spotdl/audio-val
|
49 |
|
|
|
|
|
|
scripts/exp/fine_tune.py
CHANGED
@@ -35,7 +35,7 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
|
|
35 |
"AudioDataset.duration": 3.0,
|
36 |
"AudioDataset.loudness_cutoff": -40.0,
|
37 |
"save_path": f"./runs/{name}/c2f",
|
38 |
-
"fine_tune_checkpoint": "./models/
|
39 |
}
|
40 |
|
41 |
finetune_coarse_conf = {
|
@@ -44,17 +44,17 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
|
|
44 |
"train/AudioLoader.sources": audio_files_or_folders,
|
45 |
"val/AudioLoader.sources": audio_files_or_folders,
|
46 |
"save_path": f"./runs/{name}/coarse",
|
47 |
-
"fine_tune_checkpoint": "./models/
|
48 |
}
|
49 |
|
50 |
interface_conf = {
|
51 |
-
"Interface.coarse_ckpt": f"./models/
|
52 |
"Interface.coarse_lora_ckpt": f"./runs/{name}/coarse/latest/lora.pth",
|
53 |
|
54 |
-
"Interface.coarse2fine_ckpt": f"./models/
|
55 |
"Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
|
56 |
|
57 |
-
"Interface.codec_ckpt": "./models/
|
58 |
"AudioLoader.sources": [audio_files_or_folders],
|
59 |
}
|
60 |
|
|
|
35 |
"AudioDataset.duration": 3.0,
|
36 |
"AudioDataset.loudness_cutoff": -40.0,
|
37 |
"save_path": f"./runs/{name}/c2f",
|
38 |
+
"fine_tune_checkpoint": "./models/vampnet/c2f.pth"
|
39 |
}
|
40 |
|
41 |
finetune_coarse_conf = {
|
|
|
44 |
"train/AudioLoader.sources": audio_files_or_folders,
|
45 |
"val/AudioLoader.sources": audio_files_or_folders,
|
46 |
"save_path": f"./runs/{name}/coarse",
|
47 |
+
"fine_tune_checkpoint": "./models/vampnet/coarse.pth"
|
48 |
}
|
49 |
|
50 |
interface_conf = {
|
51 |
+
"Interface.coarse_ckpt": f"./models/vampnet/coarse.pth",
|
52 |
"Interface.coarse_lora_ckpt": f"./runs/{name}/coarse/latest/lora.pth",
|
53 |
|
54 |
+
"Interface.coarse2fine_ckpt": f"./models/vampnet/c2f.pth",
|
55 |
"Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
|
56 |
|
57 |
+
"Interface.codec_ckpt": "./models/vampnet/codec.pth",
|
58 |
"AudioLoader.sources": [audio_files_or_folders],
|
59 |
}
|
60 |
|
scripts/exp/train.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import os
|
2 |
-
import
|
3 |
-
import time
|
4 |
import warnings
|
5 |
from pathlib import Path
|
6 |
from typing import Optional
|
|
|
7 |
|
8 |
import argbind
|
9 |
import audiotools as at
|
@@ -23,6 +23,12 @@ from vampnet import mask as pmask
|
|
23 |
# from dac.model.dac import DAC
|
24 |
from lac.model.lac import LAC as DAC
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
# Enable cudnn autotuner to speed up training
|
28 |
# (can be altered by the funcs.seed function)
|
@@ -85,11 +91,7 @@ def build_datasets(args, sample_rate: int):
|
|
85 |
)
|
86 |
with argbind.scope(args, "val"):
|
87 |
val_data = AudioDataset(AudioLoader(), sample_rate, transform=build_transform())
|
88 |
-
|
89 |
-
test_data = AudioDataset(
|
90 |
-
AudioLoader(), sample_rate, transform=build_transform()
|
91 |
-
)
|
92 |
-
return train_data, val_data, test_data
|
93 |
|
94 |
|
95 |
def rand_float(shape, low, high, rng):
|
@@ -100,16 +102,393 @@ def flip_coin(shape, p, rng):
|
|
100 |
return rng.draw(shape)[:, 0] < p
|
101 |
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
@argbind.bind(without_prefix=True)
|
104 |
def load(
|
105 |
args,
|
106 |
accel: at.ml.Accelerator,
|
|
|
107 |
save_path: str,
|
108 |
resume: bool = False,
|
109 |
tag: str = "latest",
|
110 |
load_weights: bool = False,
|
111 |
fine_tune_checkpoint: Optional[str] = None,
|
112 |
-
|
|
|
113 |
codec = DAC.load(args["codec_ckpt"], map_location="cpu")
|
114 |
codec.eval()
|
115 |
|
@@ -121,6 +500,7 @@ def load(
|
|
121 |
"map_location": "cpu",
|
122 |
"package": not load_weights,
|
123 |
}
|
|
|
124 |
if (Path(kwargs["folder"]) / "vampnet").exists():
|
125 |
model, v_extra = VampNet.load_from_folder(**kwargs)
|
126 |
else:
|
@@ -147,89 +527,57 @@ def load(
|
|
147 |
scheduler = NoamScheduler(optimizer, d_model=accel.unwrap(model).embedding_dim)
|
148 |
scheduler.step()
|
149 |
|
150 |
-
trainer_state = {"state_dict": None, "start_idx": 0}
|
151 |
-
|
152 |
if "optimizer.pth" in v_extra:
|
153 |
optimizer.load_state_dict(v_extra["optimizer.pth"])
|
154 |
-
if "scheduler.pth" in v_extra:
|
155 |
scheduler.load_state_dict(v_extra["scheduler.pth"])
|
156 |
-
if "
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
"model": model,
|
161 |
-
"codec": codec,
|
162 |
-
"optimizer": optimizer,
|
163 |
-
"scheduler": scheduler,
|
164 |
-
"trainer_state": trainer_state,
|
165 |
-
}
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
def num_params_hook(o, p):
|
170 |
-
return o + f" {p/1e6:<.3f}M params."
|
171 |
-
|
172 |
-
|
173 |
-
def add_num_params_repr_hook(model):
|
174 |
-
import numpy as np
|
175 |
-
from functools import partial
|
176 |
-
|
177 |
-
for n, m in model.named_modules():
|
178 |
-
o = m.extra_repr()
|
179 |
-
p = sum([np.prod(p.size()) for p in m.parameters()])
|
180 |
-
|
181 |
-
setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p))
|
182 |
-
|
183 |
-
|
184 |
-
def accuracy(
|
185 |
-
preds: torch.Tensor,
|
186 |
-
target: torch.Tensor,
|
187 |
-
top_k: int = 1,
|
188 |
-
ignore_index: Optional[int] = None,
|
189 |
-
) -> torch.Tensor:
|
190 |
-
# Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class)
|
191 |
-
preds = rearrange(preds, "b p s -> (b s) p")
|
192 |
-
target = rearrange(target, "b s -> (b s)")
|
193 |
-
|
194 |
-
# return torchmetrics.functional.accuracy(preds, target, task='multiclass', top_k=topk, num_classes=preds.shape[-1], ignore_index=ignore_index)
|
195 |
-
if ignore_index is not None:
|
196 |
-
# Create a mask for the ignored index
|
197 |
-
mask = target != ignore_index
|
198 |
-
# Apply the mask to the target and predictions
|
199 |
-
preds = preds[mask]
|
200 |
-
target = target[mask]
|
201 |
|
202 |
-
|
203 |
-
_, pred_indices = torch.topk(preds, k=top_k, dim=-1)
|
204 |
|
205 |
-
#
|
206 |
-
|
207 |
|
208 |
-
#
|
209 |
-
|
|
|
|
|
|
|
210 |
|
211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
|
214 |
@argbind.bind(without_prefix=True)
|
215 |
def train(
|
216 |
args,
|
217 |
accel: at.ml.Accelerator,
|
218 |
-
codec_ckpt: str = None,
|
219 |
seed: int = 0,
|
|
|
220 |
save_path: str = "ckpt",
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
batch_size: int =
|
226 |
-
grad_acc_steps: int = 1,
|
227 |
val_idx: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
|
228 |
num_workers: int = 10,
|
229 |
-
detect_anomaly: bool = False,
|
230 |
-
grad_clip_val: float = 5.0,
|
231 |
fine_tune: bool = False,
|
232 |
-
quiet: bool = False,
|
233 |
):
|
234 |
assert codec_ckpt is not None, "codec_ckpt is required"
|
235 |
|
@@ -241,376 +589,76 @@ def train(
|
|
241 |
writer = SummaryWriter(log_dir=f"{save_path}/logs/")
|
242 |
argbind.dump_args(args, f"{save_path}/args.yml")
|
243 |
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
codec = loaded["codec"]
|
248 |
-
optimizer = loaded["optimizer"]
|
249 |
-
scheduler = loaded["scheduler"]
|
250 |
-
trainer_state = loaded["trainer_state"]
|
251 |
-
|
252 |
-
sample_rate = codec.sample_rate
|
253 |
|
254 |
-
#
|
255 |
-
|
|
|
|
|
|
|
|
|
256 |
|
257 |
-
# log a model summary w/ num params
|
258 |
-
if accel.local_rank == 0:
|
259 |
-
add_num_params_repr_hook(accel.unwrap(model))
|
260 |
-
with open(f"{save_path}/model.txt", "w") as f:
|
261 |
-
f.write(repr(accel.unwrap(model)))
|
262 |
|
263 |
-
# load the datasets
|
264 |
-
train_data, val_data, _ = build_datasets(args, sample_rate)
|
265 |
train_dataloader = accel.prepare_dataloader(
|
266 |
-
train_data,
|
267 |
-
start_idx=
|
268 |
num_workers=num_workers,
|
269 |
batch_size=batch_size,
|
270 |
-
collate_fn=train_data.collate,
|
271 |
)
|
272 |
val_dataloader = accel.prepare_dataloader(
|
273 |
-
val_data,
|
274 |
start_idx=0,
|
275 |
num_workers=num_workers,
|
276 |
batch_size=batch_size,
|
277 |
-
collate_fn=val_data.collate,
|
|
|
278 |
)
|
279 |
|
280 |
-
|
281 |
|
282 |
if fine_tune:
|
283 |
-
|
284 |
-
lora.mark_only_lora_as_trainable(model)
|
285 |
-
|
286 |
-
|
287 |
-
class Trainer(at.ml.BaseTrainer):
|
288 |
-
_last_grad_norm = 0.0
|
289 |
-
|
290 |
-
def _metrics(self, vn, z_hat, r, target, flat_mask, output):
|
291 |
-
for r_range in [(0, 0.5), (0.5, 1.0)]:
|
292 |
-
unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX)
|
293 |
-
masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
294 |
-
|
295 |
-
assert target.shape[0] == r.shape[0]
|
296 |
-
# grab the indices of the r values that are in the range
|
297 |
-
r_idx = (r >= r_range[0]) & (r < r_range[1])
|
298 |
-
|
299 |
-
# grab the target and z_hat values that are in the range
|
300 |
-
r_unmasked_target = unmasked_target[r_idx]
|
301 |
-
r_masked_target = masked_target[r_idx]
|
302 |
-
r_z_hat = z_hat[r_idx]
|
303 |
-
|
304 |
-
for topk in (1, 25):
|
305 |
-
s, e = r_range
|
306 |
-
tag = f"accuracy-{s}-{e}/top{topk}"
|
307 |
-
|
308 |
-
output[f"{tag}/unmasked"] = accuracy(
|
309 |
-
preds=r_z_hat,
|
310 |
-
target=r_unmasked_target,
|
311 |
-
ignore_index=IGNORE_INDEX,
|
312 |
-
top_k=topk,
|
313 |
-
)
|
314 |
-
output[f"{tag}/masked"] = accuracy(
|
315 |
-
preds=r_z_hat,
|
316 |
-
target=r_masked_target,
|
317 |
-
ignore_index=IGNORE_INDEX,
|
318 |
-
top_k=topk,
|
319 |
-
)
|
320 |
-
|
321 |
-
def train_loop(self, engine, batch):
|
322 |
-
model.train()
|
323 |
-
batch = at.util.prepare_batch(batch, accel.device)
|
324 |
-
signal = apply_transform(train_data.transform, batch)
|
325 |
-
|
326 |
-
output = {}
|
327 |
-
vn = accel.unwrap(model)
|
328 |
-
with accel.autocast():
|
329 |
-
with torch.inference_mode():
|
330 |
-
codec.to(accel.device)
|
331 |
-
z = codec.encode(signal.samples, signal.sample_rate)["codes"]
|
332 |
-
z = z[:, : vn.n_codebooks, :]
|
333 |
-
|
334 |
-
n_batch = z.shape[0]
|
335 |
-
r = rng.draw(n_batch)[:, 0].to(accel.device)
|
336 |
-
|
337 |
-
mask = pmask.random(z, r)
|
338 |
-
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
339 |
-
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
340 |
-
|
341 |
-
z_mask_latent = vn.embedding.from_codes(z_mask, codec)
|
342 |
-
|
343 |
-
dtype = torch.bfloat16 if accel.amp else None
|
344 |
-
with accel.autocast(dtype=dtype):
|
345 |
-
z_hat = model(z_mask_latent, r)
|
346 |
-
|
347 |
-
target = codebook_flatten(
|
348 |
-
z[:, vn.n_conditioning_codebooks :, :],
|
349 |
-
)
|
350 |
-
|
351 |
-
flat_mask = codebook_flatten(
|
352 |
-
mask[:, vn.n_conditioning_codebooks :, :],
|
353 |
-
)
|
354 |
-
|
355 |
-
# replace target with ignore index for masked tokens
|
356 |
-
t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
357 |
-
output["loss"] = criterion(z_hat, t_masked)
|
358 |
-
|
359 |
-
self._metrics(
|
360 |
-
vn=vn,
|
361 |
-
r=r,
|
362 |
-
z_hat=z_hat,
|
363 |
-
target=target,
|
364 |
-
flat_mask=flat_mask,
|
365 |
-
output=output,
|
366 |
-
)
|
367 |
-
|
368 |
-
|
369 |
-
accel.backward(output["loss"] / grad_acc_steps)
|
370 |
-
|
371 |
-
output["other/learning_rate"] = optimizer.param_groups[0]["lr"]
|
372 |
-
output["other/batch_size"] = z.shape[0]
|
373 |
-
|
374 |
-
if (
|
375 |
-
(engine.state.iteration % grad_acc_steps == 0)
|
376 |
-
or (engine.state.iteration % epoch_length == 0)
|
377 |
-
or (engine.state.iteration % epoch_length == 1)
|
378 |
-
): # (or we reached the end of the epoch)
|
379 |
-
accel.scaler.unscale_(optimizer)
|
380 |
-
output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_(
|
381 |
-
model.parameters(), grad_clip_val
|
382 |
-
)
|
383 |
-
self._last_grad_norm = output["other/grad_norm"]
|
384 |
-
|
385 |
-
accel.step(optimizer)
|
386 |
-
optimizer.zero_grad()
|
387 |
-
|
388 |
-
scheduler.step()
|
389 |
-
accel.update()
|
390 |
-
else:
|
391 |
-
output["other/grad_norm"] = self._last_grad_norm
|
392 |
-
|
393 |
-
return {k: v for k, v in sorted(output.items())}
|
394 |
-
|
395 |
-
@torch.no_grad()
|
396 |
-
def val_loop(self, engine, batch):
|
397 |
-
model.eval()
|
398 |
-
codec.eval()
|
399 |
-
batch = at.util.prepare_batch(batch, accel.device)
|
400 |
-
signal = apply_transform(val_data.transform, batch)
|
401 |
-
|
402 |
-
vn = accel.unwrap(model)
|
403 |
-
z = codec.encode(signal.samples, signal.sample_rate)["codes"]
|
404 |
-
z = z[:, : vn.n_codebooks, :]
|
405 |
-
|
406 |
-
n_batch = z.shape[0]
|
407 |
-
r = rng.draw(n_batch)[:, 0].to(accel.device)
|
408 |
-
|
409 |
-
mask = pmask.random(z, r)
|
410 |
-
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
411 |
-
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
412 |
|
413 |
-
|
|
|
|
|
414 |
|
415 |
-
|
|
|
|
|
|
|
|
|
416 |
|
417 |
-
|
418 |
-
|
419 |
-
)
|
420 |
|
421 |
-
|
422 |
-
|
423 |
-
)
|
424 |
|
425 |
-
|
426 |
-
|
427 |
-
t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
428 |
-
output["loss"] = criterion(z_hat, t_masked)
|
429 |
-
|
430 |
-
self._metrics(
|
431 |
-
vn=vn,
|
432 |
-
r=r,
|
433 |
-
z_hat=z_hat,
|
434 |
-
target=target,
|
435 |
-
flat_mask=flat_mask,
|
436 |
-
output=output,
|
437 |
)
|
438 |
|
439 |
-
|
|
|
440 |
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
if self.state.epoch % save_audio_epochs == 0:
|
449 |
-
self.save_samples()
|
450 |
-
|
451 |
-
tags = ["latest"]
|
452 |
-
loss_key = "loss/val" if "loss/val" in metadata["logs"] else "loss/train"
|
453 |
-
self.print(f"Saving to {str(Path('.').absolute())}")
|
454 |
-
|
455 |
-
if self.state.epoch in save_epochs:
|
456 |
-
tags.append(f"epoch={self.state.epoch}")
|
457 |
-
|
458 |
-
if self.is_best(engine, loss_key):
|
459 |
-
self.print(f"Best model so far")
|
460 |
-
tags.append("best")
|
461 |
-
|
462 |
-
if fine_tune:
|
463 |
-
for tag in tags:
|
464 |
-
# save the lora model
|
465 |
-
(Path(save_path) / tag).mkdir(parents=True, exist_ok=True)
|
466 |
-
torch.save(
|
467 |
-
lora.lora_state_dict(accel.unwrap(model)),
|
468 |
-
f"{save_path}/{tag}/lora.pth"
|
469 |
-
)
|
470 |
-
|
471 |
-
for tag in tags:
|
472 |
-
model_extra = {
|
473 |
-
"optimizer.pth": optimizer.state_dict(),
|
474 |
-
"scheduler.pth": scheduler.state_dict(),
|
475 |
-
"trainer.pth": {
|
476 |
-
"start_idx": self.state.iteration * batch_size,
|
477 |
-
"state_dict": self.state_dict(),
|
478 |
-
},
|
479 |
-
"metadata.pth": metadata,
|
480 |
-
}
|
481 |
-
|
482 |
-
accel.unwrap(model).metadata = metadata
|
483 |
-
accel.unwrap(model).save_to_folder(
|
484 |
-
f"{save_path}/{tag}", model_extra,
|
485 |
-
)
|
486 |
-
|
487 |
-
def save_sampled(self, z):
|
488 |
-
num_samples = z.shape[0]
|
489 |
-
|
490 |
-
for i in range(num_samples):
|
491 |
-
sampled = accel.unwrap(model).generate(
|
492 |
-
codec=codec,
|
493 |
-
time_steps=z.shape[-1],
|
494 |
-
start_tokens=z[i : i + 1],
|
495 |
-
)
|
496 |
-
sampled.cpu().write_audio_to_tb(
|
497 |
-
f"sampled/{i}",
|
498 |
-
self.writer,
|
499 |
-
step=self.state.epoch,
|
500 |
-
plot_fn=None,
|
501 |
-
)
|
502 |
-
|
503 |
-
|
504 |
-
def save_imputation(self, z: torch.Tensor):
|
505 |
-
n_prefix = int(z.shape[-1] * 0.25)
|
506 |
-
n_suffix = int(z.shape[-1] * 0.25)
|
507 |
-
|
508 |
-
vn = accel.unwrap(model)
|
509 |
-
|
510 |
-
mask = pmask.inpaint(z, n_prefix, n_suffix)
|
511 |
-
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
512 |
-
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
513 |
-
|
514 |
-
imputed_noisy = vn.to_signal(z_mask, codec)
|
515 |
-
imputed_true = vn.to_signal(z, codec)
|
516 |
-
|
517 |
-
imputed = []
|
518 |
-
for i in range(len(z)):
|
519 |
-
imputed.append(
|
520 |
-
vn.generate(
|
521 |
-
codec=codec,
|
522 |
-
time_steps=z.shape[-1],
|
523 |
-
start_tokens=z[i][None, ...],
|
524 |
-
mask=mask[i][None, ...],
|
525 |
-
)
|
526 |
-
)
|
527 |
-
imputed = AudioSignal.batch(imputed)
|
528 |
-
|
529 |
-
for i in range(len(val_idx)):
|
530 |
-
imputed_noisy[i].cpu().write_audio_to_tb(
|
531 |
-
f"imputed_noisy/{i}",
|
532 |
-
self.writer,
|
533 |
-
step=self.state.epoch,
|
534 |
-
plot_fn=None,
|
535 |
-
)
|
536 |
-
imputed[i].cpu().write_audio_to_tb(
|
537 |
-
f"imputed/{i}",
|
538 |
-
self.writer,
|
539 |
-
step=self.state.epoch,
|
540 |
-
plot_fn=None,
|
541 |
-
)
|
542 |
-
imputed_true[i].cpu().write_audio_to_tb(
|
543 |
-
f"imputed_true/{i}",
|
544 |
-
self.writer,
|
545 |
-
step=self.state.epoch,
|
546 |
-
plot_fn=None,
|
547 |
-
)
|
548 |
-
|
549 |
-
@torch.no_grad()
|
550 |
-
def save_samples(self):
|
551 |
-
model.eval()
|
552 |
-
codec.eval()
|
553 |
-
vn = accel.unwrap(model)
|
554 |
-
|
555 |
-
batch = [val_data[i] for i in val_idx]
|
556 |
-
batch = at.util.prepare_batch(val_data.collate(batch), accel.device)
|
557 |
-
|
558 |
-
signal = apply_transform(val_data.transform, batch)
|
559 |
-
|
560 |
-
z = codec.encode(signal.samples, signal.sample_rate)["codes"]
|
561 |
-
z = z[:, : vn.n_codebooks, :]
|
562 |
-
|
563 |
-
r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device)
|
564 |
-
|
565 |
-
|
566 |
-
mask = pmask.random(z, r)
|
567 |
-
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
568 |
-
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
569 |
-
|
570 |
-
z_mask_latent = vn.embedding.from_codes(z_mask, codec)
|
571 |
-
|
572 |
-
z_hat = model(z_mask_latent, r)
|
573 |
|
574 |
-
|
575 |
-
|
576 |
-
z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
|
577 |
|
578 |
-
|
579 |
-
|
580 |
-
masked = vn.to_signal(z_mask.squeeze(1), codec)
|
581 |
-
|
582 |
-
for i in range(generated.batch_size):
|
583 |
-
audio_dict = {
|
584 |
-
"original": signal[i],
|
585 |
-
"masked": masked[i],
|
586 |
-
"generated": generated[i],
|
587 |
-
"reconstructed": reconstructed[i],
|
588 |
-
}
|
589 |
-
for k, v in audio_dict.items():
|
590 |
-
v.cpu().write_audio_to_tb(
|
591 |
-
f"samples/_{i}.r={r[i]:0.2f}/{k}",
|
592 |
-
self.writer,
|
593 |
-
step=self.state.epoch,
|
594 |
-
plot_fn=None,
|
595 |
-
)
|
596 |
-
|
597 |
-
self.save_sampled(z)
|
598 |
-
self.save_imputation(z)
|
599 |
-
|
600 |
-
trainer = Trainer(writer=writer, quiet=quiet)
|
601 |
-
|
602 |
-
if trainer_state["state_dict"] is not None:
|
603 |
-
trainer.load_state_dict(trainer_state["state_dict"])
|
604 |
-
if hasattr(train_dataloader.sampler, "set_epoch"):
|
605 |
-
train_dataloader.sampler.set_epoch(trainer.trainer.state.epoch)
|
606 |
-
|
607 |
-
trainer.run(
|
608 |
-
train_dataloader,
|
609 |
-
val_dataloader,
|
610 |
-
num_epochs=max_epochs,
|
611 |
-
epoch_length=epoch_length,
|
612 |
-
detect_anomaly=detect_anomaly,
|
613 |
-
)
|
614 |
|
615 |
|
616 |
if __name__ == "__main__":
|
@@ -618,4 +666,6 @@ if __name__ == "__main__":
|
|
618 |
args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0
|
619 |
with argbind.scope(args):
|
620 |
with Accelerator() as accel:
|
|
|
|
|
621 |
train(args, accel)
|
|
|
1 |
import os
|
2 |
+
import sys
|
|
|
3 |
import warnings
|
4 |
from pathlib import Path
|
5 |
from typing import Optional
|
6 |
+
from dataclasses import dataclass
|
7 |
|
8 |
import argbind
|
9 |
import audiotools as at
|
|
|
23 |
# from dac.model.dac import DAC
|
24 |
from lac.model.lac import LAC as DAC
|
25 |
|
26 |
+
from audiotools.ml.decorators import (
|
27 |
+
timer, Tracker, when
|
28 |
+
)
|
29 |
+
|
30 |
+
import loralib as lora
|
31 |
+
|
32 |
|
33 |
# Enable cudnn autotuner to speed up training
|
34 |
# (can be altered by the funcs.seed function)
|
|
|
91 |
)
|
92 |
with argbind.scope(args, "val"):
|
93 |
val_data = AudioDataset(AudioLoader(), sample_rate, transform=build_transform())
|
94 |
+
return train_data, val_data
|
|
|
|
|
|
|
|
|
95 |
|
96 |
|
97 |
def rand_float(shape, low, high, rng):
|
|
|
102 |
return rng.draw(shape)[:, 0] < p
|
103 |
|
104 |
|
105 |
+
def num_params_hook(o, p):
|
106 |
+
return o + f" {p/1e6:<.3f}M params."
|
107 |
+
|
108 |
+
|
109 |
+
def add_num_params_repr_hook(model):
|
110 |
+
import numpy as np
|
111 |
+
from functools import partial
|
112 |
+
|
113 |
+
for n, m in model.named_modules():
|
114 |
+
o = m.extra_repr()
|
115 |
+
p = sum([np.prod(p.size()) for p in m.parameters()])
|
116 |
+
|
117 |
+
setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p))
|
118 |
+
|
119 |
+
|
120 |
+
def accuracy(
|
121 |
+
preds: torch.Tensor,
|
122 |
+
target: torch.Tensor,
|
123 |
+
top_k: int = 1,
|
124 |
+
ignore_index: Optional[int] = None,
|
125 |
+
) -> torch.Tensor:
|
126 |
+
# Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class)
|
127 |
+
preds = rearrange(preds, "b p s -> (b s) p")
|
128 |
+
target = rearrange(target, "b s -> (b s)")
|
129 |
+
|
130 |
+
# return torchmetrics.functional.accuracy(preds, target, task='multiclass', top_k=topk, num_classes=preds.shape[-1], ignore_index=ignore_index)
|
131 |
+
if ignore_index is not None:
|
132 |
+
# Create a mask for the ignored index
|
133 |
+
mask = target != ignore_index
|
134 |
+
# Apply the mask to the target and predictions
|
135 |
+
preds = preds[mask]
|
136 |
+
target = target[mask]
|
137 |
+
|
138 |
+
# Get the top-k predicted classes and their indices
|
139 |
+
_, pred_indices = torch.topk(preds, k=top_k, dim=-1)
|
140 |
+
|
141 |
+
# Determine if the true target is in the top-k predicted classes
|
142 |
+
correct = torch.sum(torch.eq(pred_indices, target.unsqueeze(1)), dim=1)
|
143 |
+
|
144 |
+
# Calculate the accuracy
|
145 |
+
accuracy = torch.mean(correct.float())
|
146 |
+
|
147 |
+
return accuracy
|
148 |
+
|
149 |
+
def _metrics(z_hat, r, target, flat_mask, output):
|
150 |
+
for r_range in [(0, 0.5), (0.5, 1.0)]:
|
151 |
+
unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX)
|
152 |
+
masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
153 |
+
|
154 |
+
assert target.shape[0] == r.shape[0]
|
155 |
+
# grab the indices of the r values that are in the range
|
156 |
+
r_idx = (r >= r_range[0]) & (r < r_range[1])
|
157 |
+
|
158 |
+
# grab the target and z_hat values that are in the range
|
159 |
+
r_unmasked_target = unmasked_target[r_idx]
|
160 |
+
r_masked_target = masked_target[r_idx]
|
161 |
+
r_z_hat = z_hat[r_idx]
|
162 |
+
|
163 |
+
for topk in (1, 25):
|
164 |
+
s, e = r_range
|
165 |
+
tag = f"accuracy-{s}-{e}/top{topk}"
|
166 |
+
|
167 |
+
output[f"{tag}/unmasked"] = accuracy(
|
168 |
+
preds=r_z_hat,
|
169 |
+
target=r_unmasked_target,
|
170 |
+
ignore_index=IGNORE_INDEX,
|
171 |
+
top_k=topk,
|
172 |
+
)
|
173 |
+
output[f"{tag}/masked"] = accuracy(
|
174 |
+
preds=r_z_hat,
|
175 |
+
target=r_masked_target,
|
176 |
+
ignore_index=IGNORE_INDEX,
|
177 |
+
top_k=topk,
|
178 |
+
)
|
179 |
+
|
180 |
+
|
181 |
+
@dataclass
|
182 |
+
class State:
|
183 |
+
model: VampNet
|
184 |
+
codec: DAC
|
185 |
+
|
186 |
+
optimizer: AdamW
|
187 |
+
scheduler: NoamScheduler
|
188 |
+
criterion: CrossEntropyLoss
|
189 |
+
grad_clip_val: float
|
190 |
+
|
191 |
+
rng: torch.quasirandom.SobolEngine
|
192 |
+
|
193 |
+
train_data: AudioDataset
|
194 |
+
val_data: AudioDataset
|
195 |
+
|
196 |
+
tracker: Tracker
|
197 |
+
|
198 |
+
|
199 |
+
@timer()
|
200 |
+
def train_loop(state: State, batch: dict, accel: Accelerator):
|
201 |
+
state.model.train()
|
202 |
+
batch = at.util.prepare_batch(batch, accel.device)
|
203 |
+
signal = apply_transform(state.train_data.transform, batch)
|
204 |
+
|
205 |
+
output = {}
|
206 |
+
vn = accel.unwrap(state.model)
|
207 |
+
with accel.autocast():
|
208 |
+
with torch.inference_mode():
|
209 |
+
state.codec.to(accel.device)
|
210 |
+
z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
|
211 |
+
z = z[:, : vn.n_codebooks, :]
|
212 |
+
|
213 |
+
n_batch = z.shape[0]
|
214 |
+
r = state.rng.draw(n_batch)[:, 0].to(accel.device)
|
215 |
+
|
216 |
+
mask = pmask.random(z, r)
|
217 |
+
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
218 |
+
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
219 |
+
|
220 |
+
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
221 |
+
|
222 |
+
dtype = torch.bfloat16 if accel.amp else None
|
223 |
+
with accel.autocast(dtype=dtype):
|
224 |
+
z_hat = state.model(z_mask_latent, r)
|
225 |
+
|
226 |
+
target = codebook_flatten(
|
227 |
+
z[:, vn.n_conditioning_codebooks :, :],
|
228 |
+
)
|
229 |
+
|
230 |
+
flat_mask = codebook_flatten(
|
231 |
+
mask[:, vn.n_conditioning_codebooks :, :],
|
232 |
+
)
|
233 |
+
|
234 |
+
# replace target with ignore index for masked tokens
|
235 |
+
t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
236 |
+
output["loss"] = state.criterion(z_hat, t_masked)
|
237 |
+
|
238 |
+
_metrics(
|
239 |
+
r=r,
|
240 |
+
z_hat=z_hat,
|
241 |
+
target=target,
|
242 |
+
flat_mask=flat_mask,
|
243 |
+
output=output,
|
244 |
+
)
|
245 |
+
|
246 |
+
|
247 |
+
accel.backward(output["loss"])
|
248 |
+
|
249 |
+
output["other/learning_rate"] = state.optimizer.param_groups[0]["lr"]
|
250 |
+
output["other/batch_size"] = z.shape[0]
|
251 |
+
|
252 |
+
|
253 |
+
accel.scaler.unscale_(state.optimizer)
|
254 |
+
output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_(
|
255 |
+
state.model.parameters(), state.grad_clip_val
|
256 |
+
)
|
257 |
+
|
258 |
+
accel.step(state.optimizer)
|
259 |
+
state.optimizer.zero_grad()
|
260 |
+
|
261 |
+
state.scheduler.step()
|
262 |
+
accel.update()
|
263 |
+
|
264 |
+
|
265 |
+
return {k: v for k, v in sorted(output.items())}
|
266 |
+
|
267 |
+
|
268 |
+
@timer()
|
269 |
+
@torch.no_grad()
|
270 |
+
def val_loop(state: State, batch: dict, accel: Accelerator):
|
271 |
+
state.model.eval()
|
272 |
+
state.codec.eval()
|
273 |
+
batch = at.util.prepare_batch(batch, accel.device)
|
274 |
+
signal = apply_transform(state.val_data.transform, batch)
|
275 |
+
|
276 |
+
vn = accel.unwrap(state.model)
|
277 |
+
z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
|
278 |
+
z = z[:, : vn.n_codebooks, :]
|
279 |
+
|
280 |
+
n_batch = z.shape[0]
|
281 |
+
r = state.rng.draw(n_batch)[:, 0].to(accel.device)
|
282 |
+
|
283 |
+
mask = pmask.random(z, r)
|
284 |
+
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
285 |
+
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
286 |
+
|
287 |
+
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
288 |
+
|
289 |
+
z_hat = state.model(z_mask_latent, r)
|
290 |
+
|
291 |
+
target = codebook_flatten(
|
292 |
+
z[:, vn.n_conditioning_codebooks :, :],
|
293 |
+
)
|
294 |
+
|
295 |
+
flat_mask = codebook_flatten(
|
296 |
+
mask[:, vn.n_conditioning_codebooks :, :]
|
297 |
+
)
|
298 |
+
|
299 |
+
output = {}
|
300 |
+
# replace target with ignore index for masked tokens
|
301 |
+
t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
302 |
+
output["loss"] = state.criterion(z_hat, t_masked)
|
303 |
+
|
304 |
+
_metrics(
|
305 |
+
r=r,
|
306 |
+
z_hat=z_hat,
|
307 |
+
target=target,
|
308 |
+
flat_mask=flat_mask,
|
309 |
+
output=output,
|
310 |
+
)
|
311 |
+
|
312 |
+
return output
|
313 |
+
|
314 |
+
|
315 |
+
def validate(state, val_dataloader, accel):
|
316 |
+
for batch in val_dataloader:
|
317 |
+
output = val_loop(state, batch, accel)
|
318 |
+
# Consolidate state dicts if using ZeroRedundancyOptimizer
|
319 |
+
if hasattr(state.optimizer, "consolidate_state_dict"):
|
320 |
+
state.optimizer.consolidate_state_dict()
|
321 |
+
return output
|
322 |
+
|
323 |
+
|
324 |
+
def checkpoint(state, save_iters, save_path, fine_tune):
|
325 |
+
if accel.local_rank != 0:
|
326 |
+
state.tracker.print(f"ERROR:Skipping checkpoint on rank {accel.local_rank}")
|
327 |
+
return
|
328 |
+
|
329 |
+
metadata = {"logs": dict(state.tracker.history)}
|
330 |
+
|
331 |
+
tags = ["latest"]
|
332 |
+
state.tracker.print(f"Saving to {str(Path('.').absolute())}")
|
333 |
+
|
334 |
+
if state.tracker.step in save_iters:
|
335 |
+
tags.append(f"{state.tracker.step // 1000}k")
|
336 |
+
|
337 |
+
if state.tracker.is_best("val", "loss"):
|
338 |
+
state.tracker.print(f"Best model so far")
|
339 |
+
tags.append("best")
|
340 |
+
|
341 |
+
if fine_tune:
|
342 |
+
for tag in tags:
|
343 |
+
# save the lora model
|
344 |
+
(Path(save_path) / tag).mkdir(parents=True, exist_ok=True)
|
345 |
+
torch.save(
|
346 |
+
lora.lora_state_dict(accel.unwrap(state.model)),
|
347 |
+
f"{save_path}/{tag}/lora.pth"
|
348 |
+
)
|
349 |
+
|
350 |
+
for tag in tags:
|
351 |
+
model_extra = {
|
352 |
+
"optimizer.pth": state.optimizer.state_dict(),
|
353 |
+
"scheduler.pth": state.scheduler.state_dict(),
|
354 |
+
"tracker.pth": state.tracker.state_dict(),
|
355 |
+
"metadata.pth": metadata,
|
356 |
+
}
|
357 |
+
|
358 |
+
accel.unwrap(state.model).metadata = metadata
|
359 |
+
accel.unwrap(state.model).save_to_folder(
|
360 |
+
f"{save_path}/{tag}", model_extra, package=False
|
361 |
+
)
|
362 |
+
|
363 |
+
|
364 |
+
def save_sampled(state, z, writer):
|
365 |
+
num_samples = z.shape[0]
|
366 |
+
|
367 |
+
for i in range(num_samples):
|
368 |
+
sampled = accel.unwrap(state.model).generate(
|
369 |
+
codec=state.codec,
|
370 |
+
time_steps=z.shape[-1],
|
371 |
+
start_tokens=z[i : i + 1],
|
372 |
+
)
|
373 |
+
sampled.cpu().write_audio_to_tb(
|
374 |
+
f"sampled/{i}",
|
375 |
+
writer,
|
376 |
+
step=state.tracker.step,
|
377 |
+
plot_fn=None,
|
378 |
+
)
|
379 |
+
|
380 |
+
|
381 |
+
def save_imputation(state, z, val_idx, writer):
|
382 |
+
n_prefix = int(z.shape[-1] * 0.25)
|
383 |
+
n_suffix = int(z.shape[-1] * 0.25)
|
384 |
+
|
385 |
+
vn = accel.unwrap(state.model)
|
386 |
+
|
387 |
+
mask = pmask.inpaint(z, n_prefix, n_suffix)
|
388 |
+
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
389 |
+
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
390 |
+
|
391 |
+
imputed_noisy = vn.to_signal(z_mask, state.codec)
|
392 |
+
imputed_true = vn.to_signal(z, state.codec)
|
393 |
+
|
394 |
+
imputed = []
|
395 |
+
for i in range(len(z)):
|
396 |
+
imputed.append(
|
397 |
+
vn.generate(
|
398 |
+
codec=state.codec,
|
399 |
+
time_steps=z.shape[-1],
|
400 |
+
start_tokens=z[i][None, ...],
|
401 |
+
mask=mask[i][None, ...],
|
402 |
+
)
|
403 |
+
)
|
404 |
+
imputed = AudioSignal.batch(imputed)
|
405 |
+
|
406 |
+
for i in range(len(val_idx)):
|
407 |
+
imputed_noisy[i].cpu().write_audio_to_tb(
|
408 |
+
f"imputed_noisy/{i}",
|
409 |
+
writer,
|
410 |
+
step=state.tracker.step,
|
411 |
+
plot_fn=None,
|
412 |
+
)
|
413 |
+
imputed[i].cpu().write_audio_to_tb(
|
414 |
+
f"imputed/{i}",
|
415 |
+
writer,
|
416 |
+
step=state.tracker.step,
|
417 |
+
plot_fn=None,
|
418 |
+
)
|
419 |
+
imputed_true[i].cpu().write_audio_to_tb(
|
420 |
+
f"imputed_true/{i}",
|
421 |
+
writer,
|
422 |
+
step=state.tracker.step,
|
423 |
+
plot_fn=None,
|
424 |
+
)
|
425 |
+
|
426 |
+
|
427 |
+
@torch.no_grad()
|
428 |
+
def save_samples(state: State, val_idx: int, writer: SummaryWriter):
|
429 |
+
state.model.eval()
|
430 |
+
state.codec.eval()
|
431 |
+
vn = accel.unwrap(state.model)
|
432 |
+
|
433 |
+
batch = [state.val_data[i] for i in val_idx]
|
434 |
+
batch = at.util.prepare_batch(state.val_data.collate(batch), accel.device)
|
435 |
+
|
436 |
+
signal = apply_transform(state.val_data.transform, batch)
|
437 |
+
|
438 |
+
z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
|
439 |
+
z = z[:, : vn.n_codebooks, :]
|
440 |
+
|
441 |
+
r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device)
|
442 |
+
|
443 |
+
|
444 |
+
mask = pmask.random(z, r)
|
445 |
+
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
446 |
+
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
447 |
+
|
448 |
+
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
449 |
+
|
450 |
+
z_hat = state.model(z_mask_latent, r)
|
451 |
+
|
452 |
+
z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
|
453 |
+
z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks)
|
454 |
+
z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
|
455 |
+
|
456 |
+
generated = vn.to_signal(z_pred, state.codec)
|
457 |
+
reconstructed = vn.to_signal(z, state.codec)
|
458 |
+
masked = vn.to_signal(z_mask.squeeze(1), state.codec)
|
459 |
+
|
460 |
+
for i in range(generated.batch_size):
|
461 |
+
audio_dict = {
|
462 |
+
"original": signal[i],
|
463 |
+
"masked": masked[i],
|
464 |
+
"generated": generated[i],
|
465 |
+
"reconstructed": reconstructed[i],
|
466 |
+
}
|
467 |
+
for k, v in audio_dict.items():
|
468 |
+
v.cpu().write_audio_to_tb(
|
469 |
+
f"samples/_{i}.r={r[i]:0.2f}/{k}",
|
470 |
+
writer,
|
471 |
+
step=state.tracker.step,
|
472 |
+
plot_fn=None,
|
473 |
+
)
|
474 |
+
|
475 |
+
save_sampled(state=state, z=z, writer=writer)
|
476 |
+
save_imputation(state=state, z=z, val_idx=val_idx, writer=writer)
|
477 |
+
|
478 |
+
|
479 |
+
|
480 |
@argbind.bind(without_prefix=True)
|
481 |
def load(
|
482 |
args,
|
483 |
accel: at.ml.Accelerator,
|
484 |
+
tracker: Tracker,
|
485 |
save_path: str,
|
486 |
resume: bool = False,
|
487 |
tag: str = "latest",
|
488 |
load_weights: bool = False,
|
489 |
fine_tune_checkpoint: Optional[str] = None,
|
490 |
+
grad_clip_val: float = 5.0,
|
491 |
+
) -> State:
|
492 |
codec = DAC.load(args["codec_ckpt"], map_location="cpu")
|
493 |
codec.eval()
|
494 |
|
|
|
500 |
"map_location": "cpu",
|
501 |
"package": not load_weights,
|
502 |
}
|
503 |
+
tracker.print(f"Loading checkpoint from {kwargs['folder']}")
|
504 |
if (Path(kwargs["folder"]) / "vampnet").exists():
|
505 |
model, v_extra = VampNet.load_from_folder(**kwargs)
|
506 |
else:
|
|
|
527 |
scheduler = NoamScheduler(optimizer, d_model=accel.unwrap(model).embedding_dim)
|
528 |
scheduler.step()
|
529 |
|
|
|
|
|
530 |
if "optimizer.pth" in v_extra:
|
531 |
optimizer.load_state_dict(v_extra["optimizer.pth"])
|
|
|
532 |
scheduler.load_state_dict(v_extra["scheduler.pth"])
|
533 |
+
if "tracker.pth" in v_extra:
|
534 |
+
tracker.load_state_dict(v_extra["tracker.pth"])
|
535 |
+
|
536 |
+
criterion = CrossEntropyLoss()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
537 |
|
538 |
+
sample_rate = codec.sample_rate
|
|
|
539 |
|
540 |
+
# a better rng for sampling from our schedule
|
541 |
+
rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=args["seed"])
|
542 |
|
543 |
+
# log a model summary w/ num params
|
544 |
+
if accel.local_rank == 0:
|
545 |
+
add_num_params_repr_hook(accel.unwrap(model))
|
546 |
+
with open(f"{save_path}/model.txt", "w") as f:
|
547 |
+
f.write(repr(accel.unwrap(model)))
|
548 |
|
549 |
+
# load the datasets
|
550 |
+
train_data, val_data = build_datasets(args, sample_rate)
|
551 |
+
|
552 |
+
return State(
|
553 |
+
tracker=tracker,
|
554 |
+
model=model,
|
555 |
+
codec=codec,
|
556 |
+
optimizer=optimizer,
|
557 |
+
scheduler=scheduler,
|
558 |
+
criterion=criterion,
|
559 |
+
rng=rng,
|
560 |
+
train_data=train_data,
|
561 |
+
val_data=val_data,
|
562 |
+
grad_clip_val=grad_clip_val,
|
563 |
+
)
|
564 |
|
565 |
|
566 |
@argbind.bind(without_prefix=True)
|
567 |
def train(
|
568 |
args,
|
569 |
accel: at.ml.Accelerator,
|
|
|
570 |
seed: int = 0,
|
571 |
+
codec_ckpt: str = None,
|
572 |
save_path: str = "ckpt",
|
573 |
+
num_iters: int = int(1000e6),
|
574 |
+
save_iters: list = [10000, 50000, 100000, 300000, 500000,],
|
575 |
+
sample_freq: int = 10000,
|
576 |
+
val_freq: int = 1000,
|
577 |
+
batch_size: int = 12,
|
|
|
578 |
val_idx: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
|
579 |
num_workers: int = 10,
|
|
|
|
|
580 |
fine_tune: bool = False,
|
|
|
581 |
):
|
582 |
assert codec_ckpt is not None, "codec_ckpt is required"
|
583 |
|
|
|
589 |
writer = SummaryWriter(log_dir=f"{save_path}/logs/")
|
590 |
argbind.dump_args(args, f"{save_path}/args.yml")
|
591 |
|
592 |
+
tracker = Tracker(
|
593 |
+
writer=writer, log_file=f"{save_path}/log.txt", rank=accel.local_rank
|
594 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
595 |
|
596 |
+
# load the codec model
|
597 |
+
state: State = load(
|
598 |
+
args=args,
|
599 |
+
accel=accel,
|
600 |
+
tracker=tracker,
|
601 |
+
save_path=save_path)
|
602 |
|
|
|
|
|
|
|
|
|
|
|
603 |
|
|
|
|
|
604 |
train_dataloader = accel.prepare_dataloader(
|
605 |
+
state.train_data,
|
606 |
+
start_idx=state.tracker.step * batch_size,
|
607 |
num_workers=num_workers,
|
608 |
batch_size=batch_size,
|
609 |
+
collate_fn=state.train_data.collate,
|
610 |
)
|
611 |
val_dataloader = accel.prepare_dataloader(
|
612 |
+
state.val_data,
|
613 |
start_idx=0,
|
614 |
num_workers=num_workers,
|
615 |
batch_size=batch_size,
|
616 |
+
collate_fn=state.val_data.collate,
|
617 |
+
persistent_workers=True,
|
618 |
)
|
619 |
|
620 |
+
|
621 |
|
622 |
if fine_tune:
|
623 |
+
lora.mark_only_lora_as_trainable(state.model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
624 |
|
625 |
+
# Wrap the functions so that they neatly track in TensorBoard + progress bars
|
626 |
+
# and only run when specific conditions are met.
|
627 |
+
global train_loop, val_loop, validate, save_samples, checkpoint
|
628 |
|
629 |
+
train_loop = tracker.log("train", "value", history=False)(
|
630 |
+
tracker.track("train", num_iters, completed=state.tracker.step)(train_loop)
|
631 |
+
)
|
632 |
+
val_loop = tracker.track("val", len(val_dataloader))(val_loop)
|
633 |
+
validate = tracker.log("val", "mean")(validate)
|
634 |
|
635 |
+
save_samples = when(lambda: accel.local_rank == 0)(save_samples)
|
636 |
+
checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
|
|
|
637 |
|
638 |
+
with tracker.live:
|
639 |
+
for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
|
640 |
+
train_loop(state, batch, accel)
|
641 |
|
642 |
+
last_iter = (
|
643 |
+
tracker.step == num_iters - 1 if num_iters is not None else False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
644 |
)
|
645 |
|
646 |
+
if tracker.step % sample_freq == 0 or last_iter:
|
647 |
+
save_samples(state, val_idx, writer)
|
648 |
|
649 |
+
if tracker.step % val_freq == 0 or last_iter:
|
650 |
+
validate(state, val_dataloader, accel)
|
651 |
+
checkpoint(
|
652 |
+
state=state,
|
653 |
+
save_iters=save_iters,
|
654 |
+
save_path=save_path,
|
655 |
+
fine_tune=fine_tune)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
656 |
|
657 |
+
# Reset validation progress bar, print summary since last validation.
|
658 |
+
tracker.done("val", f"Iteration {tracker.step}")
|
|
|
659 |
|
660 |
+
if last_iter:
|
661 |
+
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
662 |
|
663 |
|
664 |
if __name__ == "__main__":
|
|
|
666 |
args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0
|
667 |
with argbind.scope(args):
|
668 |
with Accelerator() as accel:
|
669 |
+
if accel.local_rank != 0:
|
670 |
+
sys.tracebacklimit = 0
|
671 |
train(args, accel)
|
setup.py
CHANGED
@@ -31,7 +31,7 @@ setup(
|
|
31 |
"numpy==1.22",
|
32 |
"wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat",
|
33 |
"lac @ git+https://github.com/hugofloresgarcia/lac.git",
|
34 |
-
"audiotools @ git+https://github.com/
|
35 |
"gradio",
|
36 |
"tensorboardX",
|
37 |
"loralib",
|
|
|
31 |
"numpy==1.22",
|
32 |
"wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat",
|
33 |
"lac @ git+https://github.com/hugofloresgarcia/lac.git",
|
34 |
+
"descript-audiotools @ git+https://github.com/descriptinc/audiotools.git@0.7.2",
|
35 |
"gradio",
|
36 |
"tensorboardX",
|
37 |
"loralib",
|
vampnet/modules/__init__.py
CHANGED
@@ -2,3 +2,5 @@ import audiotools
|
|
2 |
|
3 |
audiotools.ml.BaseModel.INTERN += ["vampnet.modules.**"]
|
4 |
audiotools.ml.BaseModel.EXTERN += ["einops", "flash_attn.flash_attention", "loralib"]
|
|
|
|
|
|
2 |
|
3 |
audiotools.ml.BaseModel.INTERN += ["vampnet.modules.**"]
|
4 |
audiotools.ml.BaseModel.EXTERN += ["einops", "flash_attn.flash_attention", "loralib"]
|
5 |
+
|
6 |
+
from .transformer import VampNet
|