"""Defines the audio model for pitch estimation.""" import torch import torch.nn as nn import einops import math import numpy as np import einops import pytorch_lightning as pl import shared.utils as su class TimeEncodingDiscreteSinusoidal(nn.Module): def __init__(self, d, v=10000, rate=49, scale_factor=0.01): """ Args: d (int): Dimension rate (int): discretisation rate (frames per second) this means that each [1/49.] of a second will be encoded with a unique vector """ super().__init__() self.d = d self.rate = rate self.v = v self.scale_factor = scale_factor def forward(self, t): """ Takes in timestamps t (seconds) and outputs vectors that represent these. Args: t (torch.tensor): time stamps in seconds, [B, N] """ B, N = t.shape # Discretise time i = (t * self.rate).to(int) pe = torch.zeros(B, N, self.d).to(t.device) div_term = torch.exp( (torch.arange(0, self.d, 2, dtype=torch.float) * -(math.log(self.v) / self.d)) ) div_term = div_term.to(t.device) pe[:, :, 0::2] = torch.sin(i[:, :, None].float() * div_term) pe[:, :, 1::2] = torch.cos(i[:, :, None].float() * div_term) pe = pe * self.scale_factor return pe class Wav2Vec2WithTimeEncoding(nn.Module): def __init__( self, model_name="facebook/wav2vec2-base-960h", use_time=True, d=512, v=10000, rate=49, scale_factor=0.01, layer_norm=False, ): super().__init__() su.log.print_update( f" [:::] Loading backbone Wav2Vec 2.0 ", pos="left", fillchar=".", color="cyan", ) # Load pre-trained Wav2Vec 2.0 model from transformers import Wav2Vec2Model self.net = Wav2Vec2Model.from_pretrained(model_name) self.d = d self.v = v self.rate = rate self.sr = 16000 self.use_time = use_time if self.use_time: self.time_encoding = TimeEncodingDiscreteSinusoidal( d=d, v=v, rate=rate, scale_factor=scale_factor, ) else: print(" [:::] Not using time encoding.") self.time_encoding = None # Have a layer norm for the time encoding if layer_norm: self.layer_norm = nn.LayerNorm(d) else: self.layer_norm = nn.Identity() def forward(self, x, t): """ Args: x (torch.tensor): audio input, [B, NC, C, NS], NC: n.o. clips, NS: n.o. samples t (torch.tensor): time stamps in seconds, [B, NC, 2], start and end times for each clip """ B, T, C, NS = x.shape assert C == 1, "Require a single-channel input." assert t.shape[1] == T, \ "Number of timestamps should match number of clips." assert t.shape[0] == B, \ "Batch size should match." assert t.shape[2] == 2, \ "Timestamps should have start and end times." # # Compute number of frames # NF = int((NS / self.sr) * self.rate) # Process inputs x = einops.rearrange(x, "B T 1 NS -> (B T) NS") t = einops.rearrange(t, "B T L -> (B T) L") # This forward is based on Huggingface's implementation of Wave2Vec2 # https://github.com/huggingface/transformers/blob/main/src/ # transformers/models/wav2vec2/modeling_wav2vec2.py # Encode through the CNN extract_features = self.net.feature_extractor(x) extract_features = extract_features.transpose(1, 2) if self.use_time: # Process timestamps: get timestamps for each frame # within each clip (fps=49) NF = extract_features.shape[1] t_dense = [] for i in range(B): start, end = t[i] t_dense.append(torch.linspace(start, end, NF)) t_dense = torch.stack(t_dense).to(extract_features.device) # Add time encoding to the features t_dense_enc = self.time_encoding(t_dense) # Normalise time encoding to have the same scale as the features extract_features = extract_features + t_dense_enc else: pass # Apply layer norm extract_features = self.layer_norm(extract_features) # Project into the feature space hidden_states, extract_features = self.net.feature_projection( extract_features ) # Pass through the transformer encoder encoder_outputs = self.net.encoder( hidden_states, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True, ) z = encoder_outputs[0] # z = self.backbone(x).last_hidden_state z = einops.rearrange(z, "(B T) F D -> B T F D", B=B, T=T) return z def recursive_attr(module, attr): if "." in attr: m, a = attr.split(".", 1) return recursive_attr(getattr(module, m), a) return getattr(module, attr) class WavelengthWithTime(pl.LightningModule): def __init__( self, backbone, feat_dim=768, axial=True, axial_bins=512, radial=True, radial_bins=512, freeze_backbone=True, train_backbone_modules=[10, 11], prediction_head_hidden=[], act="softmax", criterion="kl_div", cfg_opt=dict(name="Adam", args=dict(lr=1e-4)), ): super().__init__() su.log.print_update( " [:::] Loading model WavelengthWithTime ", color="cyan", pos="left", fillchar=".", ) # By default, freeze the entire backbone if freeze_backbone: self.freeze(backbone) # Unfreeze specific modules train_backbone_modules = [ backbone.net.encoder.layers[int(m)] for m in train_backbone_modules ] for module in train_backbone_modules: self.unfreeze(module) # Make the layer norm in backbone trainable print("[>>>] Unfreezing layer norm in backbone") for param in backbone.layer_norm.parameters(): param.requires_grad = True su.misc.num_trainable_params(backbone) self.backbone = backbone self.feat_dim = feat_dim # Add some intermediate layers before prediction heads if len(prediction_head_hidden) > 0: layers = [] in_dim = feat_dim for out_dim in prediction_head_hidden: layers.append(nn.Linear(in_dim, out_dim)) layers.append(nn.ReLU()) in_dim = out_dim self.intermediate_layers = nn.Sequential(*layers) else: self.intermediate_layers = torch.nn.Identity() out_dim = feat_dim su.misc.num_trainable_params(self.intermediate_layers) assert axial or radial, \ "At least one of axial or radial heads must be enabled." # Define axial head self.axial_head = None if axial: self.axial_head = nn.Linear(out_dim, axial_bins) su.misc.num_trainable_params(self.axial_head) # Define radial head self.radial_head = None if radial: self.radial_head = nn.Linear(out_dim, radial_bins) su.misc.num_trainable_params(self.radial_head) self.act = torch.nn.Softmax(dim=-1) if act == "softmax" else torch.nn.Identity() # Set criterion self.define_criterion(criterion) # Define optimization config self.cfg_opt = cfg_opt # Save hyperparameters self.save_hyperparameters(ignore=["backbone"]) def freeze_backbone(self): for param in self.backbone.parameters(): param.requires_grad = False def define_criterion(self, criterion): if criterion == "kl_div": self.criterion = nn.KLDivLoss() elif criterion == "ce": self.criterion = nn.CrossEntropyLoss() else: raise NotImplementedError(f"Criterion {criterion} not implemented.") def freeze(self, net): for p in net.parameters(): p.requires_grad = False def unfreeze(self, module): module_name = type(module).__name__ print(f"[>>>] Unfreezing {module_name}") for p in module.parameters(): p.requires_grad = True def forward(self, x, t): """ Args: x (torch.Tensor): [B, T, C, NS], T: n.o. clips t (torch.Tensor): [B, T, 2], clip start and end times """ B, T, C, NS = x.shape z = self.backbone.forward(x, t) # assert C == 1, "Require a single-channel input." # x = einops.rearrange(x, "B T 1 NS -> (B T) NS") # z = self.backbone(x).last_hidden_state # z = einops.rearrange(z, "(B T) F D -> B T F D", B=B, D=self.feat_dim) # Intermediate layers h = self.intermediate_layers(z) # Prediction heads y_pred = dict() if self.axial_head is not None: axial = self.act(self.axial_head(h)) y_pred["axial"] = axial if self.radial_head is not None: radial = self.act(self.radial_head(h)) y_pred["radial"] = radial return y_pred def compute_loss(self, y_pred: dict, y_true: dict): loss = dict() total_loss = 0. for key in y_pred: yt = y_true[key] yt = einops.rearrange(yt, "b t d f -> b t f d") yp = y_pred[key] if isinstance(self.criterion, nn.KLDivLoss): # Need to pass log to the loss function if it is KLDivLoss yp = yp.log() loss[key] = self.criterion(yp, yt) elif isinstance(self.criterion, nn.CrossEntropyLoss): yp = einops.rearrange(yp, "b t f d -> (b t f) d") yt = einops.rearrange(yt, "b t f d -> (b t f) d") loss[key] = self.criterion(yp, yt) else: raise NotImplementedError(f"Criterion {self.criterion} not implemented.") # For now, using hardcoded loss weights of 1/K where K is number of losses total_loss += loss[key] / len(y_pred) loss["total"] = total_loss return loss # Fill in the rest of the class definition here def step(self, batch, mode, log=True): x = batch["audio_clips"] t = batch["clips"] y_true = {**batch["targets"], **batch["metadata"]} y_pred = self.forward(x, t) losses = self.compute_loss(y_pred, y_true) loss = losses["total"] if log: self.log(f"batch/{mode}/loss_net", loss, prog_bar=True, sync_dist=True) return loss def training_step(self, batch, batch_idx): return self.step(batch, "train") def validation_step(self, batch, batch_idx): return self.step(batch, "valid") def configure_optimizers(self): function = getattr(torch.optim, self.cfg_opt["name"]) optimizer = function(self.parameters(), **self.cfg_opt["args"]) return optimizer if __name__ == "__main__": import os # Test backbone backbone = Wav2Vec2WithTimeEncoding() su.misc.num_params(backbone) # Test on a real audio clip path = "./media_assets/pouring_water_in_a_glass.wav" import torchaudio waveform, sr = torchaudio.load(path) waveform = torchaudio.functional.resample(waveform, sr, 16000) sr = 16000 waveform = waveform.mean(dim=0, keepdim=True) # Forward pass an entire audio from transformers import Wav2Vec2Processor model_name = "facebook/wav2vec2-base-960h" processor = Wav2Vec2Processor.from_pretrained(model_name) s, e = 8, 22 x = processor( waveform[:, int(s*sr):int(e*sr)], sampling_rate=16000, return_tensors="pt", ).input_values.unsqueeze(0) duration = waveform.shape[-1] / sr t = torch.tensor([[s, e]]).unsqueeze(0) z = backbone(x, t) # Let's look at the tsne z_flat = einops.rearrange(z, "B T F D -> (B T F) D") import matplotlib.pyplot as plt # Add serif plt.rcParams["font.family"] = "serif" su.visualize.show_temporal_tsne(z_flat.detach().numpy(), show=False) plt.savefig("./media_assets/tsne.png") plt.close() # Test model cfg_model = { "name": "WavelengthWithTime", "args": { "axial": True, "axial_bins": 64, "radial": True, "radial_bins": 64, "freeze_backbone": True, "train_backbone_modules": [6, 7, 8, 9, 10, 11], "act": "softmax", "criterion": "kl_div", } } model = eval(cfg_model["name"])(backbone=backbone, **cfg_model["args"]) su.misc.num_trainable_params(model) # Load pre-trained checkpoint ckpt_dir = "/work/piyush/pretrained_checkpoints/SoundOfWater" ckpt_path = os.path.join( ckpt_dir, "dsr9mf13_ep100_step12423_real_finetuned_with_cosupervision.pth", ) assert os.path.exists(ckpt_path), \ f"Checkpoint not found at {ckpt_path}." print("Loading checkpoint from: ", ckpt_path) ckpt = torch.load(ckpt_path, map_location="cpu") msg = model.load_state_dict(ckpt) print(msg) # Check forward pass x_random = torch.randn(2, 1, 1, 16000) t_random = torch.tensor([[[0, 1]], [[2, 3]]]) y_pred = model(x_random, t_random) print("Input: ", x_random.shape) for key in y_pred: print(key, y_pred[key].shape) # Plot features with the trained backbone and save as tsne_trained.png z = model.backbone(x, t) z_flat = einops.rearrange(z, "B T F D -> (B T F) D") su.visualize.show_temporal_tsne(z_flat.detach().numpy(), show=False) plt.savefig("./media_assets/tsne_trained.png") plt.close()