Update 3_rotations.py
Browse files- 3_rotations.py +1 -1
3_rotations.py
CHANGED
@@ -61,7 +61,7 @@ class CombinedRotaryEmbedding(nn.Module):
|
|
61 |
raise ValueError(f"Expected n_head {self.n_head} and h_dim {self.h_dim}, but got n_head {n_head} and h_dim {h_dim}")
|
62 |
|
63 |
# Flatten for rotation
|
64 |
-
x = x.
|
65 |
|
66 |
# Apply Givens rotations
|
67 |
for k in range(self.num_rotations):
|
|
|
61 |
raise ValueError(f"Expected n_head {self.n_head} and h_dim {self.h_dim}, but got n_head {n_head} and h_dim {h_dim}")
|
62 |
|
63 |
# Flatten for rotation
|
64 |
+
x = x.reshape(-1, self.h_dim)
|
65 |
|
66 |
# Apply Givens rotations
|
67 |
for k in range(self.num_rotations):
|