zhiqu22
commited on
Commit
·
1d88f72
1
Parent(s):
317d82a
implement early stop
Browse files- modeling_mitre.py +103 -28
modeling_mitre.py
CHANGED
@@ -11,8 +11,6 @@ from transformers.utils import logging
|
|
11 |
from transformers.generation import GenerationMixin
|
12 |
from transformers.modeling_utils import PreTrainedModel
|
13 |
from transformers.activations import ACT2FN
|
14 |
-
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
15 |
-
from transformers.integrations.fsdp import is_fsdp_managed_module
|
16 |
from transformers.modeling_outputs import (
|
17 |
BaseModelOutputWithPastAndCrossAttentions,
|
18 |
Seq2SeqLMOutput,
|
@@ -75,10 +73,6 @@ class MitreSdpaAttention(nn.Module):
|
|
75 |
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
76 |
attention_mask: Optional[torch.Tensor] = None,
|
77 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
78 |
-
"""
|
79 |
-
Input shape: Batch x Time x Channel
|
80 |
-
Output objects: attn_output, attn_weights (always be None), past_key_value
|
81 |
-
"""
|
82 |
"""
|
83 |
1. MitreModel is using MitreSdpaAttention, which is modifed from M2M100SdpaAttention.
|
84 |
Notabley, both of them do not support `output_attentions=True` or `layer_head_mask` not None,
|
@@ -360,6 +354,8 @@ class MitreDecoder(MitrePreTrainedModel):
|
|
360 |
|
361 |
elif past_key_values_length > 0:
|
362 |
# in generation
|
|
|
|
|
363 |
mask = torch.zeros(past_key_values_length + 1)
|
364 |
mask = mask.to(embeds, copy=True)
|
365 |
batch_mask = mask.unsqueeze(0).expand(b, -1).clone().contiguous()
|
@@ -374,7 +370,6 @@ class MitreDecoder(MitrePreTrainedModel):
|
|
374 |
batch_mask = batch_mask.view(b, 1, batch_mask.shape[-2], batch_mask.shape[-1])
|
375 |
return batch_mask
|
376 |
|
377 |
-
|
378 |
def forward(
|
379 |
self,
|
380 |
input_ids: Optional[torch.Tensor] = None,
|
@@ -531,7 +526,6 @@ class MitreDecoder(MitrePreTrainedModel):
|
|
531 |
cache_value[:, :, src_length - max_register_num:, :]
|
532 |
)
|
533 |
next_decoder_cache += (clipped_rep,)
|
534 |
-
|
535 |
|
536 |
if past_key_values_length == 0:
|
537 |
hidden_states = hidden_states[:,src_length:,:]
|
@@ -759,6 +753,7 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
|
|
759 |
|
760 |
@staticmethod
|
761 |
def _reorder_register_cache(t, beam_idx):
|
|
|
762 |
return t.index_select(dim=0, index=beam_idx.to(t.device))
|
763 |
|
764 |
@staticmethod
|
@@ -782,15 +777,32 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
|
|
782 |
):
|
783 |
"""
|
784 |
Inference with beam search.
|
785 |
-
This code is
|
786 |
-
|
787 |
-
|
788 |
-
|
789 |
-
|
790 |
-
|
791 |
-
|
792 |
-
|
793 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
794 |
"""
|
795 |
if generation_config != None:
|
796 |
assert type(generation_config) is GenerationConfig
|
@@ -831,13 +843,18 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
|
|
831 |
past_key_values = None
|
832 |
registering_cache= None
|
833 |
attention_mask = None
|
|
|
|
|
|
|
834 |
|
|
|
|
|
835 |
logits_processor = LogitsProcessorList()
|
836 |
stopping_criteria = StoppingCriteriaList()
|
837 |
|
838 |
beam_scores = torch.zeros((batch_size, beam_size), dtype=torch.float, device=input_ids.device)
|
839 |
beam_scores[:, 1:] = -1e9
|
840 |
-
beam_scores = beam_scores.view((batch_size * beam_size,))
|
841 |
while not this_peer_finished:
|
842 |
|
843 |
if past_key_values is not None:
|
@@ -850,7 +867,7 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
|
|
850 |
attention_mask = torch.cat((attention_mask, attention_mask[..., -1:]), dim=-1)
|
851 |
else:
|
852 |
decoder_input_ids_for_generation = decoder_input_ids
|
853 |
-
|
854 |
outputs = self(
|
855 |
input_ids,
|
856 |
decoder_input_ids_for_generation,
|
@@ -859,21 +876,43 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
|
|
859 |
use_cache=True,
|
860 |
registering_cache=registering_cache
|
861 |
)
|
862 |
-
|
863 |
del input_ids
|
864 |
input_ids = None
|
865 |
|
866 |
past_key_values = outputs.past_key_values
|
867 |
registering_cache = outputs.registering_cache
|
868 |
-
|
869 |
next_token_logits = outputs.logits[:, -1, :].clone().float()
|
870 |
-
|
871 |
|
|
|
872 |
next_token_scores = nn.functional.log_softmax(
|
873 |
next_token_logits, dim=-1
|
874 |
) # (batch_size * num_beams, vocab_size)
|
875 |
|
876 |
next_token_scores_processed = logits_processor(decoder_input_ids, next_token_scores)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
877 |
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
|
878 |
next_token_scores_processed
|
879 |
)
|
@@ -892,6 +931,7 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
|
|
892 |
|
893 |
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
|
894 |
next_tokens = next_tokens % vocab_size
|
|
|
895 |
beam_outputs = beam_scorer.process(
|
896 |
decoder_input_ids,
|
897 |
next_token_scores,
|
@@ -904,15 +944,50 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
|
|
904 |
beam_scores = beam_outputs["next_beam_scores"]
|
905 |
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
906 |
beam_idx = beam_outputs["next_beam_indices"]
|
907 |
-
decoder_input_ids = torch.cat([decoder_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
908 |
|
909 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
910 |
|
911 |
-
|
912 |
-
|
913 |
-
|
914 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
915 |
|
|
|
|
|
|
|
|
|
|
|
916 |
cur_len = cur_len + 1
|
917 |
|
918 |
if beam_scorer.is_done:
|
|
|
11 |
from transformers.generation import GenerationMixin
|
12 |
from transformers.modeling_utils import PreTrainedModel
|
13 |
from transformers.activations import ACT2FN
|
|
|
|
|
14 |
from transformers.modeling_outputs import (
|
15 |
BaseModelOutputWithPastAndCrossAttentions,
|
16 |
Seq2SeqLMOutput,
|
|
|
73 |
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
74 |
attention_mask: Optional[torch.Tensor] = None,
|
75 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
|
|
|
|
|
|
76 |
"""
|
77 |
1. MitreModel is using MitreSdpaAttention, which is modifed from M2M100SdpaAttention.
|
78 |
Notabley, both of them do not support `output_attentions=True` or `layer_head_mask` not None,
|
|
|
354 |
|
355 |
elif past_key_values_length > 0:
|
356 |
# in generation
|
357 |
+
# this block is only used in fairseq and is not used in huggingface,
|
358 |
+
# because we reuse the mask by the cache.
|
359 |
mask = torch.zeros(past_key_values_length + 1)
|
360 |
mask = mask.to(embeds, copy=True)
|
361 |
batch_mask = mask.unsqueeze(0).expand(b, -1).clone().contiguous()
|
|
|
370 |
batch_mask = batch_mask.view(b, 1, batch_mask.shape[-2], batch_mask.shape[-1])
|
371 |
return batch_mask
|
372 |
|
|
|
373 |
def forward(
|
374 |
self,
|
375 |
input_ids: Optional[torch.Tensor] = None,
|
|
|
526 |
cache_value[:, :, src_length - max_register_num:, :]
|
527 |
)
|
528 |
next_decoder_cache += (clipped_rep,)
|
|
|
529 |
|
530 |
if past_key_values_length == 0:
|
531 |
hidden_states = hidden_states[:,src_length:,:]
|
|
|
753 |
|
754 |
@staticmethod
|
755 |
def _reorder_register_cache(t, beam_idx):
|
756 |
+
""" a costumized reorder method """
|
757 |
return t.index_select(dim=0, index=beam_idx.to(t.device))
|
758 |
|
759 |
@staticmethod
|
|
|
777 |
):
|
778 |
"""
|
779 |
Inference with beam search.
|
780 |
+
This code is improved from 'transformers.generation.utils.GenerationMixin.generate'.
|
781 |
+
There are **two main improved points**:
|
782 |
+
1. 'soft early_stop' in beam search.
|
783 |
+
a) problem in the vanilla version.
|
784 |
+
In multilingual translation model, e.g., NLLB and M2M, they adopt the 'vanilla early_
|
785 |
+
stop' in BeamSearchScorer (the official implementation provided by HuggingFace), i.e.,
|
786 |
+
the sequence, which is labled by 'end', is filled by 'pad(1)' still, in other words,
|
787 |
+
the ended sequence is fed into the model still, resulting in a heavy memory waste.
|
788 |
+
b) our improvement.
|
789 |
+
We implement soft early_stop to resolve the problem. Specifically, we do not change
|
790 |
+
anything in BeamSearchScorer to keep the codes' flexibility, rather we remove the ended
|
791 |
+
sequence from the input. Then, given that the output hidden states' shape is changed,
|
792 |
+
we insert some placeholders to keep the shape of BeamSearchScorer's states.
|
793 |
+
Based on our test, this improvement can decrease the memory cost to half than before.
|
794 |
+
2. mask reusing.
|
795 |
+
a) problem: registers need attention masks in each step.
|
796 |
+
A sequence possibly consists 4 parts, i.e., pads, source tokens, registers, and target
|
797 |
+
tokens. In training, we mask all tokens before registers for the generation of target
|
798 |
+
tokens. As a result, in generation, we cannot allow the target tokens to 'see' pads.
|
799 |
+
So, we need masks in each step, leading to computational resource waste.
|
800 |
+
b) our improvement.
|
801 |
+
First, we turncate the source tokens to save cost.
|
802 |
+
Second, given that there still exists some source tokens playing the role of placeholders,
|
803 |
+
we modify the mask generation compared to our codes in fairseq.
|
804 |
+
Third, in order to avoid re-generating masks, we add the mask into 'registering_cache'.
|
805 |
+
Then, we manage its order as the kv cache in beam search, and add a column of 0. every step.
|
806 |
"""
|
807 |
if generation_config != None:
|
808 |
assert type(generation_config) is GenerationConfig
|
|
|
843 |
past_key_values = None
|
844 |
registering_cache= None
|
845 |
attention_mask = None
|
846 |
+
# done_mask shows the ended sequences.
|
847 |
+
# (~done_mask) shows the running sequences.
|
848 |
+
done_mask = None
|
849 |
|
850 |
+
# we follow the style of M2M and NLLB
|
851 |
+
# so we simplify the initialization of thoes two processors.
|
852 |
logits_processor = LogitsProcessorList()
|
853 |
stopping_criteria = StoppingCriteriaList()
|
854 |
|
855 |
beam_scores = torch.zeros((batch_size, beam_size), dtype=torch.float, device=input_ids.device)
|
856 |
beam_scores[:, 1:] = -1e9
|
857 |
+
beam_scores = beam_scores.view((batch_size * beam_size,))
|
858 |
while not this_peer_finished:
|
859 |
|
860 |
if past_key_values is not None:
|
|
|
867 |
attention_mask = torch.cat((attention_mask, attention_mask[..., -1:]), dim=-1)
|
868 |
else:
|
869 |
decoder_input_ids_for_generation = decoder_input_ids
|
870 |
+
|
871 |
outputs = self(
|
872 |
input_ids,
|
873 |
decoder_input_ids_for_generation,
|
|
|
876 |
use_cache=True,
|
877 |
registering_cache=registering_cache
|
878 |
)
|
|
|
879 |
del input_ids
|
880 |
input_ids = None
|
881 |
|
882 |
past_key_values = outputs.past_key_values
|
883 |
registering_cache = outputs.registering_cache
|
|
|
884 |
next_token_logits = outputs.logits[:, -1, :].clone().float()
|
885 |
+
del outputs
|
886 |
|
887 |
+
next_token_logits = next_token_logits.to(device)
|
888 |
next_token_scores = nn.functional.log_softmax(
|
889 |
next_token_logits, dim=-1
|
890 |
) # (batch_size * num_beams, vocab_size)
|
891 |
|
892 |
next_token_scores_processed = logits_processor(decoder_input_ids, next_token_scores)
|
893 |
+
|
894 |
+
# if any sequence is ended, we have to keep the shape of Scorer's states.
|
895 |
+
# Details are described in the head of this function.
|
896 |
+
if done_mask is not None:
|
897 |
+
if done_mask.any():
|
898 |
+
# the placeholder of scores is '0.'
|
899 |
+
restored_tensor = torch.zeros(
|
900 |
+
(batch_size * beam_size, next_token_scores_processed.shape[1]),
|
901 |
+
dtype=next_token_scores_processed.dtype,
|
902 |
+
device=next_token_scores_processed.device
|
903 |
+
)
|
904 |
+
restored_tensor[~done_mask] = next_token_scores_processed
|
905 |
+
next_token_scores_processed = restored_tensor
|
906 |
+
# the placeholder of tokens is 'pad_token_id'
|
907 |
+
restored_tokens = torch.full(
|
908 |
+
(batch_size * beam_size, decoder_input_ids.shape[1]),
|
909 |
+
self.generation_config.pad_token_id,
|
910 |
+
dtype=decoder_input_ids.dtype,
|
911 |
+
device=device
|
912 |
+
)
|
913 |
+
restored_tokens[~done_mask] = decoder_input_ids
|
914 |
+
decoder_input_ids = restored_tokens
|
915 |
+
|
916 |
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
|
917 |
next_token_scores_processed
|
918 |
)
|
|
|
931 |
|
932 |
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
|
933 |
next_tokens = next_tokens % vocab_size
|
934 |
+
|
935 |
beam_outputs = beam_scorer.process(
|
936 |
decoder_input_ids,
|
937 |
next_token_scores,
|
|
|
944 |
beam_scores = beam_outputs["next_beam_scores"]
|
945 |
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
946 |
beam_idx = beam_outputs["next_beam_indices"]
|
|
|
947 |
|
948 |
+
# 'last_done_mask' is used for reordering cache
|
949 |
+
# details are described in the next code block
|
950 |
+
if done_mask is not None:
|
951 |
+
last_done_mask = done_mask
|
952 |
+
|
953 |
+
# get the newest status of sequences.
|
954 |
+
# then, filter the beam_idx
|
955 |
+
done_mask = beam_scorer._done.clone().view(-1)
|
956 |
+
done_mask = self._expand_inputs_for_generation(done_mask, beam_size)
|
957 |
+
beam_idx = beam_idx[~done_mask]
|
958 |
|
959 |
+
decoder_input_ids = torch.cat([decoder_input_ids[beam_idx, :], beam_next_tokens[~done_mask].unsqueeze(-1)], dim=-1)
|
960 |
+
|
961 |
+
# different from processing tokens, caches' order is decided by 'tokens', 'done_mask' and
|
962 |
+
# 'beam_idx', simultaneously.
|
963 |
+
if decoder_input_ids_for_generation.shape[0] < beam_next_tokens.shape[0]:
|
964 |
+
# Take carefule! If the running sequences' num is small than the num of input sequences,
|
965 |
+
# it means the Scorer decides to end it, but the cache still follows the last status.
|
966 |
+
# Therefore, we should employ the last done mask rather than newest done mask.
|
967 |
+
if (~done_mask).sum() < decoder_input_ids_for_generation.shape[0]:
|
968 |
+
count_mask = last_done_mask
|
969 |
+
else:
|
970 |
+
count_mask = done_mask
|
971 |
+
# For biasing the beam_idx
|
972 |
+
# Example:
|
973 |
+
# done_mask with beam size of 2: [f, f, t, t, f, f]
|
974 |
+
# beam_idx: [0, 0, 2, 2, 4, 5]
|
975 |
+
# reorder_idx: [0-0, 0-0, 4-2, 5-2]
|
976 |
+
prefix_sum = torch.cat([
|
977 |
+
torch.zeros_like(count_mask[:1], dtype=torch.long),
|
978 |
+
torch.cumsum(count_mask.long(), dim=0)
|
979 |
+
], dim=0)
|
980 |
+
reorder_idx = beam_idx - prefix_sum[beam_idx]
|
981 |
+
not_done = ~done_mask[beam_idx]
|
982 |
+
reorder_idx = reorder_idx[not_done]
|
983 |
+
else:
|
984 |
+
reorder_idx = beam_idx
|
985 |
|
986 |
+
past_key_values = self._reorder_cache(past_key_values, reorder_idx)
|
987 |
+
registering_cache["register_nums"] = self._reorder_register_cache(registering_cache["register_nums"], reorder_idx)
|
988 |
+
if registering_cache["attention_mask"] is not None:
|
989 |
+
registering_cache["attention_mask"] = self._reorder_register_cache(registering_cache["attention_mask"], reorder_idx)
|
990 |
+
|
991 |
cur_len = cur_len + 1
|
992 |
|
993 |
if beam_scorer.is_done:
|