--- license: apache-2.0 --- Whisper like ASR model but with some advanced ideas. Experimental. Full script just install dependencies and run. The model included is -not- trained. Its a blank (tabula rasa) newly intialized version of the script "medium" sized. I'm experimenting with some of the new stuff from the vision llm people but with audio.. Here is a super cool paper: https://www.frontiersin.org/journals/neuroscience/articles/10.3389/fnins.2022.949142/full Updated. Was having some issues there with the hybrid attention and tensor sharing.. fixed.! Drop-in enhanced givens rotary block -- Its like a rubiks cube of embbedings :) Think of regular rope embeddings as a rotating 3d block in space.. Now add columns and rows that rotate and then rotate the faces of each resulting cube : class CombinedRotaryEmbedding(nn.Module): def __init__(self, n_state, n_head, num_rotations, base=10000, checkpointing=False): super().__init__() self.n_state = n_state self.n_head = n_head self.h_dim = n_state // n_head self.num_rotations = num_rotations self.base = base self.checkpointing = checkpointing self.thetas = nn.Parameter(torch.zeros(num_rotations)) self.rotation_pairs = nn.Parameter(data=torch.rand(num_rotations, 2) * self.h_dim) self.theta_scale = nn.Parameter(data=torch.ones(1)) self.rotation_matrix = nn.Parameter(data=torch.eye(n=self.h_dim)) self.inv_freq = nn.Parameter(data=1.0 / (self.base ** (torch.arange(start=0, end=self.h_dim, step=2).float() / self.h_dim))) def givens_rotation_matrix(self, n_state, i, j, theta): G = torch.eye(n_state, device=theta.device) G[i, i] = math.cos(theta) G[i, j] = -math.sin(theta) G[j, i] = math.sin(theta) G[j, j] = math.cos(theta) return G def update_base(self, new_base): self.base = float(new_base) self.base = new_base self.inv_freq = nn.Parameter(data=1.0 / (self.base ** (torch.arange(start=0, end=self.h_dim, step=2).float() / self.h_dim))) def reset_parameters(self): nn.init.orthogonal_(tensor=self.rotation_matrix) nn.init.zeros_(tensor=self.thetas) def forward(self, x): if self.checkpointing: return checkpoint(self._forward, x) else: return self._forward(x) def _forward(self, x): if x.dim() not in [3, 4]: raise ValueError(f"Expected input tensor to be 3D or 4D, but got {x.dim()}D") if x.dim() == 3: batch_size, seq_len, n_state = x.size() x = x.view(batch_size, seq_len, self.n_head, self.h_dim) else: batch_size, seq_len, n_head, h_dim = x.size() if n_head != self.n_head or h_dim != self.h_dim: raise ValueError(f"Expected n_head {self.n_head} and h_dim {self.h_dim}, but got n_head {n_head} and h_dim {h_dim}") x = x.reshape(-1, self.h_dim) for k in range(self.num_rotations): i, j = self.rotation_pairs[k].long() theta = self.thetas[k] * self.theta_scale G = self.givens_rotation_matrix(n_state=self.h_dim, i=i, j=j, theta=theta) x = torch.matmul(input=x, other=G) x = torch.matmul(input=x, other=self.rotation_matrix) x = x.view(batch_size, seq_len, self.n_head, self.h_dim) sinusoid_inp = torch.einsum('i, j -> i j', torch.arange(end=seq_len, device=x.device), self.inv_freq.to(device=x.device)) sin = sinusoid_inp.sin()[None, :, None, :] cos = sinusoid_inp.cos()[None, :, None, :] x1, x2 = x[..., ::2], x[..., 1::2] x = torch.cat(tensors=[x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) x = x.view(batch_size, seq_len, self.n_state) return x