maxall4 commited on
Commit
b07b34f
·
verified ·
1 Parent(s): e84e201

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +3 -4
model.py CHANGED
@@ -73,13 +73,12 @@ class AttentionBlock(nn.Module):
73
  ): # workaround for masking bug in FA. This works because Wqkv does not have bias
74
  # and attention scores will be also automatically zeroed.
75
  u = u * padding_mask[..., None]
76
- u = (
77
- self.inner_mha_cls(
78
  self.pre_norm(u),
79
  inference_params=inference_params,
80
- )
81
- + u
82
  )
 
 
83
  if type(padding_mask) == torch.Tensor: # guard against bias
84
  u = u * padding_mask[..., None]
85
  u = self.mlp(self.post_norm(u)) + u
 
73
  ): # workaround for masking bug in FA. This works because Wqkv does not have bias
74
  # and attention scores will be also automatically zeroed.
75
  u = u * padding_mask[..., None]
76
+ w = self.inner_mha_cls(
 
77
  self.pre_norm(u),
78
  inference_params=inference_params,
 
 
79
  )
80
+ self.filter_output = w
81
+ u = w + u
82
  if type(padding_mask) == torch.Tensor: # guard against bias
83
  u = u * padding_mask[..., None]
84
  u = self.mlp(self.post_norm(u)) + u