pcuenq HF staff commited on
Commit
fe2d5bd
·
1 Parent(s): 2a6da03

Changes for Core ML conversion

Browse files
Files changed (1) hide show
  1. modelling_RW.py +3 -3
modelling_RW.py CHANGED
@@ -29,7 +29,7 @@ logger = logging.get_logger(__name__)
29
  # In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
30
  class Linear(nn.Linear):
31
  def forward(self, input: torch.Tensor) -> torch.Tensor:
32
- ret = input @ self.weight.T
33
  if self.bias is None:
34
  return ret
35
  else:
@@ -68,7 +68,7 @@ class RotaryEmbedding(torch.nn.Module):
68
  self,
69
  seq_len: int,
70
  device="cuda",
71
- dtype=torch.bfloat16,
72
  ) -> torch.Tensor:
73
  if seq_len != self.seq_len_cached:
74
  self.seq_len_cached = seq_len
@@ -89,7 +89,7 @@ class RotaryEmbedding(torch.nn.Module):
89
 
90
  def forward(self, q, k):
91
  batch, seq_len, head_dim = q.shape
92
- cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
93
  return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
94
 
95
 
 
29
  # In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
30
  class Linear(nn.Linear):
31
  def forward(self, input: torch.Tensor) -> torch.Tensor:
32
+ ret = input @ self.weight.permute(1, 0) #transpose(0, 1) #.T
33
  if self.bias is None:
34
  return ret
35
  else:
 
68
  self,
69
  seq_len: int,
70
  device="cuda",
71
+ dtype=torch.float16,
72
  ) -> torch.Tensor:
73
  if seq_len != self.seq_len_cached:
74
  self.seq_len_cached = seq_len
 
89
 
90
  def forward(self, q, k):
91
  batch, seq_len, head_dim = q.shape
92
+ cos, sin = self.cos_sin(seq_len, q.device)
93
  return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
94
 
95