Jackmin108
commited on
Commit
·
84441b5
1
Parent(s):
0007991
qk norm
Browse files- modeling_bert.py +6 -4
modeling_bert.py
CHANGED
@@ -280,6 +280,8 @@ class JinaBertSelfAttention(nn.Module):
|
|
280 |
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
281 |
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
282 |
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
|
|
|
|
283 |
|
284 |
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
285 |
self.position_embedding_type = position_embedding_type or getattr(
|
@@ -315,7 +317,7 @@ class JinaBertSelfAttention(nn.Module):
|
|
315 |
output_attentions: Optional[bool] = False,
|
316 |
bias: Optional[torch.FloatTensor] = None,
|
317 |
) -> Tuple[torch.Tensor]:
|
318 |
-
mixed_query_layer = self.query(hidden_states)
|
319 |
|
320 |
# If this is instantiated as a cross-attention module, the keys
|
321 |
# and values come from an encoder; the attention mask needs to be
|
@@ -328,16 +330,16 @@ class JinaBertSelfAttention(nn.Module):
|
|
328 |
value_layer = past_key_value[1]
|
329 |
attention_mask = encoder_attention_mask
|
330 |
elif is_cross_attention:
|
331 |
-
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
332 |
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
333 |
attention_mask = encoder_attention_mask
|
334 |
elif past_key_value is not None:
|
335 |
-
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
336 |
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
337 |
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
338 |
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
339 |
else:
|
340 |
-
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
341 |
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
342 |
|
343 |
query_layer = self.transpose_for_scores(mixed_query_layer)
|
|
|
280 |
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
281 |
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
282 |
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
283 |
+
self.layer_norm_q = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
284 |
+
self.layer_norm_k = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
285 |
|
286 |
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
287 |
self.position_embedding_type = position_embedding_type or getattr(
|
|
|
317 |
output_attentions: Optional[bool] = False,
|
318 |
bias: Optional[torch.FloatTensor] = None,
|
319 |
) -> Tuple[torch.Tensor]:
|
320 |
+
mixed_query_layer = self.layer_norm_q(self.query(hidden_states))
|
321 |
|
322 |
# If this is instantiated as a cross-attention module, the keys
|
323 |
# and values come from an encoder; the attention mask needs to be
|
|
|
330 |
value_layer = past_key_value[1]
|
331 |
attention_mask = encoder_attention_mask
|
332 |
elif is_cross_attention:
|
333 |
+
key_layer = self.transpose_for_scores(self.layer_norm_k(self.key(encoder_hidden_states)))
|
334 |
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
335 |
attention_mask = encoder_attention_mask
|
336 |
elif past_key_value is not None:
|
337 |
+
key_layer = self.transpose_for_scores(self.layer_norm_k(self.key(hidden_states)))
|
338 |
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
339 |
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
340 |
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
341 |
else:
|
342 |
+
key_layer = self.transpose_for_scores(self.layer_norm_k(self.key(hidden_states)))
|
343 |
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
344 |
|
345 |
query_layer = self.transpose_for_scores(mixed_query_layer)
|