import numpy as np import torch import torch.nn.functional as F from torch import nn from torchmetrics import MeanAbsoluteError from torch.optim.lr_scheduler import ReduceLROnPlateau import lightning as pl class BlurPool1D(nn.Module): def __init__(self, channels, pad_type="reflect", filt_size=3, stride=2, pad_off=0): super(BlurPool1D, self).__init__() self.filt_size = filt_size self.pad_off = pad_off self.pad_sizes = [ int(1.0 * (filt_size - 1) / 2), int(np.ceil(1.0 * (filt_size - 1) / 2)), ] self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes] self.stride = stride self.off = int((self.stride - 1) / 2.0) self.channels = channels # print('Filter size [%i]' % filt_size) if self.filt_size == 1: a = np.array( [ 1.0, ] ) elif self.filt_size == 2: a = np.array([1.0, 1.0]) elif self.filt_size == 3: a = np.array([1.0, 2.0, 1.0]) elif self.filt_size == 4: a = np.array([1.0, 3.0, 3.0, 1.0]) elif self.filt_size == 5: a = np.array([1.0, 4.0, 6.0, 4.0, 1.0]) elif self.filt_size == 6: a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0]) elif self.filt_size == 7: a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0]) filt = torch.Tensor(a) filt = filt / torch.sum(filt) self.register_buffer("filt", filt[None, None, :].repeat((self.channels, 1, 1))) self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes) def forward(self, inp): if self.filt_size == 1: if self.pad_off == 0: return inp[:, :, :: self.stride] else: return self.pad(inp)[:, :, :: self.stride] else: return F.conv1d( self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1] ) def get_pad_layer_1d(pad_type): if pad_type in ["refl", "reflect"]: PadLayer = nn.ReflectionPad1d elif pad_type in ["repl", "replicate"]: PadLayer = nn.ReplicationPad1d elif pad_type == "zero": PadLayer = nn.ZeroPad1d else: print("Pad type [%s] not recognized" % pad_type) return PadLayer from masksembles import common class Masksembles1D(nn.Module): def __init__(self, channels: int, n: int, scale: float): super().__init__() self.channels = channels self.n = n self.scale = scale masks = common.generation_wrapper(channels, n, scale) masks = torch.from_numpy(masks) self.masks = torch.nn.Parameter(masks, requires_grad=False) def forward(self, inputs): batch = inputs.shape[0] x = torch.split(inputs.unsqueeze(1), batch // self.n, dim=0) x = torch.cat(x, dim=1).permute([1, 0, 2, 3]) x = x * self.masks.unsqueeze(1).unsqueeze(-1) x = torch.cat(torch.split(x, 1, dim=0), dim=1) return x.squeeze(0).type(inputs.dtype) class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_planes, planes, stride=1, kernel_size=7, groups=1): super(BasicBlock, self).__init__() self.conv1 = nn.Conv1d( in_planes, planes, kernel_size=kernel_size, stride=stride, padding="same", bias=False, ) self.bn1 = nn.BatchNorm1d(planes) self.conv2 = nn.Conv1d( planes, planes, kernel_size=kernel_size, stride=1, padding="same", bias=False, ) self.bn2 = nn.BatchNorm1d(planes) self.shortcut = nn.Sequential( nn.Conv1d( in_planes, self.expansion * planes, kernel_size=1, stride=stride, padding="same", bias=False, ), nn.BatchNorm1d(self.expansion * planes), ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = F.relu(out) return out class Updated_onset_picker(nn.Module): def __init__( self, ): super().__init__() # self.activation = nn.ReLU() # self.maxpool = nn.MaxPool1d(2) self.n_masks = 128 self.block1 = nn.Sequential( BasicBlock(3, 8, kernel_size=7, groups=1), nn.GELU(), BlurPool1D(8, filt_size=3, stride=2), nn.GroupNorm(2, 8), ) self.block2 = nn.Sequential( BasicBlock(8, 16, kernel_size=7, groups=8), nn.GELU(), BlurPool1D(16, filt_size=3, stride=2), nn.GroupNorm(2, 16), ) self.block3 = nn.Sequential( BasicBlock(16, 32, kernel_size=7, groups=16), nn.GELU(), BlurPool1D(32, filt_size=3, stride=2), nn.GroupNorm(2, 32), ) self.block4 = nn.Sequential( BasicBlock(32, 64, kernel_size=7, groups=32), nn.GELU(), BlurPool1D(64, filt_size=3, stride=2), nn.GroupNorm(2, 64), ) self.block5 = nn.Sequential( BasicBlock(64, 128, kernel_size=7, groups=64), nn.GELU(), BlurPool1D(128, filt_size=3, stride=2), nn.GroupNorm(2, 128), ) self.block6 = nn.Sequential( Masksembles1D(128, self.n_masks, 2.0), BasicBlock(128, 256, kernel_size=7, groups=128), nn.GELU(), BlurPool1D(256, filt_size=3, stride=2), nn.GroupNorm(2, 256), ) self.block7 = nn.Sequential( Masksembles1D(256, self.n_masks, 2.0), BasicBlock(256, 512, kernel_size=7, groups=256), BlurPool1D(512, filt_size=3, stride=2), nn.GELU(), nn.GroupNorm(2, 512), ) self.block8 = nn.Sequential( Masksembles1D(512, self.n_masks, 2.0), BasicBlock(512, 1024, kernel_size=7, groups=512), BlurPool1D(1024, filt_size=3, stride=2), nn.GELU(), nn.GroupNorm(2, 1024), ) self.block9 = nn.Sequential( Masksembles1D(1024, self.n_masks, 2.0), BasicBlock(1024, 128, kernel_size=7, groups=128), # BlurPool1D(512, filt_size=3, stride=2), # nn.GELU(), # nn.GroupNorm(2,512), ) self.out = nn.Sequential(nn.Linear(3072, 2), nn.Sigmoid()) def forward(self, x): # Feature extraction x = self.block1(x) x = self.block2(x) x = self.block3(x) x = self.block4(x) x = self.block5(x) x = self.block6(x) x = self.block7(x) x = self.block8(x) x = self.block9(x) # Regressor x = x.flatten(start_dim=1) x = self.out(x) return x class Onset_picker(pl.LightningModule): def __init__(self, picker, learning_rate): super().__init__() self.picker = picker self.learning_rate = learning_rate self.save_hyperparameters(ignore=['picker']) self.mae = MeanAbsoluteError() def compute_loss(self, y, pick, mae_name=False): y_filt = y[y != 0] pick_filt = pick[y != 0] if len(y_filt) > 0: loss = F.l1_loss(y_filt, pick_filt.flatten()) if mae_name != False: mae_phase = self.mae(y_filt, pick_filt.flatten())*60 self.log(f'MAE/{mae_name}_val', mae_phase, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) else: loss = 0 return loss def training_step(self, batch, batch_idx): # training_step defines the train loop. x, y_p, y_s = batch # x, y_p, y_s, y_pg, y_sg, y_pn, y_sn = batch picks = self.picker(x) p_pick = picks[:,0] s_pick = picks[:,1] p_loss = self.compute_loss(y_p, p_pick) s_loss = self.compute_loss(y_s, s_pick) loss = (p_loss+s_loss)/2 self.log('Loss/train', loss, on_step=True, on_epoch=False, prog_bar=True, sync_dist=True) return loss def validation_step(self, batch, batch_idx): x, y_p, y_s = batch picks = self.picker(x) p_pick = picks[:,0] s_pick = picks[:,1] p_loss = self.compute_loss(y_p, p_pick, mae_name='P') s_loss = self.compute_loss(y_s, s_pick, mae_name='S') loss = (p_loss+s_loss)/2 self.log('Loss/val', loss, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) return loss def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, cooldown=10, threshold=1e-3) # scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, 3e-4, epochs=300, steps_per_epoch=len(train_loader)) monitor = 'Loss/train' return {"optimizer": optimizer, "lr_scheduler": scheduler, 'monitor': monitor} def forward(self, x): picks = self.picker(x) return picks