justinpinkney commited on
Commit
004fadc
·
1 Parent(s): 795ad25

fix bad copy

Browse files
Files changed (1) hide show
  1. modelling_RW.py +5 -0
modelling_RW.py CHANGED
@@ -89,6 +89,11 @@ class RotaryEmbedding(torch.nn.Module):
89
 
90
  return self.cos_cached, self.sin_cached
91
 
 
 
 
 
 
92
 
93
  def _make_causal_mask(
94
  input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
 
89
 
90
  return self.cos_cached, self.sin_cached
91
 
92
+ def forward(self, q, k, start_idx=0):
93
+ batch, seq_len, head_dim = q.shape
94
+ cos, sin = self.cos_sin(seq_len, q.device, q.dtype, start_idx=start_idx)
95
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
96
+
97
 
98
  def _make_causal_mask(
99
  input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int