zhiqu22 commited on
Commit
1d88f72
·
1 Parent(s): 317d82a

implement early stop

Browse files
Files changed (1) hide show
  1. modeling_mitre.py +103 -28
modeling_mitre.py CHANGED
@@ -11,8 +11,6 @@ from transformers.utils import logging
11
  from transformers.generation import GenerationMixin
12
  from transformers.modeling_utils import PreTrainedModel
13
  from transformers.activations import ACT2FN
14
- from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
15
- from transformers.integrations.fsdp import is_fsdp_managed_module
16
  from transformers.modeling_outputs import (
17
  BaseModelOutputWithPastAndCrossAttentions,
18
  Seq2SeqLMOutput,
@@ -75,10 +73,6 @@ class MitreSdpaAttention(nn.Module):
75
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
76
  attention_mask: Optional[torch.Tensor] = None,
77
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
78
- """
79
- Input shape: Batch x Time x Channel
80
- Output objects: attn_output, attn_weights (always be None), past_key_value
81
- """
82
  """
83
  1. MitreModel is using MitreSdpaAttention, which is modifed from M2M100SdpaAttention.
84
  Notabley, both of them do not support `output_attentions=True` or `layer_head_mask` not None,
@@ -360,6 +354,8 @@ class MitreDecoder(MitrePreTrainedModel):
360
 
361
  elif past_key_values_length > 0:
362
  # in generation
 
 
363
  mask = torch.zeros(past_key_values_length + 1)
364
  mask = mask.to(embeds, copy=True)
365
  batch_mask = mask.unsqueeze(0).expand(b, -1).clone().contiguous()
@@ -374,7 +370,6 @@ class MitreDecoder(MitrePreTrainedModel):
374
  batch_mask = batch_mask.view(b, 1, batch_mask.shape[-2], batch_mask.shape[-1])
375
  return batch_mask
376
 
377
-
378
  def forward(
379
  self,
380
  input_ids: Optional[torch.Tensor] = None,
@@ -531,7 +526,6 @@ class MitreDecoder(MitrePreTrainedModel):
531
  cache_value[:, :, src_length - max_register_num:, :]
532
  )
533
  next_decoder_cache += (clipped_rep,)
534
-
535
 
536
  if past_key_values_length == 0:
537
  hidden_states = hidden_states[:,src_length:,:]
@@ -759,6 +753,7 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
759
 
760
  @staticmethod
761
  def _reorder_register_cache(t, beam_idx):
 
762
  return t.index_select(dim=0, index=beam_idx.to(t.device))
763
 
764
  @staticmethod
@@ -782,15 +777,32 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
782
  ):
783
  """
784
  Inference with beam search.
785
- This code is simplified from 'transformers.generation.utils.GenerationMixin.generate'.
786
- This code follows the style of m2m and nllb.
787
- Therefore, there are two points need improvement.
788
- TODO
789
- 1. early_stop in beam search.
790
- Current early_stop is at the beam search level instead of model level. Specficially,
791
- although beamscorer generates eos to the sequence, the sequence is filled by 'pad(1)'.
792
- As a result, the sequence, which has already finished, will be computed by the model
793
- continuously. We plan to remove the finished token as Fairseq's style.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
794
  """
795
  if generation_config != None:
796
  assert type(generation_config) is GenerationConfig
@@ -831,13 +843,18 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
831
  past_key_values = None
832
  registering_cache= None
833
  attention_mask = None
 
 
 
834
 
 
 
835
  logits_processor = LogitsProcessorList()
836
  stopping_criteria = StoppingCriteriaList()
837
 
838
  beam_scores = torch.zeros((batch_size, beam_size), dtype=torch.float, device=input_ids.device)
839
  beam_scores[:, 1:] = -1e9
840
- beam_scores = beam_scores.view((batch_size * beam_size,))
841
  while not this_peer_finished:
842
 
843
  if past_key_values is not None:
@@ -850,7 +867,7 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
850
  attention_mask = torch.cat((attention_mask, attention_mask[..., -1:]), dim=-1)
851
  else:
852
  decoder_input_ids_for_generation = decoder_input_ids
853
-
854
  outputs = self(
855
  input_ids,
856
  decoder_input_ids_for_generation,
@@ -859,21 +876,43 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
859
  use_cache=True,
860
  registering_cache=registering_cache
861
  )
862
-
863
  del input_ids
864
  input_ids = None
865
 
866
  past_key_values = outputs.past_key_values
867
  registering_cache = outputs.registering_cache
868
-
869
  next_token_logits = outputs.logits[:, -1, :].clone().float()
870
- next_token_logits = next_token_logits.to(device)
871
 
 
872
  next_token_scores = nn.functional.log_softmax(
873
  next_token_logits, dim=-1
874
  ) # (batch_size * num_beams, vocab_size)
875
 
876
  next_token_scores_processed = logits_processor(decoder_input_ids, next_token_scores)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
877
  next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
878
  next_token_scores_processed
879
  )
@@ -892,6 +931,7 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
892
 
893
  next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
894
  next_tokens = next_tokens % vocab_size
 
895
  beam_outputs = beam_scorer.process(
896
  decoder_input_ids,
897
  next_token_scores,
@@ -904,15 +944,50 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
904
  beam_scores = beam_outputs["next_beam_scores"]
905
  beam_next_tokens = beam_outputs["next_beam_tokens"]
906
  beam_idx = beam_outputs["next_beam_indices"]
907
- decoder_input_ids = torch.cat([decoder_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
908
 
909
- del outputs
 
 
 
 
 
 
 
 
 
910
 
911
- past_key_values = self._reorder_cache(past_key_values, beam_idx)
912
- registering_cache["register_nums"] = self._reorder_register_cache(registering_cache["register_nums"], beam_idx)
913
- if registering_cache["attention_mask"] is not None:
914
- registering_cache["attention_mask"] = self._reorder_register_cache(registering_cache["attention_mask"], beam_idx)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
915
 
 
 
 
 
 
916
  cur_len = cur_len + 1
917
 
918
  if beam_scorer.is_done:
 
11
  from transformers.generation import GenerationMixin
12
  from transformers.modeling_utils import PreTrainedModel
13
  from transformers.activations import ACT2FN
 
 
14
  from transformers.modeling_outputs import (
15
  BaseModelOutputWithPastAndCrossAttentions,
16
  Seq2SeqLMOutput,
 
73
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
74
  attention_mask: Optional[torch.Tensor] = None,
75
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 
 
 
 
76
  """
77
  1. MitreModel is using MitreSdpaAttention, which is modifed from M2M100SdpaAttention.
78
  Notabley, both of them do not support `output_attentions=True` or `layer_head_mask` not None,
 
354
 
355
  elif past_key_values_length > 0:
356
  # in generation
357
+ # this block is only used in fairseq and is not used in huggingface,
358
+ # because we reuse the mask by the cache.
359
  mask = torch.zeros(past_key_values_length + 1)
360
  mask = mask.to(embeds, copy=True)
361
  batch_mask = mask.unsqueeze(0).expand(b, -1).clone().contiguous()
 
370
  batch_mask = batch_mask.view(b, 1, batch_mask.shape[-2], batch_mask.shape[-1])
371
  return batch_mask
372
 
 
373
  def forward(
374
  self,
375
  input_ids: Optional[torch.Tensor] = None,
 
526
  cache_value[:, :, src_length - max_register_num:, :]
527
  )
528
  next_decoder_cache += (clipped_rep,)
 
529
 
530
  if past_key_values_length == 0:
531
  hidden_states = hidden_states[:,src_length:,:]
 
753
 
754
  @staticmethod
755
  def _reorder_register_cache(t, beam_idx):
756
+ """ a costumized reorder method """
757
  return t.index_select(dim=0, index=beam_idx.to(t.device))
758
 
759
  @staticmethod
 
777
  ):
778
  """
779
  Inference with beam search.
780
+ This code is improved from 'transformers.generation.utils.GenerationMixin.generate'.
781
+ There are **two main improved points**:
782
+ 1. 'soft early_stop' in beam search.
783
+ a) problem in the vanilla version.
784
+ In multilingual translation model, e.g., NLLB and M2M, they adopt the 'vanilla early_
785
+ stop' in BeamSearchScorer (the official implementation provided by HuggingFace), i.e.,
786
+ the sequence, which is labled by 'end', is filled by 'pad(1)' still, in other words,
787
+ the ended sequence is fed into the model still, resulting in a heavy memory waste.
788
+ b) our improvement.
789
+ We implement soft early_stop to resolve the problem. Specifically, we do not change
790
+ anything in BeamSearchScorer to keep the codes' flexibility, rather we remove the ended
791
+ sequence from the input. Then, given that the output hidden states' shape is changed,
792
+ we insert some placeholders to keep the shape of BeamSearchScorer's states.
793
+ Based on our test, this improvement can decrease the memory cost to half than before.
794
+ 2. mask reusing.
795
+ a) problem: registers need attention masks in each step.
796
+ A sequence possibly consists 4 parts, i.e., pads, source tokens, registers, and target
797
+ tokens. In training, we mask all tokens before registers for the generation of target
798
+ tokens. As a result, in generation, we cannot allow the target tokens to 'see' pads.
799
+ So, we need masks in each step, leading to computational resource waste.
800
+ b) our improvement.
801
+ First, we turncate the source tokens to save cost.
802
+ Second, given that there still exists some source tokens playing the role of placeholders,
803
+ we modify the mask generation compared to our codes in fairseq.
804
+ Third, in order to avoid re-generating masks, we add the mask into 'registering_cache'.
805
+ Then, we manage its order as the kv cache in beam search, and add a column of 0. every step.
806
  """
807
  if generation_config != None:
808
  assert type(generation_config) is GenerationConfig
 
843
  past_key_values = None
844
  registering_cache= None
845
  attention_mask = None
846
+ # done_mask shows the ended sequences.
847
+ # (~done_mask) shows the running sequences.
848
+ done_mask = None
849
 
850
+ # we follow the style of M2M and NLLB
851
+ # so we simplify the initialization of thoes two processors.
852
  logits_processor = LogitsProcessorList()
853
  stopping_criteria = StoppingCriteriaList()
854
 
855
  beam_scores = torch.zeros((batch_size, beam_size), dtype=torch.float, device=input_ids.device)
856
  beam_scores[:, 1:] = -1e9
857
+ beam_scores = beam_scores.view((batch_size * beam_size,))
858
  while not this_peer_finished:
859
 
860
  if past_key_values is not None:
 
867
  attention_mask = torch.cat((attention_mask, attention_mask[..., -1:]), dim=-1)
868
  else:
869
  decoder_input_ids_for_generation = decoder_input_ids
870
+
871
  outputs = self(
872
  input_ids,
873
  decoder_input_ids_for_generation,
 
876
  use_cache=True,
877
  registering_cache=registering_cache
878
  )
 
879
  del input_ids
880
  input_ids = None
881
 
882
  past_key_values = outputs.past_key_values
883
  registering_cache = outputs.registering_cache
 
884
  next_token_logits = outputs.logits[:, -1, :].clone().float()
885
+ del outputs
886
 
887
+ next_token_logits = next_token_logits.to(device)
888
  next_token_scores = nn.functional.log_softmax(
889
  next_token_logits, dim=-1
890
  ) # (batch_size * num_beams, vocab_size)
891
 
892
  next_token_scores_processed = logits_processor(decoder_input_ids, next_token_scores)
893
+
894
+ # if any sequence is ended, we have to keep the shape of Scorer's states.
895
+ # Details are described in the head of this function.
896
+ if done_mask is not None:
897
+ if done_mask.any():
898
+ # the placeholder of scores is '0.'
899
+ restored_tensor = torch.zeros(
900
+ (batch_size * beam_size, next_token_scores_processed.shape[1]),
901
+ dtype=next_token_scores_processed.dtype,
902
+ device=next_token_scores_processed.device
903
+ )
904
+ restored_tensor[~done_mask] = next_token_scores_processed
905
+ next_token_scores_processed = restored_tensor
906
+ # the placeholder of tokens is 'pad_token_id'
907
+ restored_tokens = torch.full(
908
+ (batch_size * beam_size, decoder_input_ids.shape[1]),
909
+ self.generation_config.pad_token_id,
910
+ dtype=decoder_input_ids.dtype,
911
+ device=device
912
+ )
913
+ restored_tokens[~done_mask] = decoder_input_ids
914
+ decoder_input_ids = restored_tokens
915
+
916
  next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
917
  next_token_scores_processed
918
  )
 
931
 
932
  next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
933
  next_tokens = next_tokens % vocab_size
934
+
935
  beam_outputs = beam_scorer.process(
936
  decoder_input_ids,
937
  next_token_scores,
 
944
  beam_scores = beam_outputs["next_beam_scores"]
945
  beam_next_tokens = beam_outputs["next_beam_tokens"]
946
  beam_idx = beam_outputs["next_beam_indices"]
 
947
 
948
+ # 'last_done_mask' is used for reordering cache
949
+ # details are described in the next code block
950
+ if done_mask is not None:
951
+ last_done_mask = done_mask
952
+
953
+ # get the newest status of sequences.
954
+ # then, filter the beam_idx
955
+ done_mask = beam_scorer._done.clone().view(-1)
956
+ done_mask = self._expand_inputs_for_generation(done_mask, beam_size)
957
+ beam_idx = beam_idx[~done_mask]
958
 
959
+ decoder_input_ids = torch.cat([decoder_input_ids[beam_idx, :], beam_next_tokens[~done_mask].unsqueeze(-1)], dim=-1)
960
+
961
+ # different from processing tokens, caches' order is decided by 'tokens', 'done_mask' and
962
+ # 'beam_idx', simultaneously.
963
+ if decoder_input_ids_for_generation.shape[0] < beam_next_tokens.shape[0]:
964
+ # Take carefule! If the running sequences' num is small than the num of input sequences,
965
+ # it means the Scorer decides to end it, but the cache still follows the last status.
966
+ # Therefore, we should employ the last done mask rather than newest done mask.
967
+ if (~done_mask).sum() < decoder_input_ids_for_generation.shape[0]:
968
+ count_mask = last_done_mask
969
+ else:
970
+ count_mask = done_mask
971
+ # For biasing the beam_idx
972
+ # Example:
973
+ # done_mask with beam size of 2: [f, f, t, t, f, f]
974
+ # beam_idx: [0, 0, 2, 2, 4, 5]
975
+ # reorder_idx: [0-0, 0-0, 4-2, 5-2]
976
+ prefix_sum = torch.cat([
977
+ torch.zeros_like(count_mask[:1], dtype=torch.long),
978
+ torch.cumsum(count_mask.long(), dim=0)
979
+ ], dim=0)
980
+ reorder_idx = beam_idx - prefix_sum[beam_idx]
981
+ not_done = ~done_mask[beam_idx]
982
+ reorder_idx = reorder_idx[not_done]
983
+ else:
984
+ reorder_idx = beam_idx
985
 
986
+ past_key_values = self._reorder_cache(past_key_values, reorder_idx)
987
+ registering_cache["register_nums"] = self._reorder_register_cache(registering_cache["register_nums"], reorder_idx)
988
+ if registering_cache["attention_mask"] is not None:
989
+ registering_cache["attention_mask"] = self._reorder_register_cache(registering_cache["attention_mask"], reorder_idx)
990
+
991
  cur_len = cur_len + 1
992
 
993
  if beam_scorer.is_done: