Update model.py
Browse files
model.py
CHANGED
@@ -73,13 +73,12 @@ class AttentionBlock(nn.Module):
|
|
73 |
): # workaround for masking bug in FA. This works because Wqkv does not have bias
|
74 |
# and attention scores will be also automatically zeroed.
|
75 |
u = u * padding_mask[..., None]
|
76 |
-
|
77 |
-
self.inner_mha_cls(
|
78 |
self.pre_norm(u),
|
79 |
inference_params=inference_params,
|
80 |
-
)
|
81 |
-
+ u
|
82 |
)
|
|
|
|
|
83 |
if type(padding_mask) == torch.Tensor: # guard against bias
|
84 |
u = u * padding_mask[..., None]
|
85 |
u = self.mlp(self.post_norm(u)) + u
|
|
|
73 |
): # workaround for masking bug in FA. This works because Wqkv does not have bias
|
74 |
# and attention scores will be also automatically zeroed.
|
75 |
u = u * padding_mask[..., None]
|
76 |
+
w = self.inner_mha_cls(
|
|
|
77 |
self.pre_norm(u),
|
78 |
inference_params=inference_params,
|
|
|
|
|
79 |
)
|
80 |
+
self.filter_output = w
|
81 |
+
u = w + u
|
82 |
if type(padding_mask) == torch.Tensor: # guard against bias
|
83 |
u = u * padding_mask[..., None]
|
84 |
u = self.mlp(self.post_norm(u)) + u
|