""" 文件说明: GPT2模型文件,主要对transformers包中GPT2LMHeadModel的重写,修改计算loss部分,只计算预测title部分的loss """ from torch.nn import CrossEntropyLoss import torch.nn as nn from transformers.models.gpt2.modeling_gpt2 import GPT2PreTrainedModel, GPT2Model class GPT2LMHeadModel(GPT2PreTrainedModel): """GPT2模型""" def __init__(self, config): """ 初始化函数 Args: config: 配置参数 """ super().__init__(config) self.transformer = GPT2Model(config) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.init_weights() def forward(self, input_ids=None, past=None, token_type_ids=None, labels=None, title_id=None): """ 前向函数,计算GPT2预测结果值 Args: input_ids: 输入序列在词表中的索引序列,size:[batch_size, sequence_length] past: 包含由模型预先计算好的隐藏状态,一般使用在预测阶段,用于加速顺序解码,防止重复计算前面计算过的token token_type_ids: 用于区分输入序列中content和title的分隔符序列,size:[batch_size, sequence_length] labels: 标签序列,size:[batch_size, sequence_length],一般情况下,与input_ids相同 title_id: title部分分隔符的id Returns: """ # 获取GPT2模型的输出结果 transformer_outputs = self.transformer(input_ids, past_key_values=past, token_type_ids=token_type_ids) # 获取GPT2模型的最后一层的隐层节点状态,size:[batch_size, sequence_length, config.n_embd] hidden_states = transformer_outputs[0] # 预测隐层节点状态中的每一个token的下一个token,size:[batch_size, sequence_length, config.vocab_size] lm_logits = self.lm_head(hidden_states) # 拼接输出结果 outputs = (lm_logits,) + transformer_outputs[1:] # 如果labels不为None时,计算损失值loss,并拼接到输出结果中 if labels is not None: # 计算loss时,title_id不可以为None,因为需要title_id找到title的部分 if title_id is None or token_type_ids is None: raise Exception("当labels不为None时, title_id和token_type_ids均不可以为None。") # 获取mask值,如果token_type_ids中等于title_id的部分需要计算loss,标记为1;否则为0。 # size:[batch_size, sequence_length] mask = (token_type_ids == title_id).long() # 获取新的标签,size:[batch_size, sequence_length] labels = labels * mask # 对预测结果和标签进行偏移操作 # GPT2的生成机制为通过前面的token,预测下一个token;并且labels与input_ids相同, # 因此input_ids中的第一个token的预测结果,实际上是标签中的第二个token,以此类推,最终仅计算sequence_length-1个token的loss shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # 定义损失函数CrossEntropyLoss,并且设置忽略计算loss的索引,以及返回loss的形式 # 忽略shift_labels中为0的loss,也就是仅计算title部分的损失值 # 对loss的计算方式设为sum,由于我们仅计算了itle部分的损失值,如果使用mean,会使loss变小(实际除的是sequence_length-1,不是title部分的真实长度) loss_fct = CrossEntropyLoss(ignore_index=0, reduction="sum") loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) # 获取title部分的真实长度,并计算真实loss num = shift_labels.ne(0).long().sum().item() loss = loss / num outputs = (loss,) + outputs return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)