Implement long former sliding window
#9
by
alaeddine-13
- opened
- modeling_bert.py +240 -214
modeling_bert.py
CHANGED
@@ -16,7 +16,6 @@
|
|
16 |
# limitations under the License.
|
17 |
"""PyTorch BERT model."""
|
18 |
|
19 |
-
|
20 |
import math
|
21 |
import os
|
22 |
import warnings
|
@@ -96,6 +95,15 @@ _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
|
|
96 |
_SEQ_CLASS_EXPECTED_LOSS = 0.01
|
97 |
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
100 |
"""Load tf checkpoints in a pytorch model."""
|
101 |
try:
|
@@ -126,15 +134,15 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
|
126 |
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
127 |
# which are not required for using pretrained model
|
128 |
if any(
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
):
|
139 |
logger.info(f"Skipping {'/'.join(name)}")
|
140 |
continue
|
@@ -214,12 +222,12 @@ class JinaBertEmbeddings(nn.Module):
|
|
214 |
)
|
215 |
|
216 |
def forward(
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
) -> torch.Tensor:
|
224 |
if input_ids is not None:
|
225 |
input_shape = input_ids.size()
|
@@ -230,8 +238,8 @@ class JinaBertEmbeddings(nn.Module):
|
|
230 |
|
231 |
if position_ids is None:
|
232 |
position_ids = self.position_ids[
|
233 |
-
|
234 |
-
|
235 |
|
236 |
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
237 |
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
@@ -265,13 +273,13 @@ class JinaBertSelfAttention(nn.Module):
|
|
265 |
def __init__(self, config: JinaBertConfig, position_embedding_type=None):
|
266 |
super().__init__()
|
267 |
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
268 |
-
|
269 |
):
|
270 |
raise ValueError(
|
271 |
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
272 |
f"heads ({config.num_attention_heads})"
|
273 |
)
|
274 |
-
|
275 |
self.attn_implementation = config.attn_implementation
|
276 |
self.num_attention_heads = config.num_attention_heads
|
277 |
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
@@ -286,8 +294,8 @@ class JinaBertSelfAttention(nn.Module):
|
|
286 |
config, "position_embedding_type", "absolute"
|
287 |
)
|
288 |
if (
|
289 |
-
|
290 |
-
|
291 |
):
|
292 |
self.max_position_embeddings = config.max_position_embeddings
|
293 |
self.distance_embedding = nn.Embedding(
|
@@ -305,15 +313,16 @@ class JinaBertSelfAttention(nn.Module):
|
|
305 |
return x.permute(0, 2, 1, 3)
|
306 |
|
307 |
def forward(
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
|
|
317 |
) -> Tuple[torch.Tensor]:
|
318 |
mixed_query_layer = self.query(hidden_states)
|
319 |
|
@@ -364,8 +373,8 @@ class JinaBertSelfAttention(nn.Module):
|
|
364 |
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
365 |
|
366 |
if (
|
367 |
-
|
368 |
-
|
369 |
):
|
370 |
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
371 |
if use_cache:
|
@@ -401,9 +410,9 @@ class JinaBertSelfAttention(nn.Module):
|
|
401 |
"bhrd,lrd->bhlr", key_layer, positional_embedding
|
402 |
)
|
403 |
attention_scores = (
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
)
|
408 |
|
409 |
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
@@ -414,6 +423,10 @@ class JinaBertSelfAttention(nn.Module):
|
|
414 |
# Normalize the attention scores to probabilities.
|
415 |
attention_probs = nn.functional.softmax(attention_scores + bias, dim=-1)
|
416 |
|
|
|
|
|
|
|
|
|
417 |
# This is actually dropping out entire tokens to attend to, which might
|
418 |
# seem a bit unusual, but is taken from the original Transformer paper.
|
419 |
attention_probs = self.dropout(attention_probs)
|
@@ -445,7 +458,7 @@ class JinaBertSelfOutput(nn.Module):
|
|
445 |
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
446 |
|
447 |
def forward(
|
448 |
-
|
449 |
) -> torch.Tensor:
|
450 |
hidden_states = self.dense(hidden_states)
|
451 |
hidden_states = self.dropout(hidden_states)
|
@@ -481,20 +494,21 @@ class JinaBertAttention(nn.Module):
|
|
481 |
# Update hyper params and store pruned heads
|
482 |
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
483 |
self.self.all_head_size = (
|
484 |
-
|
485 |
)
|
486 |
self.pruned_heads = self.pruned_heads.union(heads)
|
487 |
|
488 |
def forward(
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
|
|
498 |
) -> Tuple[torch.Tensor]:
|
499 |
self_outputs = self.self(
|
500 |
hidden_states,
|
@@ -505,11 +519,12 @@ class JinaBertAttention(nn.Module):
|
|
505 |
past_key_value,
|
506 |
output_attentions,
|
507 |
bias,
|
|
|
508 |
)
|
509 |
attention_output = self.output(self_outputs[0], hidden_states)
|
510 |
outputs = (attention_output,) + self_outputs[
|
511 |
-
|
512 |
-
|
513 |
return outputs
|
514 |
|
515 |
|
@@ -536,7 +551,7 @@ class JinaBertOutput(nn.Module):
|
|
536 |
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
537 |
|
538 |
def forward(
|
539 |
-
|
540 |
) -> torch.Tensor:
|
541 |
hidden_states = self.dense(hidden_states)
|
542 |
hidden_states = self.dropout(hidden_states)
|
@@ -568,7 +583,7 @@ class JinaBertGLUMLP(nn.Module):
|
|
568 |
# compute the activation
|
569 |
hidden_states = self.gated_layers(hidden_states)
|
570 |
gated = hidden_states[:, :, : self.config.intermediate_size]
|
571 |
-
non_gated = hidden_states[:, :, self.config.intermediate_size
|
572 |
hidden_states = self.act(gated) * non_gated
|
573 |
hidden_states = self.dropout(hidden_states)
|
574 |
# multiply by the second matrix
|
@@ -602,15 +617,16 @@ class JinaBertLayer(nn.Module):
|
|
602 |
self.output = JinaBertOutput(config)
|
603 |
|
604 |
def forward(
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
|
|
614 |
) -> Tuple[torch.Tensor]:
|
615 |
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
616 |
self_attn_past_key_value = (
|
@@ -623,6 +639,7 @@ class JinaBertLayer(nn.Module):
|
|
623 |
output_attentions=output_attentions,
|
624 |
past_key_value=self_attn_past_key_value,
|
625 |
bias=bias,
|
|
|
626 |
)
|
627 |
attention_output = self_attention_outputs[0]
|
628 |
|
@@ -632,8 +649,8 @@ class JinaBertLayer(nn.Module):
|
|
632 |
present_key_value = self_attention_outputs[-1]
|
633 |
else:
|
634 |
outputs = self_attention_outputs[
|
635 |
-
|
636 |
-
|
637 |
|
638 |
cross_attn_present_key_value = None
|
639 |
if self.is_decoder and encoder_hidden_states is not None:
|
@@ -658,7 +675,7 @@ class JinaBertLayer(nn.Module):
|
|
658 |
)
|
659 |
attention_output = cross_attention_outputs[0]
|
660 |
outputs = (
|
661 |
-
|
662 |
) # add cross attentions if we output attention weights
|
663 |
|
664 |
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
@@ -704,7 +721,7 @@ class JinaBertEncoder(nn.Module):
|
|
704 |
)
|
705 |
|
706 |
def rebuild_alibi_tensor(
|
707 |
-
|
708 |
):
|
709 |
# Alibi
|
710 |
# Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1)
|
@@ -717,7 +734,7 @@ class JinaBertEncoder(nn.Module):
|
|
717 |
def get_slopes_power_of_2(n):
|
718 |
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
719 |
ratio = start
|
720 |
-
return [start * ratio**i for i in range(n)]
|
721 |
|
722 |
if math.log2(n_heads).is_integer():
|
723 |
return get_slopes_power_of_2(
|
@@ -728,10 +745,10 @@ class JinaBertEncoder(nn.Module):
|
|
728 |
math.log2(n_heads)
|
729 |
) # when the number of heads is not a power of 2, we use this workaround.
|
730 |
return (
|
731 |
-
|
732 |
-
|
733 |
-
|
734 |
-
|
735 |
)
|
736 |
|
737 |
context_position = torch.arange(size, device=device)[:, None]
|
@@ -749,17 +766,18 @@ class JinaBertEncoder(nn.Module):
|
|
749 |
return alibi
|
750 |
|
751 |
def forward(
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
|
|
763 |
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
764 |
all_hidden_states = () if output_hidden_states else None
|
765 |
all_self_attentions = () if output_attentions else None
|
@@ -828,6 +846,7 @@ class JinaBertEncoder(nn.Module):
|
|
828 |
alibi_bias,
|
829 |
past_key_value,
|
830 |
output_attentions,
|
|
|
831 |
)
|
832 |
|
833 |
hidden_states = layer_outputs[0]
|
@@ -1117,16 +1136,17 @@ class JinaBertModel(JinaBertPreTrainedModel):
|
|
1117 |
|
1118 |
@torch.inference_mode()
|
1119 |
def encode(
|
1120 |
-
|
1121 |
-
|
1122 |
-
|
1123 |
-
|
1124 |
-
|
1125 |
-
|
1126 |
-
|
1127 |
-
|
1128 |
-
|
1129 |
-
|
|
|
1130 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
1131 |
"""
|
1132 |
Computes sentence embeddings
|
@@ -1172,8 +1192,8 @@ class JinaBertModel(JinaBertPreTrainedModel):
|
|
1172 |
|
1173 |
if show_progress_bar is None:
|
1174 |
show_progress_bar = (
|
1175 |
-
|
1176 |
-
|
1177 |
)
|
1178 |
|
1179 |
if convert_to_tensor:
|
@@ -1215,11 +1235,11 @@ class JinaBertModel(JinaBertPreTrainedModel):
|
|
1215 |
|
1216 |
for i in range_iter:
|
1217 |
encoded_input = self.tokenizer(
|
1218 |
-
sentences[i
|
1219 |
return_tensors='pt',
|
1220 |
**tokenizer_kwargs,
|
1221 |
).to(self.device)
|
1222 |
-
token_embs = self.forward(**encoded_input)[0]
|
1223 |
|
1224 |
# Accumulate in fp32 to avoid overflow
|
1225 |
token_embs = token_embs.float()
|
@@ -1254,7 +1274,7 @@ class JinaBertModel(JinaBertPreTrainedModel):
|
|
1254 |
return all_embeddings
|
1255 |
|
1256 |
def mean_pooling(
|
1257 |
-
|
1258 |
):
|
1259 |
input_mask_expanded = (
|
1260 |
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
@@ -1286,20 +1306,21 @@ class JinaBertModel(JinaBertPreTrainedModel):
|
|
1286 |
config_class=_CONFIG_FOR_DOC,
|
1287 |
)
|
1288 |
def forward(
|
1289 |
-
|
1290 |
-
|
1291 |
-
|
1292 |
-
|
1293 |
-
|
1294 |
-
|
1295 |
-
|
1296 |
-
|
1297 |
-
|
1298 |
-
|
1299 |
-
|
1300 |
-
|
1301 |
-
|
1302 |
-
|
|
|
1303 |
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
1304 |
r"""
|
1305 |
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
@@ -1425,6 +1446,7 @@ class JinaBertModel(JinaBertPreTrainedModel):
|
|
1425 |
output_attentions=output_attentions,
|
1426 |
output_hidden_states=output_hidden_states,
|
1427 |
return_dict=return_dict,
|
|
|
1428 |
)
|
1429 |
sequence_output = encoder_outputs[0]
|
1430 |
pooled_output = (
|
@@ -1476,18 +1498,19 @@ class JinaBertForPreTraining(JinaBertPreTrainedModel):
|
|
1476 |
output_type=JinaBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
|
1477 |
)
|
1478 |
def forward(
|
1479 |
-
|
1480 |
-
|
1481 |
-
|
1482 |
-
|
1483 |
-
|
1484 |
-
|
1485 |
-
|
1486 |
-
|
1487 |
-
|
1488 |
-
|
1489 |
-
|
1490 |
-
|
|
|
1491 |
) -> Union[Tuple[torch.Tensor], JinaBertForPreTrainingOutput]:
|
1492 |
r"""
|
1493 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
@@ -1519,6 +1542,7 @@ class JinaBertForPreTraining(JinaBertPreTrainedModel):
|
|
1519 |
output_attentions=output_attentions,
|
1520 |
output_hidden_states=output_hidden_states,
|
1521 |
return_dict=return_dict,
|
|
|
1522 |
)
|
1523 |
|
1524 |
sequence_output, pooled_output = outputs[:2]
|
@@ -1586,21 +1610,21 @@ class JinaBertLMHeadModel(JinaBertPreTrainedModel):
|
|
1586 |
config_class=_CONFIG_FOR_DOC,
|
1587 |
)
|
1588 |
def forward(
|
1589 |
-
|
1590 |
-
|
1591 |
-
|
1592 |
-
|
1593 |
-
|
1594 |
-
|
1595 |
-
|
1596 |
-
|
1597 |
-
|
1598 |
-
|
1599 |
-
|
1600 |
-
|
1601 |
-
|
1602 |
-
|
1603 |
-
|
1604 |
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
1605 |
r"""
|
1606 |
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
@@ -1676,12 +1700,12 @@ class JinaBertLMHeadModel(JinaBertPreTrainedModel):
|
|
1676 |
)
|
1677 |
|
1678 |
def prepare_inputs_for_generation(
|
1679 |
-
|
1680 |
-
|
1681 |
-
|
1682 |
-
|
1683 |
-
|
1684 |
-
|
1685 |
):
|
1686 |
input_shape = input_ids.shape
|
1687 |
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
@@ -1748,19 +1772,20 @@ class JinaBertForMaskedLM(JinaBertPreTrainedModel):
|
|
1748 |
expected_loss=0.88,
|
1749 |
)
|
1750 |
def forward(
|
1751 |
-
|
1752 |
-
|
1753 |
-
|
1754 |
-
|
1755 |
-
|
1756 |
-
|
1757 |
-
|
1758 |
-
|
1759 |
-
|
1760 |
-
|
1761 |
-
|
1762 |
-
|
1763 |
-
|
|
|
1764 |
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
|
1765 |
r"""
|
1766 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
@@ -1785,6 +1810,7 @@ class JinaBertForMaskedLM(JinaBertPreTrainedModel):
|
|
1785 |
output_attentions=output_attentions,
|
1786 |
output_hidden_states=output_hidden_states,
|
1787 |
return_dict=return_dict,
|
|
|
1788 |
)
|
1789 |
|
1790 |
sequence_output = outputs[0]
|
@@ -1811,7 +1837,7 @@ class JinaBertForMaskedLM(JinaBertPreTrainedModel):
|
|
1811 |
)
|
1812 |
|
1813 |
def prepare_inputs_for_generation(
|
1814 |
-
|
1815 |
):
|
1816 |
input_shape = input_ids.shape
|
1817 |
effective_batch_size = input_shape[0]
|
@@ -1856,18 +1882,18 @@ class JinaBertForNextSentencePrediction(JinaBertPreTrainedModel):
|
|
1856 |
output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC
|
1857 |
)
|
1858 |
def forward(
|
1859 |
-
|
1860 |
-
|
1861 |
-
|
1862 |
-
|
1863 |
-
|
1864 |
-
|
1865 |
-
|
1866 |
-
|
1867 |
-
|
1868 |
-
|
1869 |
-
|
1870 |
-
|
1871 |
) -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
|
1872 |
r"""
|
1873 |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
@@ -1967,17 +1993,17 @@ class JinaBertForSequenceClassification(JinaBertPreTrainedModel):
|
|
1967 |
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
1968 |
)
|
1969 |
def forward(
|
1970 |
-
|
1971 |
-
|
1972 |
-
|
1973 |
-
|
1974 |
-
|
1975 |
-
|
1976 |
-
|
1977 |
-
|
1978 |
-
|
1979 |
-
|
1980 |
-
|
1981 |
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
1982 |
r"""
|
1983 |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
@@ -2012,7 +2038,7 @@ class JinaBertForSequenceClassification(JinaBertPreTrainedModel):
|
|
2012 |
if self.num_labels == 1:
|
2013 |
self.config.problem_type = "regression"
|
2014 |
elif self.num_labels > 1 and (
|
2015 |
-
|
2016 |
):
|
2017 |
self.config.problem_type = "single_label_classification"
|
2018 |
else:
|
@@ -2074,17 +2100,17 @@ class JinaBertForMultipleChoice(JinaBertPreTrainedModel):
|
|
2074 |
config_class=_CONFIG_FOR_DOC,
|
2075 |
)
|
2076 |
def forward(
|
2077 |
-
|
2078 |
-
|
2079 |
-
|
2080 |
-
|
2081 |
-
|
2082 |
-
|
2083 |
-
|
2084 |
-
|
2085 |
-
|
2086 |
-
|
2087 |
-
|
2088 |
) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
|
2089 |
r"""
|
2090 |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
@@ -2193,17 +2219,17 @@ class JinaBertForTokenClassification(JinaBertPreTrainedModel):
|
|
2193 |
expected_loss=_TOKEN_CLASS_EXPECTED_LOSS,
|
2194 |
)
|
2195 |
def forward(
|
2196 |
-
|
2197 |
-
|
2198 |
-
|
2199 |
-
|
2200 |
-
|
2201 |
-
|
2202 |
-
|
2203 |
-
|
2204 |
-
|
2205 |
-
|
2206 |
-
|
2207 |
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
|
2208 |
r"""
|
2209 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
@@ -2278,18 +2304,18 @@ class JinaBertForQuestionAnswering(JinaBertPreTrainedModel):
|
|
2278 |
expected_loss=_QA_EXPECTED_LOSS,
|
2279 |
)
|
2280 |
def forward(
|
2281 |
-
|
2282 |
-
|
2283 |
-
|
2284 |
-
|
2285 |
-
|
2286 |
-
|
2287 |
-
|
2288 |
-
|
2289 |
-
|
2290 |
-
|
2291 |
-
|
2292 |
-
|
2293 |
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
2294 |
r"""
|
2295 |
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
|
16 |
# limitations under the License.
|
17 |
"""PyTorch BERT model."""
|
18 |
|
|
|
19 |
import math
|
20 |
import os
|
21 |
import warnings
|
|
|
95 |
_SEQ_CLASS_EXPECTED_LOSS = 0.01
|
96 |
|
97 |
|
98 |
+
def create_k_diag_mask(k, n):
|
99 |
+
mask = torch.zeros(n, n, dtype=bool)
|
100 |
+
for i in range(n):
|
101 |
+
for j in range(n):
|
102 |
+
if not math.fabs(i - j) < k:
|
103 |
+
mask[i, j] = True
|
104 |
+
return mask
|
105 |
+
|
106 |
+
|
107 |
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
108 |
"""Load tf checkpoints in a pytorch model."""
|
109 |
try:
|
|
|
134 |
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
135 |
# which are not required for using pretrained model
|
136 |
if any(
|
137 |
+
n
|
138 |
+
in [
|
139 |
+
"adam_v",
|
140 |
+
"adam_m",
|
141 |
+
"AdamWeightDecayOptimizer",
|
142 |
+
"AdamWeightDecayOptimizer_1",
|
143 |
+
"global_step",
|
144 |
+
]
|
145 |
+
for n in name
|
146 |
):
|
147 |
logger.info(f"Skipping {'/'.join(name)}")
|
148 |
continue
|
|
|
222 |
)
|
223 |
|
224 |
def forward(
|
225 |
+
self,
|
226 |
+
input_ids: Optional[torch.LongTensor] = None,
|
227 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
228 |
+
position_ids: Optional[torch.LongTensor] = None,
|
229 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
230 |
+
past_key_values_length: int = 0,
|
231 |
) -> torch.Tensor:
|
232 |
if input_ids is not None:
|
233 |
input_shape = input_ids.size()
|
|
|
238 |
|
239 |
if position_ids is None:
|
240 |
position_ids = self.position_ids[
|
241 |
+
:, past_key_values_length: seq_length + past_key_values_length
|
242 |
+
]
|
243 |
|
244 |
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
245 |
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
|
|
273 |
def __init__(self, config: JinaBertConfig, position_embedding_type=None):
|
274 |
super().__init__()
|
275 |
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
276 |
+
config, "embedding_size"
|
277 |
):
|
278 |
raise ValueError(
|
279 |
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
280 |
f"heads ({config.num_attention_heads})"
|
281 |
)
|
282 |
+
|
283 |
self.attn_implementation = config.attn_implementation
|
284 |
self.num_attention_heads = config.num_attention_heads
|
285 |
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
|
|
294 |
config, "position_embedding_type", "absolute"
|
295 |
)
|
296 |
if (
|
297 |
+
self.position_embedding_type == "relative_key"
|
298 |
+
or self.position_embedding_type == "relative_key_query"
|
299 |
):
|
300 |
self.max_position_embeddings = config.max_position_embeddings
|
301 |
self.distance_embedding = nn.Embedding(
|
|
|
313 |
return x.permute(0, 2, 1, 3)
|
314 |
|
315 |
def forward(
|
316 |
+
self,
|
317 |
+
hidden_states: torch.Tensor,
|
318 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
319 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
320 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
321 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
322 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
323 |
+
output_attentions: Optional[bool] = False,
|
324 |
+
bias: Optional[torch.FloatTensor] = None,
|
325 |
+
sliding_window: Optional[int] = None,
|
326 |
) -> Tuple[torch.Tensor]:
|
327 |
mixed_query_layer = self.query(hidden_states)
|
328 |
|
|
|
373 |
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
374 |
|
375 |
if (
|
376 |
+
self.position_embedding_type == "relative_key"
|
377 |
+
or self.position_embedding_type == "relative_key_query"
|
378 |
):
|
379 |
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
380 |
if use_cache:
|
|
|
410 |
"bhrd,lrd->bhlr", key_layer, positional_embedding
|
411 |
)
|
412 |
attention_scores = (
|
413 |
+
attention_scores
|
414 |
+
+ relative_position_scores_query
|
415 |
+
+ relative_position_scores_key
|
416 |
)
|
417 |
|
418 |
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
|
|
423 |
# Normalize the attention scores to probabilities.
|
424 |
attention_probs = nn.functional.softmax(attention_scores + bias, dim=-1)
|
425 |
|
426 |
+
if sliding_window is not None:
|
427 |
+
mask = create_k_diag_mask(sliding_window, int(attention_scores.size(dim=2)))
|
428 |
+
attention_probs.masked_fill_(mask, 0)
|
429 |
+
|
430 |
# This is actually dropping out entire tokens to attend to, which might
|
431 |
# seem a bit unusual, but is taken from the original Transformer paper.
|
432 |
attention_probs = self.dropout(attention_probs)
|
|
|
458 |
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
459 |
|
460 |
def forward(
|
461 |
+
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
|
462 |
) -> torch.Tensor:
|
463 |
hidden_states = self.dense(hidden_states)
|
464 |
hidden_states = self.dropout(hidden_states)
|
|
|
494 |
# Update hyper params and store pruned heads
|
495 |
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
496 |
self.self.all_head_size = (
|
497 |
+
self.self.attention_head_size * self.self.num_attention_heads
|
498 |
)
|
499 |
self.pruned_heads = self.pruned_heads.union(heads)
|
500 |
|
501 |
def forward(
|
502 |
+
self,
|
503 |
+
hidden_states: torch.Tensor,
|
504 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
505 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
506 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
507 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
508 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
509 |
+
output_attentions: Optional[bool] = False,
|
510 |
+
bias: Optional[torch.FloatTensor] = None,
|
511 |
+
sliding_window: Optional[int] = None,
|
512 |
) -> Tuple[torch.Tensor]:
|
513 |
self_outputs = self.self(
|
514 |
hidden_states,
|
|
|
519 |
past_key_value,
|
520 |
output_attentions,
|
521 |
bias,
|
522 |
+
sliding_window=sliding_window
|
523 |
)
|
524 |
attention_output = self.output(self_outputs[0], hidden_states)
|
525 |
outputs = (attention_output,) + self_outputs[
|
526 |
+
1:
|
527 |
+
] # add attentions if we output them
|
528 |
return outputs
|
529 |
|
530 |
|
|
|
551 |
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
552 |
|
553 |
def forward(
|
554 |
+
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
|
555 |
) -> torch.Tensor:
|
556 |
hidden_states = self.dense(hidden_states)
|
557 |
hidden_states = self.dropout(hidden_states)
|
|
|
583 |
# compute the activation
|
584 |
hidden_states = self.gated_layers(hidden_states)
|
585 |
gated = hidden_states[:, :, : self.config.intermediate_size]
|
586 |
+
non_gated = hidden_states[:, :, self.config.intermediate_size:]
|
587 |
hidden_states = self.act(gated) * non_gated
|
588 |
hidden_states = self.dropout(hidden_states)
|
589 |
# multiply by the second matrix
|
|
|
617 |
self.output = JinaBertOutput(config)
|
618 |
|
619 |
def forward(
|
620 |
+
self,
|
621 |
+
hidden_states: torch.Tensor,
|
622 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
623 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
624 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
625 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
626 |
+
bias: Optional[torch.FloatTensor] = None,
|
627 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
628 |
+
output_attentions: Optional[bool] = False,
|
629 |
+
sliding_window: Optional[int] = None,
|
630 |
) -> Tuple[torch.Tensor]:
|
631 |
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
632 |
self_attn_past_key_value = (
|
|
|
639 |
output_attentions=output_attentions,
|
640 |
past_key_value=self_attn_past_key_value,
|
641 |
bias=bias,
|
642 |
+
sliding_window=sliding_window
|
643 |
)
|
644 |
attention_output = self_attention_outputs[0]
|
645 |
|
|
|
649 |
present_key_value = self_attention_outputs[-1]
|
650 |
else:
|
651 |
outputs = self_attention_outputs[
|
652 |
+
1:
|
653 |
+
] # add self attentions if we output attention weights
|
654 |
|
655 |
cross_attn_present_key_value = None
|
656 |
if self.is_decoder and encoder_hidden_states is not None:
|
|
|
675 |
)
|
676 |
attention_output = cross_attention_outputs[0]
|
677 |
outputs = (
|
678 |
+
outputs + cross_attention_outputs[1:-1]
|
679 |
) # add cross attentions if we output attention weights
|
680 |
|
681 |
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
|
|
721 |
)
|
722 |
|
723 |
def rebuild_alibi_tensor(
|
724 |
+
self, size: int, device: Optional[Union[torch.device, str]] = None
|
725 |
):
|
726 |
# Alibi
|
727 |
# Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1)
|
|
|
734 |
def get_slopes_power_of_2(n):
|
735 |
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
736 |
ratio = start
|
737 |
+
return [start * ratio ** i for i in range(n)]
|
738 |
|
739 |
if math.log2(n_heads).is_integer():
|
740 |
return get_slopes_power_of_2(
|
|
|
745 |
math.log2(n_heads)
|
746 |
) # when the number of heads is not a power of 2, we use this workaround.
|
747 |
return (
|
748 |
+
get_slopes_power_of_2(closest_power_of_2)
|
749 |
+
+ _get_alibi_head_slopes(2 * closest_power_of_2)[0::2][
|
750 |
+
: n_heads - closest_power_of_2
|
751 |
+
]
|
752 |
)
|
753 |
|
754 |
context_position = torch.arange(size, device=device)[:, None]
|
|
|
766 |
return alibi
|
767 |
|
768 |
def forward(
|
769 |
+
self,
|
770 |
+
hidden_states: torch.Tensor,
|
771 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
772 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
773 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
774 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
775 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
776 |
+
use_cache: Optional[bool] = None,
|
777 |
+
output_attentions: Optional[bool] = False,
|
778 |
+
output_hidden_states: Optional[bool] = False,
|
779 |
+
return_dict: Optional[bool] = True,
|
780 |
+
sliding_window: Optional[int] = None,
|
781 |
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
782 |
all_hidden_states = () if output_hidden_states else None
|
783 |
all_self_attentions = () if output_attentions else None
|
|
|
846 |
alibi_bias,
|
847 |
past_key_value,
|
848 |
output_attentions,
|
849 |
+
sliding_window
|
850 |
)
|
851 |
|
852 |
hidden_states = layer_outputs[0]
|
|
|
1136 |
|
1137 |
@torch.inference_mode()
|
1138 |
def encode(
|
1139 |
+
self: 'JinaBertModel',
|
1140 |
+
sentences: Union[str, List[str]],
|
1141 |
+
batch_size: int = 32,
|
1142 |
+
show_progress_bar: Optional[bool] = None,
|
1143 |
+
output_value: str = 'sentence_embedding',
|
1144 |
+
convert_to_numpy: bool = True,
|
1145 |
+
convert_to_tensor: bool = False,
|
1146 |
+
device: Optional[torch.device] = None,
|
1147 |
+
normalize_embeddings: bool = False,
|
1148 |
+
sliding_window: Optional[int] = None,
|
1149 |
+
**tokenizer_kwargs,
|
1150 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
1151 |
"""
|
1152 |
Computes sentence embeddings
|
|
|
1192 |
|
1193 |
if show_progress_bar is None:
|
1194 |
show_progress_bar = (
|
1195 |
+
logger.getEffectiveLevel() == logging.INFO
|
1196 |
+
or logger.getEffectiveLevel() == logging.DEBUG
|
1197 |
)
|
1198 |
|
1199 |
if convert_to_tensor:
|
|
|
1235 |
|
1236 |
for i in range_iter:
|
1237 |
encoded_input = self.tokenizer(
|
1238 |
+
sentences[i: i + batch_size],
|
1239 |
return_tensors='pt',
|
1240 |
**tokenizer_kwargs,
|
1241 |
).to(self.device)
|
1242 |
+
token_embs = self.forward(sliding_window=sliding_window, **encoded_input)[0]
|
1243 |
|
1244 |
# Accumulate in fp32 to avoid overflow
|
1245 |
token_embs = token_embs.float()
|
|
|
1274 |
return all_embeddings
|
1275 |
|
1276 |
def mean_pooling(
|
1277 |
+
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
1278 |
):
|
1279 |
input_mask_expanded = (
|
1280 |
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
|
|
1306 |
config_class=_CONFIG_FOR_DOC,
|
1307 |
)
|
1308 |
def forward(
|
1309 |
+
self,
|
1310 |
+
input_ids: Optional[torch.Tensor] = None,
|
1311 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1312 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
1313 |
+
position_ids: Optional[torch.Tensor] = None,
|
1314 |
+
head_mask: Optional[torch.Tensor] = None,
|
1315 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1316 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1317 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
1318 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1319 |
+
use_cache: Optional[bool] = None,
|
1320 |
+
output_attentions: Optional[bool] = None,
|
1321 |
+
output_hidden_states: Optional[bool] = None,
|
1322 |
+
return_dict: Optional[bool] = None,
|
1323 |
+
sliding_window: Optional[int] = None,
|
1324 |
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
1325 |
r"""
|
1326 |
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
|
1446 |
output_attentions=output_attentions,
|
1447 |
output_hidden_states=output_hidden_states,
|
1448 |
return_dict=return_dict,
|
1449 |
+
sliding_window=sliding_window
|
1450 |
)
|
1451 |
sequence_output = encoder_outputs[0]
|
1452 |
pooled_output = (
|
|
|
1498 |
output_type=JinaBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
|
1499 |
)
|
1500 |
def forward(
|
1501 |
+
self,
|
1502 |
+
input_ids: Optional[torch.Tensor] = None,
|
1503 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1504 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
1505 |
+
position_ids: Optional[torch.Tensor] = None,
|
1506 |
+
head_mask: Optional[torch.Tensor] = None,
|
1507 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1508 |
+
labels: Optional[torch.Tensor] = None,
|
1509 |
+
next_sentence_label: Optional[torch.Tensor] = None,
|
1510 |
+
output_attentions: Optional[bool] = None,
|
1511 |
+
output_hidden_states: Optional[bool] = None,
|
1512 |
+
return_dict: Optional[bool] = None,
|
1513 |
+
sliding_window: Optional[int] = None,
|
1514 |
) -> Union[Tuple[torch.Tensor], JinaBertForPreTrainingOutput]:
|
1515 |
r"""
|
1516 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
|
1542 |
output_attentions=output_attentions,
|
1543 |
output_hidden_states=output_hidden_states,
|
1544 |
return_dict=return_dict,
|
1545 |
+
sliding_window=sliding_window
|
1546 |
)
|
1547 |
|
1548 |
sequence_output, pooled_output = outputs[:2]
|
|
|
1610 |
config_class=_CONFIG_FOR_DOC,
|
1611 |
)
|
1612 |
def forward(
|
1613 |
+
self,
|
1614 |
+
input_ids: Optional[torch.Tensor] = None,
|
1615 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1616 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
1617 |
+
position_ids: Optional[torch.Tensor] = None,
|
1618 |
+
head_mask: Optional[torch.Tensor] = None,
|
1619 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1620 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1621 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
1622 |
+
labels: Optional[torch.Tensor] = None,
|
1623 |
+
past_key_values: Optional[List[torch.Tensor]] = None,
|
1624 |
+
use_cache: Optional[bool] = None,
|
1625 |
+
output_attentions: Optional[bool] = None,
|
1626 |
+
output_hidden_states: Optional[bool] = None,
|
1627 |
+
return_dict: Optional[bool] = None,
|
1628 |
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
1629 |
r"""
|
1630 |
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
|
1700 |
)
|
1701 |
|
1702 |
def prepare_inputs_for_generation(
|
1703 |
+
self,
|
1704 |
+
input_ids,
|
1705 |
+
past_key_values=None,
|
1706 |
+
attention_mask=None,
|
1707 |
+
use_cache=True,
|
1708 |
+
**model_kwargs,
|
1709 |
):
|
1710 |
input_shape = input_ids.shape
|
1711 |
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
|
|
1772 |
expected_loss=0.88,
|
1773 |
)
|
1774 |
def forward(
|
1775 |
+
self,
|
1776 |
+
input_ids: Optional[torch.Tensor] = None,
|
1777 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1778 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
1779 |
+
position_ids: Optional[torch.Tensor] = None,
|
1780 |
+
head_mask: Optional[torch.Tensor] = None,
|
1781 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1782 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1783 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
1784 |
+
labels: Optional[torch.Tensor] = None,
|
1785 |
+
output_attentions: Optional[bool] = None,
|
1786 |
+
output_hidden_states: Optional[bool] = None,
|
1787 |
+
return_dict: Optional[bool] = None,
|
1788 |
+
sliding_window: Optional[int] = None,
|
1789 |
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
|
1790 |
r"""
|
1791 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
|
1810 |
output_attentions=output_attentions,
|
1811 |
output_hidden_states=output_hidden_states,
|
1812 |
return_dict=return_dict,
|
1813 |
+
sliding_window=sliding_window
|
1814 |
)
|
1815 |
|
1816 |
sequence_output = outputs[0]
|
|
|
1837 |
)
|
1838 |
|
1839 |
def prepare_inputs_for_generation(
|
1840 |
+
self, input_ids, attention_mask=None, **model_kwargs
|
1841 |
):
|
1842 |
input_shape = input_ids.shape
|
1843 |
effective_batch_size = input_shape[0]
|
|
|
1882 |
output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC
|
1883 |
)
|
1884 |
def forward(
|
1885 |
+
self,
|
1886 |
+
input_ids: Optional[torch.Tensor] = None,
|
1887 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1888 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
1889 |
+
position_ids: Optional[torch.Tensor] = None,
|
1890 |
+
head_mask: Optional[torch.Tensor] = None,
|
1891 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1892 |
+
labels: Optional[torch.Tensor] = None,
|
1893 |
+
output_attentions: Optional[bool] = None,
|
1894 |
+
output_hidden_states: Optional[bool] = None,
|
1895 |
+
return_dict: Optional[bool] = None,
|
1896 |
+
**kwargs,
|
1897 |
) -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
|
1898 |
r"""
|
1899 |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
|
1993 |
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
1994 |
)
|
1995 |
def forward(
|
1996 |
+
self,
|
1997 |
+
input_ids: Optional[torch.Tensor] = None,
|
1998 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1999 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
2000 |
+
position_ids: Optional[torch.Tensor] = None,
|
2001 |
+
head_mask: Optional[torch.Tensor] = None,
|
2002 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
2003 |
+
labels: Optional[torch.Tensor] = None,
|
2004 |
+
output_attentions: Optional[bool] = None,
|
2005 |
+
output_hidden_states: Optional[bool] = None,
|
2006 |
+
return_dict: Optional[bool] = None,
|
2007 |
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
2008 |
r"""
|
2009 |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
|
2038 |
if self.num_labels == 1:
|
2039 |
self.config.problem_type = "regression"
|
2040 |
elif self.num_labels > 1 and (
|
2041 |
+
labels.dtype == torch.long or labels.dtype == torch.int
|
2042 |
):
|
2043 |
self.config.problem_type = "single_label_classification"
|
2044 |
else:
|
|
|
2100 |
config_class=_CONFIG_FOR_DOC,
|
2101 |
)
|
2102 |
def forward(
|
2103 |
+
self,
|
2104 |
+
input_ids: Optional[torch.Tensor] = None,
|
2105 |
+
attention_mask: Optional[torch.Tensor] = None,
|
2106 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
2107 |
+
position_ids: Optional[torch.Tensor] = None,
|
2108 |
+
head_mask: Optional[torch.Tensor] = None,
|
2109 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
2110 |
+
labels: Optional[torch.Tensor] = None,
|
2111 |
+
output_attentions: Optional[bool] = None,
|
2112 |
+
output_hidden_states: Optional[bool] = None,
|
2113 |
+
return_dict: Optional[bool] = None,
|
2114 |
) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
|
2115 |
r"""
|
2116 |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
|
2219 |
expected_loss=_TOKEN_CLASS_EXPECTED_LOSS,
|
2220 |
)
|
2221 |
def forward(
|
2222 |
+
self,
|
2223 |
+
input_ids: Optional[torch.Tensor] = None,
|
2224 |
+
attention_mask: Optional[torch.Tensor] = None,
|
2225 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
2226 |
+
position_ids: Optional[torch.Tensor] = None,
|
2227 |
+
head_mask: Optional[torch.Tensor] = None,
|
2228 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
2229 |
+
labels: Optional[torch.Tensor] = None,
|
2230 |
+
output_attentions: Optional[bool] = None,
|
2231 |
+
output_hidden_states: Optional[bool] = None,
|
2232 |
+
return_dict: Optional[bool] = None,
|
2233 |
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
|
2234 |
r"""
|
2235 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
|
2304 |
expected_loss=_QA_EXPECTED_LOSS,
|
2305 |
)
|
2306 |
def forward(
|
2307 |
+
self,
|
2308 |
+
input_ids: Optional[torch.Tensor] = None,
|
2309 |
+
attention_mask: Optional[torch.Tensor] = None,
|
2310 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
2311 |
+
position_ids: Optional[torch.Tensor] = None,
|
2312 |
+
head_mask: Optional[torch.Tensor] = None,
|
2313 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
2314 |
+
start_positions: Optional[torch.Tensor] = None,
|
2315 |
+
end_positions: Optional[torch.Tensor] = None,
|
2316 |
+
output_attentions: Optional[bool] = None,
|
2317 |
+
output_hidden_states: Optional[bool] = None,
|
2318 |
+
return_dict: Optional[bool] = None,
|
2319 |
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
2320 |
r"""
|
2321 |
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|