zhiqu22 commited on
Commit
74025f2
·
1 Parent(s): 17468e1

update attention in generate

Browse files
Files changed (1) hide show
  1. modeling_mitre.py +143 -53
modeling_mitre.py CHANGED
@@ -280,22 +280,48 @@ class MitreDecoder(MitrePreTrainedModel):
280
  registers = input_ids[range(batch_size), torch.argmax(input_ids, dim=-1)].unsqueeze(1).repeat(1, max_register_nums)
281
  return registers, register_nums, total_token_nums
282
 
283
- def combine_src_and_registers(self, input_ids, registers, register_nums, total_token_nums):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  '''
285
  return a expanded_src_tokens for positional embedding.
286
  '''
287
  pads = torch.full_like(registers, self.padding_idx)
288
  expanded_src_tokens = torch.cat((pads, input_ids, registers), dim=1)
289
- indices = torch.arange(total_token_nums).expand(input_ids.size(0), -1).to(input_ids.device)
290
- indices = indices + register_nums.unsqueeze(1)
291
-
292
- batch_indices = torch.arange(input_ids.shape[0]).unsqueeze(1).expand(-1, indices.size(1)).contiguous()
293
- return expanded_src_tokens, batch_indices, indices
 
 
 
 
 
 
 
 
294
 
295
  def fill_with_neg_inf(self, t):
296
  return t.float().fill_(float("-inf")).type_as(t)
 
 
 
297
 
298
- def build_future_mask(self, embeds, src_length, register_nums, padding_mask=None, past_key_values_length=0):
299
  b = register_nums.size(0)
300
  ns = src_length - register_nums
301
  if past_key_values_length == 0:
@@ -331,11 +357,6 @@ class MitreDecoder(MitrePreTrainedModel):
331
  batch_mask[batch_indices[target_indices], row_indices[target_indices], col_indices[target_indices]] = float('-inf')
332
  # shape: batch_size, head_num (1 for broadcasting), seq_len, seq_len
333
  batch_mask = batch_mask.unsqueeze(1)
334
- # 6. masking pads
335
- if padding_mask is not None:
336
- if padding_mask.any():
337
- padding_mask = padding_mask.to(batch_mask.device).unsqueeze(1).unsqueeze(2)
338
- batch_mask = batch_mask.masked_fill(padding_mask == 1, float('-inf'))
339
 
340
  elif past_key_values_length > 0:
341
  # in generation
@@ -350,7 +371,6 @@ class MitreDecoder(MitrePreTrainedModel):
350
  batch_mask[batch_indices[target_to_source_mask], token_indices[target_to_source_mask]] = float('-inf')
351
  batch_mask = batch_mask.unsqueeze(1)
352
 
353
- # ensure contiguous
354
  batch_mask = batch_mask.view(b, 1, batch_mask.shape[-2], batch_mask.shape[-1])
355
  return batch_mask
356
 
@@ -359,13 +379,12 @@ class MitreDecoder(MitrePreTrainedModel):
359
  self,
360
  input_ids: Optional[torch.Tensor] = None,
361
  decoder_input_ids: Optional[torch.Tensor] = None,
 
362
  past_key_values: Optional[List[torch.FloatTensor]] = None,
363
  use_cache: Optional[bool] = None,
364
- output_attentions: Optional[bool] = None,
365
  output_hidden_states: Optional[bool] = None,
366
  registering_cache: dict = None,
367
  ):
368
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
369
  output_hidden_states = (
370
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
371
  )
@@ -374,50 +393,98 @@ class MitreDecoder(MitrePreTrainedModel):
374
  # past_key_values_length
375
  past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
376
 
377
- decoder_input_shape = decoder_input_ids.size()
378
- decoder_input_ids = decoder_input_ids.view(-1, decoder_input_shape[-1])
379
- padding_mask = None
380
-
381
  if past_key_values_length > 0:
382
  register_nums = registering_cache["register_nums"]
383
  src_length = registering_cache["src_length"]
384
 
385
  if input_ids is not None and past_key_values_length == 0:
386
- # .view() additionally ensure that the memory is contiguous
387
- input_shape = input_ids.size()
388
- input_ids = input_ids.view(-1, input_shape[-1])
389
-
390
- registers, register_nums, total_token_nums = self.create_registers(input_ids)
391
- expanded_src_tokens, batch_indices, indices = self.combine_src_and_registers(input_ids, registers, register_nums, total_token_nums)
392
-
393
- # positional embedding for source tokens and registers
394
- inputs_embeds = self.embed_tokens(expanded_src_tokens)
395
- inputs_embeds_1 = inputs_embeds[:,:total_token_nums,:] + self.src_embed_positions(expanded_src_tokens[:,:total_token_nums])
396
- inputs_embeds_2 = inputs_embeds[:,total_token_nums:,:] + self.register_embed_positions(expanded_src_tokens[:,total_token_nums:])
397
- inputs_embeds = torch.cat((inputs_embeds_1, inputs_embeds_2), dim=1)
398
- inputs_embeds = inputs_embeds[batch_indices, indices]
399
 
 
 
 
 
 
 
 
 
 
400
 
401
- # padding mask
402
- source_tokens = expanded_src_tokens[batch_indices, indices]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  src_length = source_tokens.shape[1]
404
 
 
 
 
405
  # replace the inference trigger with langtok
406
  # namely, enc-tgt-dec-tgt strategy
407
  if decoder_input_ids[0][0].item() != source_tokens[0][-1].item():
408
  decoder_input_ids[:, 0] = source_tokens[:, -1]
409
 
410
  tokens = torch.cat([source_tokens, decoder_input_ids], dim=1)
411
- padding_mask = tokens.eq(self.padding_idx)
412
 
413
  decoder_inputs_embeds = self.embed_tokens(decoder_input_ids)
414
  decoder_inputs_embeds = decoder_inputs_embeds + self.tgt_embed_positions(decoder_input_ids, past_key_values_length, src_length=src_length)
 
 
415
  if past_key_values_length == 0:
416
  hidden_states = torch.cat([inputs_embeds, decoder_inputs_embeds], dim=1)
417
  else:
418
  hidden_states = decoder_inputs_embeds
419
 
420
- attention_mask = self.build_future_mask(hidden_states, src_length, register_nums, padding_mask, past_key_values_length)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
422
 
423
  if self.gradient_checkpointing and self.training:
@@ -429,8 +496,6 @@ class MitreDecoder(MitrePreTrainedModel):
429
 
430
  # decoder layers
431
  all_hidden_states = () if output_hidden_states else None
432
- all_self_attns = () if output_attentions else None
433
- all_cross_attentions = () if output_attentions else None
434
  next_decoder_cache = () if use_cache else None
435
 
436
  for idx, decoder_layer in enumerate(self.layers):
@@ -458,7 +523,16 @@ class MitreDecoder(MitrePreTrainedModel):
458
  hidden_states = layer_outputs[0]
459
 
460
  if use_cache:
461
- next_decoder_cache += (layer_outputs[1],)
 
 
 
 
 
 
 
 
 
462
 
463
  if past_key_values_length == 0:
464
  hidden_states = hidden_states[:,src_length:,:]
@@ -475,13 +549,19 @@ class MitreDecoder(MitrePreTrainedModel):
475
  last_hidden_state=hidden_states,
476
  past_key_values=next_cache,
477
  hidden_states=all_hidden_states,
478
- attentions=all_self_attns,
479
- cross_attentions=all_cross_attentions,
480
  )
481
- model_output.registering_cache = {
482
- "register_nums": register_nums,
483
- "src_length": src_length
484
- }
 
 
 
 
 
 
 
 
485
  return model_output
486
 
487
 
@@ -579,6 +659,7 @@ class MitreModel(MitrePreTrainedModel):
579
  self,
580
  input_ids: Optional[torch.LongTensor] = None,
581
  decoder_input_ids: Optional[torch.Tensor] = None,
 
582
  past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
583
  use_cache: Optional[bool] = None,
584
  output_attentions: Optional[bool] = None,
@@ -594,6 +675,7 @@ class MitreModel(MitrePreTrainedModel):
594
  decoder_outputs = self.decoder(
595
  input_ids=input_ids,
596
  decoder_input_ids=decoder_input_ids,
 
597
  past_key_values=past_key_values,
598
  use_cache=use_cache,
599
  output_hidden_states=output_hidden_states,
@@ -634,15 +716,18 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
634
  self,
635
  input_ids: Optional[torch.LongTensor] = None,
636
  decoder_input_ids: Optional[torch.LongTensor] = None,
 
637
  past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
638
  labels: Optional[torch.LongTensor] = None,
639
  use_cache: Optional[bool] = None,
640
  output_hidden_states: Optional[bool] = None,
641
  registering_cache: dict = None,
642
  ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
 
643
  outputs = self.model(
644
  input_ids=input_ids,
645
  decoder_input_ids=decoder_input_ids,
 
646
  past_key_values=past_key_values,
647
  use_cache=use_cache,
648
  output_hidden_states=output_hidden_states,
@@ -707,11 +792,6 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
707
  although beamscorer generates eos to the sequence, the sequence is filled by 'pad(1)'.
708
  As a result, the sequence, which has already finished, will be computed by the model
709
  continuously. We plan to remove the finished token as Fairseq's style.
710
- 2. build self-attention mask.
711
- Current building happens within the model. Thus, when running beam search, we have to
712
- create a mask whose size is (beam_size * batch_size) from scratch. If we create the mask
713
- outside of the model, we can create the mask by duplicating beam_size times.
714
- Moreover, we can prepare a cache of mask in beam search to avoid create mask many times.
715
  """
716
  if generation_config != None:
717
  assert type(generation_config) is GenerationConfig
@@ -746,12 +826,12 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
746
 
747
  input_ids = self._expand_inputs_for_generation(input_ids, beam_size)
748
  decoder_input_ids = self._expand_inputs_for_generation(decoder_input_ids, beam_size)
749
- # decoder_input_ids.to(device)
750
  cur_len = decoder_input_ids.shape[1]
751
 
752
  this_peer_finished = False
753
  past_key_values = None
754
- registering_cache = None
 
755
 
756
  logits_processor = LogitsProcessorList()
757
  stopping_criteria = StoppingCriteriaList()
@@ -763,10 +843,20 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
763
 
764
  if past_key_values is not None:
765
  decoder_input_ids_for_generation = decoder_input_ids[:, -1:]
 
 
 
766
  else:
767
  decoder_input_ids_for_generation = decoder_input_ids
768
 
769
- outputs = self(input_ids, decoder_input_ids_for_generation, past_key_values=past_key_values, use_cache=True, registering_cache=registering_cache)
 
 
 
 
 
 
 
770
 
771
  del input_ids
772
  input_ids = None
 
280
  registers = input_ids[range(batch_size), torch.argmax(input_ids, dim=-1)].unsqueeze(1).repeat(1, max_register_nums)
281
  return registers, register_nums, total_token_nums
282
 
283
+ def get_token_indices(self, input_ids, total_token_nums, register_nums):
284
+ '''
285
+ return a token_indices for selecting source tokens from expanded_src_tokens
286
+ '''
287
+ token_indices = torch.arange(total_token_nums).expand(input_ids.size(0), -1).to(input_ids.device)
288
+ token_indices = token_indices + register_nums.unsqueeze(1)
289
+ return token_indices
290
+
291
+ def get_batch_indices(self, input_ids, token_indices):
292
+ '''
293
+ return a batch_indices for selecting source tokens from expanded_src_tokens
294
+ '''
295
+ batch_indices = torch.arange(input_ids.shape[0]).unsqueeze(1).expand(-1, token_indices.size(1)).contiguous()
296
+ return batch_indices
297
+
298
+ def combine_src_and_registers(self, input_ids, registers):
299
  '''
300
  return a expanded_src_tokens for positional embedding.
301
  '''
302
  pads = torch.full_like(registers, self.padding_idx)
303
  expanded_src_tokens = torch.cat((pads, input_ids, registers), dim=1)
304
+ return expanded_src_tokens
305
+
306
+ def source_tokens_embedding_with_positions(self, expanded_src_tokens, total_token_nums, batch_indices, indices):
307
+ '''
308
+ return the embeds of source tokens
309
+ '''
310
+ inputs_embeds = self.embed_tokens(expanded_src_tokens)
311
+ inputs_embeds_1 = inputs_embeds[:,:total_token_nums,:] + self.src_embed_positions(expanded_src_tokens[:,:total_token_nums])
312
+ inputs_embeds_2 = inputs_embeds[:,total_token_nums:,:] + self.register_embed_positions(expanded_src_tokens[:,total_token_nums:])
313
+ inputs_embeds = torch.cat((inputs_embeds_1, inputs_embeds_2), dim=1)
314
+ inputs_embeds = inputs_embeds[batch_indices, indices]
315
+
316
+ return inputs_embeds
317
 
318
  def fill_with_neg_inf(self, t):
319
  return t.float().fill_(float("-inf")).type_as(t)
320
+
321
+ def check_contiguous(self, t: torch.Tensor):
322
+ return t if t.is_contiguous() else t.contiguous()
323
 
324
+ def build_future_mask(self, embeds, src_length, register_nums, past_key_values_length=0):
325
  b = register_nums.size(0)
326
  ns = src_length - register_nums
327
  if past_key_values_length == 0:
 
357
  batch_mask[batch_indices[target_indices], row_indices[target_indices], col_indices[target_indices]] = float('-inf')
358
  # shape: batch_size, head_num (1 for broadcasting), seq_len, seq_len
359
  batch_mask = batch_mask.unsqueeze(1)
 
 
 
 
 
360
 
361
  elif past_key_values_length > 0:
362
  # in generation
 
371
  batch_mask[batch_indices[target_to_source_mask], token_indices[target_to_source_mask]] = float('-inf')
372
  batch_mask = batch_mask.unsqueeze(1)
373
 
 
374
  batch_mask = batch_mask.view(b, 1, batch_mask.shape[-2], batch_mask.shape[-1])
375
  return batch_mask
376
 
 
379
  self,
380
  input_ids: Optional[torch.Tensor] = None,
381
  decoder_input_ids: Optional[torch.Tensor] = None,
382
+ attention_mask: Optional[torch.Tensor] = None,
383
  past_key_values: Optional[List[torch.FloatTensor]] = None,
384
  use_cache: Optional[bool] = None,
 
385
  output_hidden_states: Optional[bool] = None,
386
  registering_cache: dict = None,
387
  ):
 
388
  output_hidden_states = (
389
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
390
  )
 
393
  # past_key_values_length
394
  past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
395
 
 
 
 
 
396
  if past_key_values_length > 0:
397
  register_nums = registering_cache["register_nums"]
398
  src_length = registering_cache["src_length"]
399
 
400
  if input_ids is not None and past_key_values_length == 0:
401
+ # ensure contiguous
402
+ input_ids = self.check_contiguous(input_ids)
403
+ decoder_input_ids = self.check_contiguous(decoder_input_ids)
 
 
 
 
 
 
 
 
 
 
404
 
405
+ if attention_mask is None:
406
+ # create registers from input_ids
407
+ registers, register_nums, total_token_nums = self.create_registers(input_ids)
408
+ # 'expanded_src_tokens' is combined by input_ids, registers, and pads.
409
+ expanded_src_tokens = self.combine_src_and_registers(input_ids, registers)
410
+ token_indices = self.get_token_indices(input_ids, total_token_nums, register_nums)
411
+ batch_indices = self.get_batch_indices(input_ids, token_indices)
412
+ # source tokens (input_ids + registers)
413
+ source_tokens = expanded_src_tokens[batch_indices, token_indices]
414
 
415
+ else:
416
+ # although we do not give the attention mask in training and the 1st step of generation,
417
+ # we still leave this block here.
418
+ if registering_cache is None or \
419
+ not all(key in registering_cache for key in \
420
+ ("register_nums", "total_token_nums", "expanded_src_tokens",\
421
+ "batch_indices", "token_indices", "source_tokens")):
422
+ raise ValueError(
423
+ "If you generate registers by external codes, \
424
+ you must provide 'register_nums', 'total_token_nums', \
425
+ 'expanded_src_tokens', 'batch_indices', 'token_indices' \
426
+ and 'source_tokens' in 'registering_cache' in the training."
427
+ )
428
+ register_nums, total_token_nums = registering_cache["register_nums"], registering_cache["total_token_nums"]
429
+ expanded_src_tokens = registering_cache["expanded_src_tokens"]
430
+ batch_indices, token_indices = registering_cache["batch_indices"], registering_cache["token_indices"]
431
+ source_tokens = registering_cache["source_tokens"]
432
+
433
+ # ensure contiguous
434
+ expanded_src_tokens = self.check_contiguous(expanded_src_tokens)
435
+ source_tokens = self.check_contiguous(source_tokens)
436
  src_length = source_tokens.shape[1]
437
 
438
+ # get embeds with positions for source tokens (input_ids + registers)
439
+ inputs_embeds = self.source_tokens_embedding_with_positions(expanded_src_tokens, total_token_nums, batch_indices, token_indices)
440
+
441
  # replace the inference trigger with langtok
442
  # namely, enc-tgt-dec-tgt strategy
443
  if decoder_input_ids[0][0].item() != source_tokens[0][-1].item():
444
  decoder_input_ids[:, 0] = source_tokens[:, -1]
445
 
446
  tokens = torch.cat([source_tokens, decoder_input_ids], dim=1)
 
447
 
448
  decoder_inputs_embeds = self.embed_tokens(decoder_input_ids)
449
  decoder_inputs_embeds = decoder_inputs_embeds + self.tgt_embed_positions(decoder_input_ids, past_key_values_length, src_length=src_length)
450
+ # if past_key_values_length > 0:
451
+ # raise ValueError()
452
  if past_key_values_length == 0:
453
  hidden_states = torch.cat([inputs_embeds, decoder_inputs_embeds], dim=1)
454
  else:
455
  hidden_states = decoder_inputs_embeds
456
 
457
+ # ensure contiguous
458
+ hidden_states = self.check_contiguous(hidden_states)
459
+
460
+ # if attention_mask is NOT given, we build the attention mask from current hyperparams
461
+ # if attention_mask is given, check the shape of attention mask
462
+ if attention_mask is None:
463
+ attention_mask = self.build_future_mask(hidden_states, src_length, register_nums, past_key_values_length)
464
+ else:
465
+ bsz, src_len = hidden_states.shape[0], hidden_states.shape[1]
466
+ tgt_len = hidden_states.shape[1] if past_key_values_length == 0 else past_key_values_length + 1
467
+ if attention_mask.size() != (bsz, 1, src_len, tgt_len):
468
+ raise ValueError(
469
+ f"Attention mask should be of size {(bsz, 1, src_len, tgt_len)}, but is {attention_mask.size()}"
470
+ )
471
+
472
+ # ensure contiguous
473
+ attention_mask = self.check_contiguous(attention_mask)
474
+
475
+ # this is a param to turncate kv cache
476
+ # in training, it's None, namely, unactivated.
477
+ max_register_num = None
478
+ # masking pads for attention_mask in the training or the 1st step of generation
479
+ if past_key_values_length == 0:
480
+ # if in generation, activate
481
+ max_register_num = register_nums.max().item() if use_cache else None
482
+
483
+ padding_mask = tokens.eq(self.padding_idx)
484
+ if padding_mask.any():
485
+ padding_mask = padding_mask.unsqueeze(1).unsqueeze(2)
486
+ attention_mask = attention_mask.masked_fill(padding_mask == 1, float('-inf'))
487
+
488
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
489
 
490
  if self.gradient_checkpointing and self.training:
 
496
 
497
  # decoder layers
498
  all_hidden_states = () if output_hidden_states else None
 
 
499
  next_decoder_cache = () if use_cache else None
500
 
501
  for idx, decoder_layer in enumerate(self.layers):
 
523
  hidden_states = layer_outputs[0]
524
 
525
  if use_cache:
526
+ if past_key_values_length > 0:
527
+ next_decoder_cache += (layer_outputs[1],)
528
+ else:
529
+ cache_key, cache_value = layer_outputs[1]
530
+ clipped_rep = (
531
+ cache_key[:, :, src_length - max_register_num:, :],
532
+ cache_value[:, :, src_length - max_register_num:, :]
533
+ )
534
+ next_decoder_cache += (clipped_rep,)
535
+
536
 
537
  if past_key_values_length == 0:
538
  hidden_states = hidden_states[:,src_length:,:]
 
549
  last_hidden_state=hidden_states,
550
  past_key_values=next_cache,
551
  hidden_states=all_hidden_states,
 
 
552
  )
553
+
554
+ # the registering cache used in generation
555
+ # in the 1st step, we turncate the kv cache to save cost, so we have to change the src_length
556
+ if use_cache:
557
+ model_output.registering_cache = {
558
+ "register_nums": register_nums,
559
+ "src_length": src_length if past_key_values_length > 0 else max_register_num,
560
+ "attention_mask": attention_mask if past_key_values_length > 0 else None
561
+ }
562
+ else:
563
+ model_output.registering_cache = None
564
+
565
  return model_output
566
 
567
 
 
659
  self,
660
  input_ids: Optional[torch.LongTensor] = None,
661
  decoder_input_ids: Optional[torch.Tensor] = None,
662
+ attention_mask: Optional[torch.Tensor] = None,
663
  past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
664
  use_cache: Optional[bool] = None,
665
  output_attentions: Optional[bool] = None,
 
675
  decoder_outputs = self.decoder(
676
  input_ids=input_ids,
677
  decoder_input_ids=decoder_input_ids,
678
+ attention_mask=attention_mask,
679
  past_key_values=past_key_values,
680
  use_cache=use_cache,
681
  output_hidden_states=output_hidden_states,
 
716
  self,
717
  input_ids: Optional[torch.LongTensor] = None,
718
  decoder_input_ids: Optional[torch.LongTensor] = None,
719
+ attention_mask: Optional[torch.Tensor] = None,
720
  past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
721
  labels: Optional[torch.LongTensor] = None,
722
  use_cache: Optional[bool] = None,
723
  output_hidden_states: Optional[bool] = None,
724
  registering_cache: dict = None,
725
  ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
726
+
727
  outputs = self.model(
728
  input_ids=input_ids,
729
  decoder_input_ids=decoder_input_ids,
730
+ attention_mask=attention_mask,
731
  past_key_values=past_key_values,
732
  use_cache=use_cache,
733
  output_hidden_states=output_hidden_states,
 
792
  although beamscorer generates eos to the sequence, the sequence is filled by 'pad(1)'.
793
  As a result, the sequence, which has already finished, will be computed by the model
794
  continuously. We plan to remove the finished token as Fairseq's style.
 
 
 
 
 
795
  """
796
  if generation_config != None:
797
  assert type(generation_config) is GenerationConfig
 
826
 
827
  input_ids = self._expand_inputs_for_generation(input_ids, beam_size)
828
  decoder_input_ids = self._expand_inputs_for_generation(decoder_input_ids, beam_size)
 
829
  cur_len = decoder_input_ids.shape[1]
830
 
831
  this_peer_finished = False
832
  past_key_values = None
833
+ registering_cache= None
834
+ attention_mask = None
835
 
836
  logits_processor = LogitsProcessorList()
837
  stopping_criteria = StoppingCriteriaList()
 
843
 
844
  if past_key_values is not None:
845
  decoder_input_ids_for_generation = decoder_input_ids[:, -1:]
846
+ attention_mask = registering_cache["attention_mask"]
847
+ if attention_mask is not None:
848
+ attention_mask = torch.cat((attention_mask, attention_mask[..., -1:]), dim=-1)
849
  else:
850
  decoder_input_ids_for_generation = decoder_input_ids
851
 
852
+ outputs = self(
853
+ input_ids,
854
+ decoder_input_ids_for_generation,
855
+ attention_mask=attention_mask,
856
+ past_key_values=past_key_values,
857
+ use_cache=True,
858
+ registering_cache=registering_cache
859
+ )
860
 
861
  del input_ids
862
  input_ids = None