zhiqu22
commited on
Commit
·
317d82a
1
Parent(s):
74025f2
update
Browse files- modeling_mitre.py +10 -6
modeling_mitre.py
CHANGED
@@ -433,7 +433,6 @@ class MitreDecoder(MitrePreTrainedModel):
|
|
433 |
# ensure contiguous
|
434 |
expanded_src_tokens = self.check_contiguous(expanded_src_tokens)
|
435 |
source_tokens = self.check_contiguous(source_tokens)
|
436 |
-
src_length = source_tokens.shape[1]
|
437 |
|
438 |
# get embeds with positions for source tokens (input_ids + registers)
|
439 |
inputs_embeds = self.source_tokens_embedding_with_positions(expanded_src_tokens, total_token_nums, batch_indices, token_indices)
|
@@ -444,11 +443,11 @@ class MitreDecoder(MitrePreTrainedModel):
|
|
444 |
decoder_input_ids[:, 0] = source_tokens[:, -1]
|
445 |
|
446 |
tokens = torch.cat([source_tokens, decoder_input_ids], dim=1)
|
|
|
447 |
|
448 |
decoder_inputs_embeds = self.embed_tokens(decoder_input_ids)
|
449 |
decoder_inputs_embeds = decoder_inputs_embeds + self.tgt_embed_positions(decoder_input_ids, past_key_values_length, src_length=src_length)
|
450 |
-
|
451 |
-
# raise ValueError()
|
452 |
if past_key_values_length == 0:
|
453 |
hidden_states = torch.cat([inputs_embeds, decoder_inputs_embeds], dim=1)
|
454 |
else:
|
@@ -759,8 +758,8 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
|
|
759 |
return reordered_past
|
760 |
|
761 |
@staticmethod
|
762 |
-
def
|
763 |
-
return
|
764 |
|
765 |
@staticmethod
|
766 |
def _expand_inputs_for_generation(
|
@@ -844,6 +843,9 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
|
|
844 |
if past_key_values is not None:
|
845 |
decoder_input_ids_for_generation = decoder_input_ids[:, -1:]
|
846 |
attention_mask = registering_cache["attention_mask"]
|
|
|
|
|
|
|
847 |
if attention_mask is not None:
|
848 |
attention_mask = torch.cat((attention_mask, attention_mask[..., -1:]), dim=-1)
|
849 |
else:
|
@@ -907,7 +909,9 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
|
|
907 |
del outputs
|
908 |
|
909 |
past_key_values = self._reorder_cache(past_key_values, beam_idx)
|
910 |
-
registering_cache["register_nums"] = self.
|
|
|
|
|
911 |
|
912 |
cur_len = cur_len + 1
|
913 |
|
|
|
433 |
# ensure contiguous
|
434 |
expanded_src_tokens = self.check_contiguous(expanded_src_tokens)
|
435 |
source_tokens = self.check_contiguous(source_tokens)
|
|
|
436 |
|
437 |
# get embeds with positions for source tokens (input_ids + registers)
|
438 |
inputs_embeds = self.source_tokens_embedding_with_positions(expanded_src_tokens, total_token_nums, batch_indices, token_indices)
|
|
|
443 |
decoder_input_ids[:, 0] = source_tokens[:, -1]
|
444 |
|
445 |
tokens = torch.cat([source_tokens, decoder_input_ids], dim=1)
|
446 |
+
src_length = source_tokens.shape[1]
|
447 |
|
448 |
decoder_inputs_embeds = self.embed_tokens(decoder_input_ids)
|
449 |
decoder_inputs_embeds = decoder_inputs_embeds + self.tgt_embed_positions(decoder_input_ids, past_key_values_length, src_length=src_length)
|
450 |
+
|
|
|
451 |
if past_key_values_length == 0:
|
452 |
hidden_states = torch.cat([inputs_embeds, decoder_inputs_embeds], dim=1)
|
453 |
else:
|
|
|
758 |
return reordered_past
|
759 |
|
760 |
@staticmethod
|
761 |
+
def _reorder_register_cache(t, beam_idx):
|
762 |
+
return t.index_select(dim=0, index=beam_idx.to(t.device))
|
763 |
|
764 |
@staticmethod
|
765 |
def _expand_inputs_for_generation(
|
|
|
843 |
if past_key_values is not None:
|
844 |
decoder_input_ids_for_generation = decoder_input_ids[:, -1:]
|
845 |
attention_mask = registering_cache["attention_mask"]
|
846 |
+
# Get the mask when the first time using kv cache.
|
847 |
+
# After it, we can simply repeat 0. (the last column of mask) to get the next mask.
|
848 |
+
# As a result, we avoid generate the mask from scratch in kv cache and save memory.
|
849 |
if attention_mask is not None:
|
850 |
attention_mask = torch.cat((attention_mask, attention_mask[..., -1:]), dim=-1)
|
851 |
else:
|
|
|
909 |
del outputs
|
910 |
|
911 |
past_key_values = self._reorder_cache(past_key_values, beam_idx)
|
912 |
+
registering_cache["register_nums"] = self._reorder_register_cache(registering_cache["register_nums"], beam_idx)
|
913 |
+
if registering_cache["attention_mask"] is not None:
|
914 |
+
registering_cache["attention_mask"] = self._reorder_register_cache(registering_cache["attention_mask"], beam_idx)
|
915 |
|
916 |
cur_len = cur_len + 1
|
917 |
|