Sin2pi commited on
Commit
c10a7a1
·
verified ·
1 Parent(s): a9eb017

Update 3_rotations.py

Browse files
Files changed (1) hide show
  1. 3_rotations.py +1 -1
3_rotations.py CHANGED
@@ -61,7 +61,7 @@ class CombinedRotaryEmbedding(nn.Module):
61
  raise ValueError(f"Expected n_head {self.n_head} and h_dim {self.h_dim}, but got n_head {n_head} and h_dim {h_dim}")
62
 
63
  # Flatten for rotation
64
- x = x.view(-1, self.h_dim)
65
 
66
  # Apply Givens rotations
67
  for k in range(self.num_rotations):
 
61
  raise ValueError(f"Expected n_head {self.n_head} and h_dim {self.h_dim}, but got n_head {n_head} and h_dim {h_dim}")
62
 
63
  # Flatten for rotation
64
+ x = x.reshape(-1, self.h_dim)
65
 
66
  # Apply Givens rotations
67
  for k in range(self.num_rotations):