Changes for Core ML conversion
Browse files- modelling_RW.py +3 -3
modelling_RW.py
CHANGED
@@ -29,7 +29,7 @@ logger = logging.get_logger(__name__)
|
|
29 |
# In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
|
30 |
class Linear(nn.Linear):
|
31 |
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
32 |
-
ret = input @ self.weight.T
|
33 |
if self.bias is None:
|
34 |
return ret
|
35 |
else:
|
@@ -68,7 +68,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|
68 |
self,
|
69 |
seq_len: int,
|
70 |
device="cuda",
|
71 |
-
dtype=torch.
|
72 |
) -> torch.Tensor:
|
73 |
if seq_len != self.seq_len_cached:
|
74 |
self.seq_len_cached = seq_len
|
@@ -89,7 +89,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|
89 |
|
90 |
def forward(self, q, k):
|
91 |
batch, seq_len, head_dim = q.shape
|
92 |
-
cos, sin = self.cos_sin(seq_len, q.device
|
93 |
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
94 |
|
95 |
|
|
|
29 |
# In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
|
30 |
class Linear(nn.Linear):
|
31 |
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
32 |
+
ret = input @ self.weight.permute(1, 0) #transpose(0, 1) #.T
|
33 |
if self.bias is None:
|
34 |
return ret
|
35 |
else:
|
|
|
68 |
self,
|
69 |
seq_len: int,
|
70 |
device="cuda",
|
71 |
+
dtype=torch.float16,
|
72 |
) -> torch.Tensor:
|
73 |
if seq_len != self.seq_len_cached:
|
74 |
self.seq_len_cached = seq_len
|
|
|
89 |
|
90 |
def forward(self, q, k):
|
91 |
batch, seq_len, head_dim = q.shape
|
92 |
+
cos, sin = self.cos_sin(seq_len, q.device)
|
93 |
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
94 |
|
95 |
|