Spaces:
Sleeping
Sleeping
import os | |
import subprocess | |
import time | |
import warnings | |
from pathlib import Path | |
from typing import Optional | |
import argbind | |
import audiotools as at | |
import torch | |
import torch.nn as nn | |
from audiotools import AudioSignal | |
from audiotools.data import transforms | |
from einops import rearrange | |
from rich import pretty | |
from rich.traceback import install | |
from tensorboardX import SummaryWriter | |
import vampnet | |
from vampnet.modules.transformer import VampNet | |
from vampnet.util import codebook_unflatten, codebook_flatten | |
from vampnet import mask as pmask | |
# from dac.model.dac import DAC | |
from lac.model.lac import LAC as DAC | |
# Enable cudnn autotuner to speed up training | |
# (can be altered by the funcs.seed function) | |
torch.backends.cudnn.benchmark = bool(int(os.getenv("CUDNN_BENCHMARK", 1))) | |
# Uncomment to trade memory for speed. | |
# Install to make things look nice | |
warnings.filterwarnings("ignore", category=UserWarning) | |
pretty.install() | |
install() | |
# optim | |
Accelerator = argbind.bind(at.ml.Accelerator, without_prefix=True) | |
CrossEntropyLoss = argbind.bind(nn.CrossEntropyLoss) | |
AdamW = argbind.bind(torch.optim.AdamW) | |
NoamScheduler = argbind.bind(vampnet.scheduler.NoamScheduler) | |
# transforms | |
filter_fn = lambda fn: hasattr(fn, "transform") and fn.__qualname__ not in [ | |
"BaseTransform", | |
"Compose", | |
"Choose", | |
] | |
tfm = argbind.bind_module(transforms, "train", "val", filter_fn=filter_fn) | |
# model | |
VampNet = argbind.bind(VampNet) | |
# data | |
AudioLoader = argbind.bind(at.datasets.AudioLoader) | |
AudioDataset = argbind.bind(at.datasets.AudioDataset, "train", "val") | |
IGNORE_INDEX = -100 | |
def build_transform(): | |
transform = transforms.Compose( | |
tfm.VolumeNorm(("const", -24)), | |
# tfm.PitchShift(), | |
tfm.RescaleAudio(), | |
) | |
return transform | |
def apply_transform(transform_fn, batch): | |
sig: AudioSignal = batch["signal"] | |
kwargs = batch["transform_args"] | |
sig: AudioSignal = transform_fn(sig.clone(), **kwargs) | |
return sig | |
def build_datasets(args, sample_rate: int): | |
with argbind.scope(args, "train"): | |
train_data = AudioDataset( | |
AudioLoader(), sample_rate, transform=build_transform() | |
) | |
with argbind.scope(args, "val"): | |
val_data = AudioDataset(AudioLoader(), sample_rate, transform=build_transform()) | |
with argbind.scope(args, "test"): | |
test_data = AudioDataset( | |
AudioLoader(), sample_rate, transform=build_transform() | |
) | |
return train_data, val_data, test_data | |
def rand_float(shape, low, high, rng): | |
return rng.draw(shape)[:, 0] * (high - low) + low | |
def flip_coin(shape, p, rng): | |
return rng.draw(shape)[:, 0] < p | |
def load( | |
args, | |
accel: at.ml.Accelerator, | |
save_path: str, | |
resume: bool = False, | |
tag: str = "latest", | |
load_weights: bool = False, | |
fine_tune_checkpoint: Optional[str] = None, | |
): | |
codec = DAC.load(args["codec_ckpt"], map_location="cpu") | |
codec.eval() | |
model, v_extra = None, {} | |
if resume: | |
kwargs = { | |
"folder": f"{save_path}/{tag}", | |
"map_location": "cpu", | |
"package": not load_weights, | |
} | |
if (Path(kwargs["folder"]) / "vampnet").exists(): | |
model, v_extra = VampNet.load_from_folder(**kwargs) | |
else: | |
raise ValueError( | |
f"Could not find a VampNet checkpoint in {kwargs['folder']}" | |
) | |
if args["fine_tune"]: | |
assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint" | |
model = VampNet.load(location=Path(fine_tune_checkpoint), map_location="cpu") | |
model = VampNet() if model is None else model | |
model = accel.prepare_model(model) | |
# assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks | |
assert ( | |
accel.unwrap(model).vocab_size == codec.quantizer.quantizers[0].codebook_size | |
) | |
optimizer = AdamW(model.parameters(), use_zero=accel.use_ddp) | |
scheduler = NoamScheduler(optimizer, d_model=accel.unwrap(model).embedding_dim) | |
scheduler.step() | |
trainer_state = {"state_dict": None, "start_idx": 0} | |
if "optimizer.pth" in v_extra: | |
optimizer.load_state_dict(v_extra["optimizer.pth"]) | |
if "scheduler.pth" in v_extra: | |
scheduler.load_state_dict(v_extra["scheduler.pth"]) | |
if "trainer.pth" in v_extra: | |
trainer_state = v_extra["trainer.pth"] | |
return { | |
"model": model, | |
"codec": codec, | |
"optimizer": optimizer, | |
"scheduler": scheduler, | |
"trainer_state": trainer_state, | |
} | |
def num_params_hook(o, p): | |
return o + f" {p/1e6:<.3f}M params." | |
def add_num_params_repr_hook(model): | |
import numpy as np | |
from functools import partial | |
for n, m in model.named_modules(): | |
o = m.extra_repr() | |
p = sum([np.prod(p.size()) for p in m.parameters()]) | |
setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p)) | |
def accuracy( | |
preds: torch.Tensor, | |
target: torch.Tensor, | |
top_k: int = 1, | |
ignore_index: Optional[int] = None, | |
) -> torch.Tensor: | |
# Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class) | |
preds = rearrange(preds, "b p s -> (b s) p") | |
target = rearrange(target, "b s -> (b s)") | |
# return torchmetrics.functional.accuracy(preds, target, task='multiclass', top_k=topk, num_classes=preds.shape[-1], ignore_index=ignore_index) | |
if ignore_index is not None: | |
# Create a mask for the ignored index | |
mask = target != ignore_index | |
# Apply the mask to the target and predictions | |
preds = preds[mask] | |
target = target[mask] | |
# Get the top-k predicted classes and their indices | |
_, pred_indices = torch.topk(preds, k=top_k, dim=-1) | |
# Determine if the true target is in the top-k predicted classes | |
correct = torch.sum(torch.eq(pred_indices, target.unsqueeze(1)), dim=1) | |
# Calculate the accuracy | |
accuracy = torch.mean(correct.float()) | |
return accuracy | |
def train( | |
args, | |
accel: at.ml.Accelerator, | |
codec_ckpt: str = None, | |
seed: int = 0, | |
save_path: str = "ckpt", | |
max_epochs: int = int(100e3), | |
epoch_length: int = 1000, | |
save_audio_epochs: int = 2, | |
save_epochs: list = [10, 50, 100, 200, 300, 400,], | |
batch_size: int = 48, | |
grad_acc_steps: int = 1, | |
val_idx: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], | |
num_workers: int = 10, | |
detect_anomaly: bool = False, | |
grad_clip_val: float = 5.0, | |
fine_tune: bool = False, | |
quiet: bool = False, | |
): | |
assert codec_ckpt is not None, "codec_ckpt is required" | |
seed = seed + accel.local_rank | |
at.util.seed(seed) | |
writer = None | |
if accel.local_rank == 0: | |
writer = SummaryWriter(log_dir=f"{save_path}/logs/") | |
argbind.dump_args(args, f"{save_path}/args.yml") | |
# load the codec model | |
loaded = load(args, accel, save_path) | |
model = loaded["model"] | |
codec = loaded["codec"] | |
optimizer = loaded["optimizer"] | |
scheduler = loaded["scheduler"] | |
trainer_state = loaded["trainer_state"] | |
sample_rate = codec.sample_rate | |
# a better rng for sampling from our schedule | |
rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=seed) | |
# log a model summary w/ num params | |
if accel.local_rank == 0: | |
add_num_params_repr_hook(accel.unwrap(model)) | |
with open(f"{save_path}/model.txt", "w") as f: | |
f.write(repr(accel.unwrap(model))) | |
# load the datasets | |
train_data, val_data, _ = build_datasets(args, sample_rate) | |
train_dataloader = accel.prepare_dataloader( | |
train_data, | |
start_idx=trainer_state["start_idx"], | |
num_workers=num_workers, | |
batch_size=batch_size, | |
collate_fn=train_data.collate, | |
) | |
val_dataloader = accel.prepare_dataloader( | |
val_data, | |
start_idx=0, | |
num_workers=num_workers, | |
batch_size=batch_size, | |
collate_fn=val_data.collate, | |
) | |
criterion = CrossEntropyLoss() | |
if fine_tune: | |
import loralib as lora | |
lora.mark_only_lora_as_trainable(model) | |
class Trainer(at.ml.BaseTrainer): | |
_last_grad_norm = 0.0 | |
def _metrics(self, vn, z_hat, r, target, flat_mask, output): | |
for r_range in [(0, 0.5), (0.5, 1.0)]: | |
unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX) | |
masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX) | |
assert target.shape[0] == r.shape[0] | |
# grab the indices of the r values that are in the range | |
r_idx = (r >= r_range[0]) & (r < r_range[1]) | |
# grab the target and z_hat values that are in the range | |
r_unmasked_target = unmasked_target[r_idx] | |
r_masked_target = masked_target[r_idx] | |
r_z_hat = z_hat[r_idx] | |
for topk in (1, 25): | |
s, e = r_range | |
tag = f"accuracy-{s}-{e}/top{topk}" | |
output[f"{tag}/unmasked"] = accuracy( | |
preds=r_z_hat, | |
target=r_unmasked_target, | |
ignore_index=IGNORE_INDEX, | |
top_k=topk, | |
) | |
output[f"{tag}/masked"] = accuracy( | |
preds=r_z_hat, | |
target=r_masked_target, | |
ignore_index=IGNORE_INDEX, | |
top_k=topk, | |
) | |
def train_loop(self, engine, batch): | |
model.train() | |
batch = at.util.prepare_batch(batch, accel.device) | |
signal = apply_transform(train_data.transform, batch) | |
output = {} | |
vn = accel.unwrap(model) | |
with accel.autocast(): | |
with torch.inference_mode(): | |
codec.to(accel.device) | |
z = codec.encode(signal.samples, signal.sample_rate)["codes"] | |
z = z[:, : vn.n_codebooks, :] | |
n_batch = z.shape[0] | |
r = rng.draw(n_batch)[:, 0].to(accel.device) | |
mask = pmask.random(z, r) | |
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks) | |
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token) | |
z_mask_latent = vn.embedding.from_codes(z_mask, codec) | |
dtype = torch.bfloat16 if accel.amp else None | |
with accel.autocast(dtype=dtype): | |
z_hat = model(z_mask_latent, r) | |
target = codebook_flatten( | |
z[:, vn.n_conditioning_codebooks :, :], | |
) | |
flat_mask = codebook_flatten( | |
mask[:, vn.n_conditioning_codebooks :, :], | |
) | |
# replace target with ignore index for masked tokens | |
t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX) | |
output["loss"] = criterion(z_hat, t_masked) | |
self._metrics( | |
vn=vn, | |
r=r, | |
z_hat=z_hat, | |
target=target, | |
flat_mask=flat_mask, | |
output=output, | |
) | |
accel.backward(output["loss"] / grad_acc_steps) | |
output["other/learning_rate"] = optimizer.param_groups[0]["lr"] | |
output["other/batch_size"] = z.shape[0] | |
if ( | |
(engine.state.iteration % grad_acc_steps == 0) | |
or (engine.state.iteration % epoch_length == 0) | |
or (engine.state.iteration % epoch_length == 1) | |
): # (or we reached the end of the epoch) | |
accel.scaler.unscale_(optimizer) | |
output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_( | |
model.parameters(), grad_clip_val | |
) | |
self._last_grad_norm = output["other/grad_norm"] | |
accel.step(optimizer) | |
optimizer.zero_grad() | |
scheduler.step() | |
accel.update() | |
else: | |
output["other/grad_norm"] = self._last_grad_norm | |
return {k: v for k, v in sorted(output.items())} | |
def val_loop(self, engine, batch): | |
model.eval() | |
codec.eval() | |
batch = at.util.prepare_batch(batch, accel.device) | |
signal = apply_transform(val_data.transform, batch) | |
vn = accel.unwrap(model) | |
z = codec.encode(signal.samples, signal.sample_rate)["codes"] | |
z = z[:, : vn.n_codebooks, :] | |
n_batch = z.shape[0] | |
r = rng.draw(n_batch)[:, 0].to(accel.device) | |
mask = pmask.random(z, r) | |
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks) | |
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token) | |
z_mask_latent = vn.embedding.from_codes(z_mask, codec) | |
z_hat = model(z_mask_latent, r) | |
target = codebook_flatten( | |
z[:, vn.n_conditioning_codebooks :, :], | |
) | |
flat_mask = codebook_flatten( | |
mask[:, vn.n_conditioning_codebooks :, :] | |
) | |
output = {} | |
# replace target with ignore index for masked tokens | |
t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX) | |
output["loss"] = criterion(z_hat, t_masked) | |
self._metrics( | |
vn=vn, | |
r=r, | |
z_hat=z_hat, | |
target=target, | |
flat_mask=flat_mask, | |
output=output, | |
) | |
return output | |
def checkpoint(self, engine): | |
if accel.local_rank != 0: | |
print(f"ERROR:Skipping checkpoint on rank {accel.local_rank}") | |
return | |
metadata = {"logs": dict(engine.state.logs["epoch"])} | |
if self.state.epoch % save_audio_epochs == 0: | |
self.save_samples() | |
tags = ["latest"] | |
loss_key = "loss/val" if "loss/val" in metadata["logs"] else "loss/train" | |
self.print(f"Saving to {str(Path('.').absolute())}") | |
if self.state.epoch in save_epochs: | |
tags.append(f"epoch={self.state.epoch}") | |
if self.is_best(engine, loss_key): | |
self.print(f"Best model so far") | |
tags.append("best") | |
if fine_tune: | |
for tag in tags: | |
# save the lora model | |
(Path(save_path) / tag).mkdir(parents=True, exist_ok=True) | |
torch.save( | |
lora.lora_state_dict(accel.unwrap(model)), | |
f"{save_path}/{tag}/lora.pth" | |
) | |
for tag in tags: | |
model_extra = { | |
"optimizer.pth": optimizer.state_dict(), | |
"scheduler.pth": scheduler.state_dict(), | |
"trainer.pth": { | |
"start_idx": self.state.iteration * batch_size, | |
"state_dict": self.state_dict(), | |
}, | |
"metadata.pth": metadata, | |
} | |
accel.unwrap(model).metadata = metadata | |
accel.unwrap(model).save_to_folder( | |
f"{save_path}/{tag}", model_extra, | |
) | |
def save_sampled(self, z): | |
num_samples = z.shape[0] | |
for i in range(num_samples): | |
sampled = accel.unwrap(model).generate( | |
codec=codec, | |
time_steps=z.shape[-1], | |
start_tokens=z[i : i + 1], | |
) | |
sampled.cpu().write_audio_to_tb( | |
f"sampled/{i}", | |
self.writer, | |
step=self.state.epoch, | |
plot_fn=None, | |
) | |
def save_imputation(self, z: torch.Tensor): | |
n_prefix = int(z.shape[-1] * 0.25) | |
n_suffix = int(z.shape[-1] * 0.25) | |
vn = accel.unwrap(model) | |
mask = pmask.inpaint(z, n_prefix, n_suffix) | |
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks) | |
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token) | |
imputed_noisy = vn.to_signal(z_mask, codec) | |
imputed_true = vn.to_signal(z, codec) | |
imputed = [] | |
for i in range(len(z)): | |
imputed.append( | |
vn.generate( | |
codec=codec, | |
time_steps=z.shape[-1], | |
start_tokens=z[i][None, ...], | |
mask=mask[i][None, ...], | |
) | |
) | |
imputed = AudioSignal.batch(imputed) | |
for i in range(len(val_idx)): | |
imputed_noisy[i].cpu().write_audio_to_tb( | |
f"imputed_noisy/{i}", | |
self.writer, | |
step=self.state.epoch, | |
plot_fn=None, | |
) | |
imputed[i].cpu().write_audio_to_tb( | |
f"imputed/{i}", | |
self.writer, | |
step=self.state.epoch, | |
plot_fn=None, | |
) | |
imputed_true[i].cpu().write_audio_to_tb( | |
f"imputed_true/{i}", | |
self.writer, | |
step=self.state.epoch, | |
plot_fn=None, | |
) | |
def save_samples(self): | |
model.eval() | |
codec.eval() | |
vn = accel.unwrap(model) | |
batch = [val_data[i] for i in val_idx] | |
batch = at.util.prepare_batch(val_data.collate(batch), accel.device) | |
signal = apply_transform(val_data.transform, batch) | |
z = codec.encode(signal.samples, signal.sample_rate)["codes"] | |
z = z[:, : vn.n_codebooks, :] | |
r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device) | |
mask = pmask.random(z, r) | |
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks) | |
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token) | |
z_mask_latent = vn.embedding.from_codes(z_mask, codec) | |
z_hat = model(z_mask_latent, r) | |
z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1) | |
z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks) | |
z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1) | |
generated = vn.to_signal(z_pred, codec) | |
reconstructed = vn.to_signal(z, codec) | |
masked = vn.to_signal(z_mask.squeeze(1), codec) | |
for i in range(generated.batch_size): | |
audio_dict = { | |
"original": signal[i], | |
"masked": masked[i], | |
"generated": generated[i], | |
"reconstructed": reconstructed[i], | |
} | |
for k, v in audio_dict.items(): | |
v.cpu().write_audio_to_tb( | |
f"samples/_{i}.r={r[i]:0.2f}/{k}", | |
self.writer, | |
step=self.state.epoch, | |
plot_fn=None, | |
) | |
self.save_sampled(z) | |
self.save_imputation(z) | |
trainer = Trainer(writer=writer, quiet=quiet) | |
if trainer_state["state_dict"] is not None: | |
trainer.load_state_dict(trainer_state["state_dict"]) | |
if hasattr(train_dataloader.sampler, "set_epoch"): | |
train_dataloader.sampler.set_epoch(trainer.trainer.state.epoch) | |
trainer.run( | |
train_dataloader, | |
val_dataloader, | |
num_epochs=max_epochs, | |
epoch_length=epoch_length, | |
detect_anomaly=detect_anomaly, | |
) | |
if __name__ == "__main__": | |
args = argbind.parse_args() | |
args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0 | |
with argbind.scope(args): | |
with Accelerator() as accel: | |
train(args, accel) | |