zhiqu22 commited on
Commit
317d82a
·
1 Parent(s): 74025f2
Files changed (1) hide show
  1. modeling_mitre.py +10 -6
modeling_mitre.py CHANGED
@@ -433,7 +433,6 @@ class MitreDecoder(MitrePreTrainedModel):
433
  # ensure contiguous
434
  expanded_src_tokens = self.check_contiguous(expanded_src_tokens)
435
  source_tokens = self.check_contiguous(source_tokens)
436
- src_length = source_tokens.shape[1]
437
 
438
  # get embeds with positions for source tokens (input_ids + registers)
439
  inputs_embeds = self.source_tokens_embedding_with_positions(expanded_src_tokens, total_token_nums, batch_indices, token_indices)
@@ -444,11 +443,11 @@ class MitreDecoder(MitrePreTrainedModel):
444
  decoder_input_ids[:, 0] = source_tokens[:, -1]
445
 
446
  tokens = torch.cat([source_tokens, decoder_input_ids], dim=1)
 
447
 
448
  decoder_inputs_embeds = self.embed_tokens(decoder_input_ids)
449
  decoder_inputs_embeds = decoder_inputs_embeds + self.tgt_embed_positions(decoder_input_ids, past_key_values_length, src_length=src_length)
450
- # if past_key_values_length > 0:
451
- # raise ValueError()
452
  if past_key_values_length == 0:
453
  hidden_states = torch.cat([inputs_embeds, decoder_inputs_embeds], dim=1)
454
  else:
@@ -759,8 +758,8 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
759
  return reordered_past
760
 
761
  @staticmethod
762
- def _reorder_register_nums(register_nums, beam_idx):
763
- return register_nums.index_select(0, beam_idx.to(register_nums.device))
764
 
765
  @staticmethod
766
  def _expand_inputs_for_generation(
@@ -844,6 +843,9 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
844
  if past_key_values is not None:
845
  decoder_input_ids_for_generation = decoder_input_ids[:, -1:]
846
  attention_mask = registering_cache["attention_mask"]
 
 
 
847
  if attention_mask is not None:
848
  attention_mask = torch.cat((attention_mask, attention_mask[..., -1:]), dim=-1)
849
  else:
@@ -907,7 +909,9 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
907
  del outputs
908
 
909
  past_key_values = self._reorder_cache(past_key_values, beam_idx)
910
- registering_cache["register_nums"] = self._reorder_register_nums(registering_cache["register_nums"], beam_idx)
 
 
911
 
912
  cur_len = cur_len + 1
913
 
 
433
  # ensure contiguous
434
  expanded_src_tokens = self.check_contiguous(expanded_src_tokens)
435
  source_tokens = self.check_contiguous(source_tokens)
 
436
 
437
  # get embeds with positions for source tokens (input_ids + registers)
438
  inputs_embeds = self.source_tokens_embedding_with_positions(expanded_src_tokens, total_token_nums, batch_indices, token_indices)
 
443
  decoder_input_ids[:, 0] = source_tokens[:, -1]
444
 
445
  tokens = torch.cat([source_tokens, decoder_input_ids], dim=1)
446
+ src_length = source_tokens.shape[1]
447
 
448
  decoder_inputs_embeds = self.embed_tokens(decoder_input_ids)
449
  decoder_inputs_embeds = decoder_inputs_embeds + self.tgt_embed_positions(decoder_input_ids, past_key_values_length, src_length=src_length)
450
+
 
451
  if past_key_values_length == 0:
452
  hidden_states = torch.cat([inputs_embeds, decoder_inputs_embeds], dim=1)
453
  else:
 
758
  return reordered_past
759
 
760
  @staticmethod
761
+ def _reorder_register_cache(t, beam_idx):
762
+ return t.index_select(dim=0, index=beam_idx.to(t.device))
763
 
764
  @staticmethod
765
  def _expand_inputs_for_generation(
 
843
  if past_key_values is not None:
844
  decoder_input_ids_for_generation = decoder_input_ids[:, -1:]
845
  attention_mask = registering_cache["attention_mask"]
846
+ # Get the mask when the first time using kv cache.
847
+ # After it, we can simply repeat 0. (the last column of mask) to get the next mask.
848
+ # As a result, we avoid generate the mask from scratch in kv cache and save memory.
849
  if attention_mask is not None:
850
  attention_mask = torch.cat((attention_mask, attention_mask[..., -1:]), dim=-1)
851
  else:
 
909
  del outputs
910
 
911
  past_key_values = self._reorder_cache(past_key_values, beam_idx)
912
+ registering_cache["register_nums"] = self._reorder_register_cache(registering_cache["register_nums"], beam_idx)
913
+ if registering_cache["attention_mask"] is not None:
914
+ registering_cache["attention_mask"] = self._reorder_register_cache(registering_cache["attention_mask"], beam_idx)
915
 
916
  cur_len = cur_len + 1
917