Sin2pi commited on
Commit
e3cfe0c
·
verified ·
1 Parent(s): 22a6a6b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +75 -1
README.md CHANGED
@@ -5,4 +5,78 @@ Whisper like ASR model but with some advanced ideas. Experimental. Full script j
5
  I'm experimenting with some of the new stuff from the vision llm people but with audio.. Here is a super cool paper:
6
  https://www.frontiersin.org/journals/neuroscience/articles/10.3389/fnins.2022.949142/full
7
 
8
- Updated. Was having some issues there with the hybrid attention and tensor sharing.. fixed.!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  I'm experimenting with some of the new stuff from the vision llm people but with audio.. Here is a super cool paper:
6
  https://www.frontiersin.org/journals/neuroscience/articles/10.3389/fnins.2022.949142/full
7
 
8
+ Updated. Was having some issues there with the hybrid attention and tensor sharing.. fixed.!
9
+
10
+ Drop-in enhanced givens rotary block -- Its like a rubiks cube of embbedings :)
11
+
12
+
13
+ class CombinedRotaryEmbedding(nn.Module):
14
+ def __init__(self, n_state, n_head, num_rotations, base=10000, checkpointing=False):
15
+ super().__init__()
16
+ self.n_state = n_state
17
+ self.n_head = n_head
18
+ self.h_dim = n_state // n_head
19
+ self.num_rotations = num_rotations
20
+ self.base = base
21
+ self.checkpointing = checkpointing
22
+
23
+ self.thetas = nn.Parameter(torch.zeros(num_rotations))
24
+ self.rotation_pairs = nn.Parameter(data=torch.rand(num_rotations, 2) * self.h_dim)
25
+ self.theta_scale = nn.Parameter(data=torch.ones(1))
26
+ self.rotation_matrix = nn.Parameter(data=torch.eye(n=self.h_dim))
27
+ self.inv_freq = nn.Parameter(data=1.0 / (self.base ** (torch.arange(start=0, end=self.h_dim, step=2).float() / self.h_dim)))
28
+
29
+ def givens_rotation_matrix(self, n_state, i, j, theta):
30
+ G = torch.eye(n_state, device=theta.device)
31
+ G[i, i] = math.cos(theta)
32
+ G[i, j] = -math.sin(theta)
33
+ G[j, i] = math.sin(theta)
34
+ G[j, j] = math.cos(theta)
35
+ return G
36
+
37
+ def update_base(self, new_base):
38
+ self.base = float(new_base)
39
+ self.base = new_base
40
+ self.inv_freq = nn.Parameter(data=1.0 / (self.base ** (torch.arange(start=0, end=self.h_dim, step=2).float() / self.h_dim)))
41
+
42
+ def reset_parameters(self):
43
+ nn.init.orthogonal_(tensor=self.rotation_matrix)
44
+ nn.init.zeros_(tensor=self.thetas)
45
+
46
+ def forward(self, x):
47
+ if self.checkpointing:
48
+ return checkpoint(self._forward, x)
49
+ else:
50
+ return self._forward(x)
51
+
52
+ def _forward(self, x):
53
+ if x.dim() not in [3, 4]:
54
+ raise ValueError(f"Expected input tensor to be 3D or 4D, but got {x.dim()}D")
55
+
56
+ if x.dim() == 3:
57
+ batch_size, seq_len, n_state = x.size()
58
+ x = x.view(batch_size, seq_len, self.n_head, self.h_dim)
59
+ else:
60
+ batch_size, seq_len, n_head, h_dim = x.size()
61
+ if n_head != self.n_head or h_dim != self.h_dim:
62
+ raise ValueError(f"Expected n_head {self.n_head} and h_dim {self.h_dim}, but got n_head {n_head} and h_dim {h_dim}")
63
+
64
+ x = x.reshape(-1, self.h_dim)
65
+
66
+ for k in range(self.num_rotations):
67
+ i, j = self.rotation_pairs[k].long()
68
+ theta = self.thetas[k] * self.theta_scale
69
+ G = self.givens_rotation_matrix(n_state=self.h_dim, i=i, j=j, theta=theta)
70
+ x = torch.matmul(input=x, other=G)
71
+
72
+ x = torch.matmul(input=x, other=self.rotation_matrix)
73
+ x = x.view(batch_size, seq_len, self.n_head, self.h_dim)
74
+
75
+ sinusoid_inp = torch.einsum('i, j -> i j', torch.arange(end=seq_len, device=x.device), self.inv_freq.to(device=x.device))
76
+ sin = sinusoid_inp.sin()[None, :, None, :]
77
+ cos = sinusoid_inp.cos()[None, :, None, :]
78
+
79
+ x1, x2 = x[..., ::2], x[..., 1::2]
80
+ x = torch.cat(tensors=[x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
81
+ x = x.view(batch_size, seq_len, self.n_state)
82
+ return x