# coding=utf-8 import math from typing import List, Optional, Tuple, Union, Dict, Any import torch from torch import nn from .configuration_mitre import MitreConfig from transformers.utils import logging from transformers.generation import GenerationMixin from transformers.modeling_utils import PreTrainedModel from transformers.activations import ACT2FN from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.fsdp import is_fsdp_managed_module from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, ) from transformers.generation.configuration_utils import GenerationConfig from transformers.generation.beam_search import BeamSearchScorer from transformers.generation.logits_process import LogitsProcessorList from transformers.generation.stopping_criteria import StoppingCriteriaList logger = logging.get_logger(__name__) def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): """ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols are ignored. This is modified from fairseq's `utils.make_positions`. """ mask = input_ids.ne(padding_idx).int() incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask return incremental_indices.long() + padding_idx # Modified from transformers.models.m2m_100.modeling_m2m_100.M2M100Attention # and transformers.models.m2m_100.modeling_m2m_100.M2M100SdpaAttention class MitreSdpaAttention(nn.Module): def __init__( self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, config: Optional[MitreConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f" and `num_heads`: {num_heads})." ) self.scaling = self.head_dim**-0.5 self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Input shape: Batch x Time x Channel Output objects: attn_output, attn_weights (always be None), past_key_value """ """ 1. MitreModel is using MitreSdpaAttention, which is modifed from M2M100SdpaAttention. Notabley, both of them do not support `output_attentions=True` or `layer_head_mask` not None, leading to 'attn_weights' always being None in output. The plan of improving this point has a low priority. 2. We plan to improve this code with Flash Attention v2. """ bsz, tgt_len, _ = hidden_states.size() # get query proj query_states = self.q_proj(hidden_states) if past_key_value is not None: # reuse k, v, self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) past_key_value = (key_states, value_states) query_states = self._shape(query_states, tgt_len, bsz) attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=False, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned across GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, None, past_key_value # Modified from transformers.models.m2m_100.modeling_m2m100.M2M100DecoderLayer class MitreDecoderLayer(nn.Module): def __init__(self, config: MitreConfig): super().__init__() self.embed_dim = config.d_model self.self_attn = MitreSdpaAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, config=config, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, use_cache: Optional[bool] = True, ) -> torch.Tensor: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None # add present self-attn cache to positions 1,2 of present_key_value tuple hidden_states, _, present_key_value = self.self_attn( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states outputs = (hidden_states,) if use_cache: outputs += (present_key_value,) return outputs class MitrePreTrainedModel(PreTrainedModel): config_class = MitreConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["MitreDecoderLayer"] # we plan to implement codes for falsh attention v2 _supports_flash_attn_2 = False _supports_sdpa = True def _init_weights(self, module): std = self.config.init_std if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class MitreDecoder(MitrePreTrainedModel): """ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MitreDecoderLayer`] Args: config: MitreConfig embed_tokens (nn.Embedding): output embedding """ def __init__(self, config: MitreConfig): super().__init__(config) self.dropout = config.dropout self.padding_idx = config.pad_token_id self.max_target_positions = config.max_position_embeddings embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 self.embed_tokens = MitreScaledWordEmbedding( config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale ) self.src_embed_positions = MitreSinusoidalPositionalEmbedding( config.max_position_embeddings, config.d_model, self.padding_idx, ) self.register_embed_positions = MitreSinusoidalPositionalEmbedding( config.max_position_embeddings, config.d_model, self.padding_idx, ) self.tgt_embed_positions = MitreSinusoidalPositionalEmbedding( config.max_position_embeddings, config.d_model, self.padding_idx, ) self.layers = nn.ModuleList([MitreDecoderLayer(config) for _ in range(config.decoder_layers)]) if config._attn_implementation != "sdpa": raise NotImplementedError("Other attention mechanism are not implemented yet.") # TODO implement flash atten v2 for MITRE # self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self._use_sdpa = config._attn_implementation == "sdpa" self.layer_norm = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False self._future_mask = torch.empty(0) # Initialize weights and apply final processing self.post_init() def create_registers(self, input_ids): ''' create registers by duplicating the language tag respective to each sentence. length(registers) = length(real_tokens) = length(tokens) - length(pads) ''' register_nums = (~input_ids.eq(self.padding_idx)).sum(dim=1) max_register_nums = register_nums.max().item() total_token_nums = input_ids.size(1) + max_register_nums batch_size = input_ids.size(0) registers = input_ids[range(batch_size), torch.argmax(input_ids, dim=-1)].unsqueeze(1).repeat(1, max_register_nums) return registers, register_nums, total_token_nums def get_token_indices(self, input_ids, total_token_nums, register_nums): ''' return a token_indices for selecting source tokens from expanded_src_tokens ''' token_indices = torch.arange(total_token_nums).expand(input_ids.size(0), -1).to(input_ids.device) token_indices = token_indices + register_nums.unsqueeze(1) return token_indices def get_batch_indices(self, input_ids, token_indices): ''' return a batch_indices for selecting source tokens from expanded_src_tokens ''' batch_indices = torch.arange(input_ids.shape[0]).unsqueeze(1).expand(-1, token_indices.size(1)).contiguous() return batch_indices def combine_src_and_registers(self, input_ids, registers): ''' return a expanded_src_tokens for positional embedding. ''' pads = torch.full_like(registers, self.padding_idx) expanded_src_tokens = torch.cat((pads, input_ids, registers), dim=1) return expanded_src_tokens def source_tokens_embedding_with_positions(self, expanded_src_tokens, total_token_nums, batch_indices, indices): ''' return the embeds of source tokens ''' inputs_embeds = self.embed_tokens(expanded_src_tokens) inputs_embeds_1 = inputs_embeds[:,:total_token_nums,:] + self.src_embed_positions(expanded_src_tokens[:,:total_token_nums]) inputs_embeds_2 = inputs_embeds[:,total_token_nums:,:] + self.register_embed_positions(expanded_src_tokens[:,total_token_nums:]) inputs_embeds = torch.cat((inputs_embeds_1, inputs_embeds_2), dim=1) inputs_embeds = inputs_embeds[batch_indices, indices] return inputs_embeds def fill_with_neg_inf(self, t): return t.float().fill_(float("-inf")).type_as(t) def check_contiguous(self, t: torch.Tensor): return t if t.is_contiguous() else t.contiguous() def build_future_mask(self, embeds, src_length, register_nums, past_key_values_length=0): b = register_nums.size(0) ns = src_length - register_nums if past_key_values_length == 0: # in training # 1. create mask by cache dim = embeds.size(1) if ( self._future_mask.size(0) == 0 or self._future_mask.size(0) < dim ): self._future_mask = torch.triu(self.fill_with_neg_inf(torch.zeros([dim, dim])), 1) if self._future_mask.device == embeds.device: mask = self._future_mask[:dim, :dim].clone() else: mask = self._future_mask[:dim, :dim].to(embeds, copy=True) # 2. bi-directional attention in source tokens and registers mask[ :src_length, :src_length] = 0. # 3. create batch mask batch_mask = mask.unsqueeze(0).expand(b, -1, -1).clone().contiguous() # 4. mask source tokens -> registers # 5. mask target -> source tokens batch_indices = torch.arange(b).to(batch_mask.device).view(-1, 1, 1).expand(b, dim, dim).contiguous() row_indices = torch.arange(dim).to(batch_mask.device).view(1, -1, 1).expand(b, dim, dim).contiguous() col_indices = torch.arange(dim).to(batch_mask.device).view(1, 1, -1).expand(b, dim, dim).contiguous() source_indices = (row_indices < ns.view(-1, 1, 1)) & (col_indices >= ns.view(-1, 1, 1)) & (col_indices < (ns + register_nums).view(-1, 1, 1)).contiguous() target_indices = (row_indices >= (ns + register_nums).view(-1, 1, 1)) & (col_indices < ns.view(-1, 1, 1)).contiguous() # 4 batch_mask[batch_indices[source_indices], row_indices[source_indices], col_indices[source_indices]] = float('-inf') # 5 batch_mask[batch_indices[target_indices], row_indices[target_indices], col_indices[target_indices]] = float('-inf') # shape: batch_size, head_num (1 for broadcasting), seq_len, seq_len batch_mask = batch_mask.unsqueeze(1) elif past_key_values_length > 0: # in generation mask = torch.zeros(past_key_values_length + 1) mask = mask.to(embeds, copy=True) batch_mask = mask.unsqueeze(0).expand(b, -1).clone().contiguous() batch_indices = torch.arange(b).view(-1, 1).expand(b, past_key_values_length + 1).to(batch_mask.device) token_indices = torch.arange(past_key_values_length + 1).view(1, -1).expand(b, past_key_values_length + 1).to(batch_mask.device) target_to_source_mask = token_indices < ns.view(-1, 1) batch_mask[batch_indices[target_to_source_mask], token_indices[target_to_source_mask]] = float('-inf') batch_mask = batch_mask.unsqueeze(1) batch_mask = batch_mask.view(b, 1, batch_mask.shape[-2], batch_mask.shape[-1]) return batch_mask def forward( self, input_ids: Optional[torch.Tensor] = None, decoder_input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, registering_cache: dict = None, ): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if past_key_values_length > 0: register_nums = registering_cache["register_nums"] src_length = registering_cache["src_length"] if input_ids is not None and past_key_values_length == 0: # ensure contiguous input_ids = self.check_contiguous(input_ids) decoder_input_ids = self.check_contiguous(decoder_input_ids) if attention_mask is None: # create registers from input_ids registers, register_nums, total_token_nums = self.create_registers(input_ids) # 'expanded_src_tokens' is combined by input_ids, registers, and pads. expanded_src_tokens = self.combine_src_and_registers(input_ids, registers) token_indices = self.get_token_indices(input_ids, total_token_nums, register_nums) batch_indices = self.get_batch_indices(input_ids, token_indices) # source tokens (input_ids + registers) source_tokens = expanded_src_tokens[batch_indices, token_indices] else: # although we do not give the attention mask in training and the 1st step of generation, # we still leave this block here. if registering_cache is None or \ not all(key in registering_cache for key in \ ("register_nums", "total_token_nums", "expanded_src_tokens",\ "batch_indices", "token_indices", "source_tokens")): raise ValueError( "If you generate registers by external codes, \ you must provide 'register_nums', 'total_token_nums', \ 'expanded_src_tokens', 'batch_indices', 'token_indices' \ and 'source_tokens' in 'registering_cache' in the training." ) register_nums, total_token_nums = registering_cache["register_nums"], registering_cache["total_token_nums"] expanded_src_tokens = registering_cache["expanded_src_tokens"] batch_indices, token_indices = registering_cache["batch_indices"], registering_cache["token_indices"] source_tokens = registering_cache["source_tokens"] # ensure contiguous expanded_src_tokens = self.check_contiguous(expanded_src_tokens) source_tokens = self.check_contiguous(source_tokens) src_length = source_tokens.shape[1] # get embeds with positions for source tokens (input_ids + registers) inputs_embeds = self.source_tokens_embedding_with_positions(expanded_src_tokens, total_token_nums, batch_indices, token_indices) # replace the inference trigger with langtok # namely, enc-tgt-dec-tgt strategy if decoder_input_ids[0][0].item() != source_tokens[0][-1].item(): decoder_input_ids[:, 0] = source_tokens[:, -1] tokens = torch.cat([source_tokens, decoder_input_ids], dim=1) decoder_inputs_embeds = self.embed_tokens(decoder_input_ids) decoder_inputs_embeds = decoder_inputs_embeds + self.tgt_embed_positions(decoder_input_ids, past_key_values_length, src_length=src_length) # if past_key_values_length > 0: # raise ValueError() if past_key_values_length == 0: hidden_states = torch.cat([inputs_embeds, decoder_inputs_embeds], dim=1) else: hidden_states = decoder_inputs_embeds # ensure contiguous hidden_states = self.check_contiguous(hidden_states) # if attention_mask is NOT given, we build the attention mask from current hyperparams # if attention_mask is given, check the shape of attention mask if attention_mask is None: attention_mask = self.build_future_mask(hidden_states, src_length, register_nums, past_key_values_length) else: bsz, src_len = hidden_states.shape[0], hidden_states.shape[1] tgt_len = hidden_states.shape[1] if past_key_values_length == 0 else past_key_values_length + 1 if attention_mask.size() != (bsz, 1, src_len, tgt_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, src_len, tgt_len)}, but is {attention_mask.size()}" ) # ensure contiguous attention_mask = self.check_contiguous(attention_mask) # this is a param to turncate kv cache # in training, it's None, namely, unactivated. max_register_num = None # masking pads for attention_mask in the training or the 1st step of generation if past_key_values_length == 0: # if in generation, activate max_register_num = register_nums.max().item() if use_cache else None padding_mask = tokens.eq(self.padding_idx) if padding_mask.any(): padding_mask = padding_mask.unsqueeze(1).unsqueeze(2) attention_mask = attention_mask.masked_fill(padding_mask == 1, float('-inf')) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting" " `use_cache=False`..." ) use_cache = False # decoder layers all_hidden_states = () if output_hidden_states else None next_decoder_cache = () if use_cache else None for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, past_key_value=None, use_cache=use_cache, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, past_key_value=past_key_value, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: if past_key_values_length > 0: next_decoder_cache += (layer_outputs[1],) else: cache_key, cache_value = layer_outputs[1] clipped_rep = ( cache_key[:, :, src_length - max_register_num:, :], cache_value[:, :, src_length - max_register_num:, :] ) next_decoder_cache += (clipped_rep,) if past_key_values_length == 0: hidden_states = hidden_states[:,src_length:,:] hidden_states = self.layer_norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None model_output = BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, ) # the registering cache used in generation # in the 1st step, we turncate the kv cache to save cost, so we have to change the src_length if use_cache: model_output.registering_cache = { "register_nums": register_nums, "src_length": src_length if past_key_values_length > 0 else max_register_num, "attention_mask": attention_mask if past_key_values_length > 0 else None } else: model_output.registering_cache = None return model_output # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ScaledWordEmbedding class MitreScaledWordEmbedding(nn.Embedding): """ This module overrides nn.Embeddings' forward by multiplying with embeddings scale. """ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): super().__init__(num_embeddings, embedding_dim, padding_idx) self.embed_scale = embed_scale def forward(self, input_ids: torch.Tensor): return super().forward(input_ids) * self.embed_scale class MitreSinusoidalPositionalEmbedding(nn.Module): """This module produces sinusoidal positional embeddings of any length.""" def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): super().__init__() self.offset = 2 self.embedding_dim = embedding_dim self.padding_idx = padding_idx self.make_weights(num_positions + self.offset, embedding_dim, padding_idx) def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx) if hasattr(self, "weights"): # in forward put the weights on the correct dtype and device of the param emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) self.register_buffer("weights", emb_weights, persistent=False) @staticmethod def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): """ Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of "Attention Is All You Need". """ half_dim = embedding_dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) if embedding_dim % 2 == 1: # zero pad emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) if padding_idx is not None: emb[padding_idx, :] = 0 return emb.to(torch.get_default_dtype()) @torch.no_grad() def forward( self, input_ids: torch.Tensor = None, past_key_values_length: int = 0, src_length: int = 0 ): bsz, seq_len = input_ids.size() # Create the position ids from the input token ids. Any padded tokens remain padded. position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to( input_ids.device ) if past_key_values_length > 0 and src_length > 0: position_ids = torch.where(position_ids == 1, position_ids, position_ids - src_length) # expand embeddings if needed max_pos = self.padding_idx + 1 + seq_len + past_key_values_length if max_pos > self.weights.size(0): self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach() class MitreModel(MitrePreTrainedModel): _tied_weights_keys = ["decoder.embed_tokens.weight"] def __init__(self, config: MitreConfig): super().__init__(config) self.decoder = MitreDecoder(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.decoder.embed_tokens def get_decoder(self): return self.decoder def forward( self, input_ids: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, registering_cache: dict = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache decoder_outputs = self.decoder( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states, registering_cache=registering_cache ) model_output = Seq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, ) model_output.registering_cache = decoder_outputs.registering_cache return model_output class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin): base_model_prefix = "model" _tied_weights_keys = ["decoder.embed_tokens.weight", "lm_head.weight"] def __init__(self, config: MitreConfig): super().__init__(config) self.model = MitreModel(config) self.lm_head = nn.Linear(config.d_model, self.model.decoder.embed_tokens.num_embeddings, bias=False) # Initialize weights and apply final processing self.post_init() def get_decoder(self): return self.model.get_decoder() def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def forward( self, input_ids: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, registering_cache: dict = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: outputs = self.model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states, registering_cache=registering_cache, ) lm_logits = self.lm_head(outputs[0]) if labels is not None: raise NotImplementedError("Please implement your loss function here.") model_output = Seq2SeqLMOutput( loss=None, logits=lm_logits, past_key_values=outputs.past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, ) model_output.registering_cache = outputs.registering_cache return model_output @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: reordered_past += ( tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), ) return reordered_past @staticmethod def _reorder_register_nums(register_nums, beam_idx): return register_nums.index_select(0, beam_idx.to(register_nums.device)) @staticmethod def _expand_inputs_for_generation( input_ids: Optional[torch.LongTensor] = None, beam_size: int = 1, ) -> torch.LongTensor: """ Expands input_ids from [batch_size, len(tokens)] to [batch_size * expand_size, , len(tokens)] This is simplified from 'transformers.generation.utils.GenerationMixin._expand_inputs_for_generation' """ if beam_size == 1: return input_ids return input_ids.repeat_interleave(beam_size, dim=0) def generate(self, input_ids: Optional[torch.Tensor] = None, generation_config: Optional[GenerationConfig] = None, **kwargs: Dict ): """ Inference with beam search. This code is simplified from 'transformers.generation.utils.GenerationMixin.generate'. This code follows the style of m2m and nllb. Therefore, there are two points need improvement. TODO 1. early_stop in beam search. Current early_stop is at the beam search level instead of model level. Specficially, although beamscorer generates eos to the sequence, the sequence is filled by 'pad(1)'. As a result, the sequence, which has already finished, will be computed by the model continuously. We plan to remove the finished token as Fairseq's style. """ if generation_config != None: assert type(generation_config) is GenerationConfig self.generation_config = generation_config self.generation_config.update(**kwargs) generation_config = self.generation_config batch_size = input_ids.shape[0] beam_size = generation_config.num_beams device = input_ids.device max_cache_length = generation_config.max_length eos_token_id = torch.Tensor([generation_config.eos_token_id]) # initial the target tokens decoder_input_ids = torch.full( (batch_size, 1), self.generation_config.decoder_start_token_id, dtype=input_ids.dtype, device=device ) beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=beam_size, device=device, length_penalty=self.generation_config.length_penalty, do_early_stopping=self.generation_config.early_stopping, num_beam_hyps_to_keep=self.generation_config.num_return_sequences, max_length=max_cache_length, ) input_ids = self._expand_inputs_for_generation(input_ids, beam_size) decoder_input_ids = self._expand_inputs_for_generation(decoder_input_ids, beam_size) cur_len = decoder_input_ids.shape[1] this_peer_finished = False past_key_values = None registering_cache= None attention_mask = None logits_processor = LogitsProcessorList() stopping_criteria = StoppingCriteriaList() beam_scores = torch.zeros((batch_size, beam_size), dtype=torch.float, device=input_ids.device) beam_scores[:, 1:] = -1e9 beam_scores = beam_scores.view((batch_size * beam_size,)) while not this_peer_finished: if past_key_values is not None: decoder_input_ids_for_generation = decoder_input_ids[:, -1:] attention_mask = registering_cache["attention_mask"] if attention_mask is not None: attention_mask = torch.cat((attention_mask, attention_mask[..., -1:]), dim=-1) else: decoder_input_ids_for_generation = decoder_input_ids outputs = self( input_ids, decoder_input_ids_for_generation, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=True, registering_cache=registering_cache ) del input_ids input_ids = None past_key_values = outputs.past_key_values registering_cache = outputs.registering_cache next_token_logits = outputs.logits[:, -1, :].clone().float() next_token_logits = next_token_logits.to(device) next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) next_token_scores_processed = logits_processor(decoder_input_ids, next_token_scores) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( next_token_scores_processed ) # reshape for beam search vocab_size = next_token_scores.shape[-1] next_token_scores = next_token_scores.view(batch_size, beam_size * vocab_size) # Beam token selection: pick 1 + eos_token_id.shape[0] next tokens for each beam so we have at least 1 # non eos token per beam. n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0 n_tokens_to_keep = max(2, 1 + n_eos_tokens) * beam_size next_token_scores, next_tokens = torch.topk( next_token_scores, n_tokens_to_keep, dim=1, largest=True, sorted=True ) next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") next_tokens = next_tokens % vocab_size beam_outputs = beam_scorer.process( decoder_input_ids, next_token_scores, next_tokens, next_indices, pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, decoder_prompt_len=1, ) beam_scores = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] decoder_input_ids = torch.cat([decoder_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) del outputs past_key_values = self._reorder_cache(past_key_values, beam_idx) registering_cache["register_nums"] = self._reorder_register_nums(registering_cache["register_nums"], beam_idx) cur_len = cur_len + 1 if beam_scorer.is_done: this_peer_finished = True sequence_outputs = beam_scorer.finalize( decoder_input_ids, beam_scores, next_tokens, next_indices, pad_token_id=generation_config.pad_token_id, eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, decoder_prompt_len=1, ) return sequence_outputs["sequences"] MitreForConditionalGeneration.register_for_auto_class("AutoModel")