Update README.md
Browse files
README.md
CHANGED
@@ -5,4 +5,78 @@ Whisper like ASR model but with some advanced ideas. Experimental. Full script j
|
|
5 |
I'm experimenting with some of the new stuff from the vision llm people but with audio.. Here is a super cool paper:
|
6 |
https://www.frontiersin.org/journals/neuroscience/articles/10.3389/fnins.2022.949142/full
|
7 |
|
8 |
-
Updated. Was having some issues there with the hybrid attention and tensor sharing.. fixed.!
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
I'm experimenting with some of the new stuff from the vision llm people but with audio.. Here is a super cool paper:
|
6 |
https://www.frontiersin.org/journals/neuroscience/articles/10.3389/fnins.2022.949142/full
|
7 |
|
8 |
+
Updated. Was having some issues there with the hybrid attention and tensor sharing.. fixed.!
|
9 |
+
|
10 |
+
Drop-in enhanced givens rotary block -- Its like a rubiks cube of embbedings :)
|
11 |
+
|
12 |
+
|
13 |
+
class CombinedRotaryEmbedding(nn.Module):
|
14 |
+
def __init__(self, n_state, n_head, num_rotations, base=10000, checkpointing=False):
|
15 |
+
super().__init__()
|
16 |
+
self.n_state = n_state
|
17 |
+
self.n_head = n_head
|
18 |
+
self.h_dim = n_state // n_head
|
19 |
+
self.num_rotations = num_rotations
|
20 |
+
self.base = base
|
21 |
+
self.checkpointing = checkpointing
|
22 |
+
|
23 |
+
self.thetas = nn.Parameter(torch.zeros(num_rotations))
|
24 |
+
self.rotation_pairs = nn.Parameter(data=torch.rand(num_rotations, 2) * self.h_dim)
|
25 |
+
self.theta_scale = nn.Parameter(data=torch.ones(1))
|
26 |
+
self.rotation_matrix = nn.Parameter(data=torch.eye(n=self.h_dim))
|
27 |
+
self.inv_freq = nn.Parameter(data=1.0 / (self.base ** (torch.arange(start=0, end=self.h_dim, step=2).float() / self.h_dim)))
|
28 |
+
|
29 |
+
def givens_rotation_matrix(self, n_state, i, j, theta):
|
30 |
+
G = torch.eye(n_state, device=theta.device)
|
31 |
+
G[i, i] = math.cos(theta)
|
32 |
+
G[i, j] = -math.sin(theta)
|
33 |
+
G[j, i] = math.sin(theta)
|
34 |
+
G[j, j] = math.cos(theta)
|
35 |
+
return G
|
36 |
+
|
37 |
+
def update_base(self, new_base):
|
38 |
+
self.base = float(new_base)
|
39 |
+
self.base = new_base
|
40 |
+
self.inv_freq = nn.Parameter(data=1.0 / (self.base ** (torch.arange(start=0, end=self.h_dim, step=2).float() / self.h_dim)))
|
41 |
+
|
42 |
+
def reset_parameters(self):
|
43 |
+
nn.init.orthogonal_(tensor=self.rotation_matrix)
|
44 |
+
nn.init.zeros_(tensor=self.thetas)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
if self.checkpointing:
|
48 |
+
return checkpoint(self._forward, x)
|
49 |
+
else:
|
50 |
+
return self._forward(x)
|
51 |
+
|
52 |
+
def _forward(self, x):
|
53 |
+
if x.dim() not in [3, 4]:
|
54 |
+
raise ValueError(f"Expected input tensor to be 3D or 4D, but got {x.dim()}D")
|
55 |
+
|
56 |
+
if x.dim() == 3:
|
57 |
+
batch_size, seq_len, n_state = x.size()
|
58 |
+
x = x.view(batch_size, seq_len, self.n_head, self.h_dim)
|
59 |
+
else:
|
60 |
+
batch_size, seq_len, n_head, h_dim = x.size()
|
61 |
+
if n_head != self.n_head or h_dim != self.h_dim:
|
62 |
+
raise ValueError(f"Expected n_head {self.n_head} and h_dim {self.h_dim}, but got n_head {n_head} and h_dim {h_dim}")
|
63 |
+
|
64 |
+
x = x.reshape(-1, self.h_dim)
|
65 |
+
|
66 |
+
for k in range(self.num_rotations):
|
67 |
+
i, j = self.rotation_pairs[k].long()
|
68 |
+
theta = self.thetas[k] * self.theta_scale
|
69 |
+
G = self.givens_rotation_matrix(n_state=self.h_dim, i=i, j=j, theta=theta)
|
70 |
+
x = torch.matmul(input=x, other=G)
|
71 |
+
|
72 |
+
x = torch.matmul(input=x, other=self.rotation_matrix)
|
73 |
+
x = x.view(batch_size, seq_len, self.n_head, self.h_dim)
|
74 |
+
|
75 |
+
sinusoid_inp = torch.einsum('i, j -> i j', torch.arange(end=seq_len, device=x.device), self.inv_freq.to(device=x.device))
|
76 |
+
sin = sinusoid_inp.sin()[None, :, None, :]
|
77 |
+
cos = sinusoid_inp.cos()[None, :, None, :]
|
78 |
+
|
79 |
+
x1, x2 = x[..., ::2], x[..., 1::2]
|
80 |
+
x = torch.cat(tensors=[x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
|
81 |
+
x = x.view(batch_size, seq_len, self.n_state)
|
82 |
+
return x
|