File size: 4,066 Bytes
1c9cb67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""
    文件说明:
    GPT2模型文件,主要对transformers包中GPT2LMHeadModel的重写,修改计算loss部分,只计算预测title部分的loss
"""

from torch.nn import CrossEntropyLoss
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2PreTrainedModel, GPT2Model


class GPT2LMHeadModel(GPT2PreTrainedModel):
    """GPT2模型"""
    def __init__(self, config):
        """
        初始化函数
        Args:
            config: 配置参数
        """
        super().__init__(config)
        self.transformer = GPT2Model(config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.init_weights()

    def forward(self, input_ids=None, past=None, token_type_ids=None, labels=None, title_id=None):
        """
        前向函数,计算GPT2预测结果值
        Args:
            input_ids: 输入序列在词表中的索引序列,size:[batch_size, sequence_length]
            past: 包含由模型预先计算好的隐藏状态,一般使用在预测阶段,用于加速顺序解码,防止重复计算前面计算过的token
            token_type_ids: 用于区分输入序列中content和title的分隔符序列,size:[batch_size, sequence_length]
            labels: 标签序列,size:[batch_size, sequence_length],一般情况下,与input_ids相同
            title_id: title部分分隔符的id
        Returns:

        """
        # 获取GPT2模型的输出结果
        transformer_outputs = self.transformer(input_ids, past_key_values=past, token_type_ids=token_type_ids)
        # 获取GPT2模型的最后一层的隐层节点状态,size:[batch_size, sequence_length, config.n_embd]
        hidden_states = transformer_outputs[0]
        # 预测隐层节点状态中的每一个token的下一个token,size:[batch_size, sequence_length, config.vocab_size]
        lm_logits = self.lm_head(hidden_states)
        # 拼接输出结果
        outputs = (lm_logits,) + transformer_outputs[1:]
        # 如果labels不为None时,计算损失值loss,并拼接到输出结果中
        if labels is not None:
            # 计算loss时,title_id不可以为None,因为需要title_id找到title的部分
            if title_id is None or token_type_ids is None:
                raise Exception("当labels不为None时, title_id和token_type_ids均不可以为None。")
            # 获取mask值,如果token_type_ids中等于title_id的部分需要计算loss,标记为1;否则为0。
            # size:[batch_size, sequence_length]
            mask = (token_type_ids == title_id).long()
            # 获取新的标签,size:[batch_size, sequence_length]
            labels = labels * mask
            # 对预测结果和标签进行偏移操作
            # GPT2的生成机制为通过前面的token,预测下一个token;并且labels与input_ids相同,
            # 因此input_ids中的第一个token的预测结果,实际上是标签中的第二个token,以此类推,最终仅计算sequence_length-1个token的loss
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            # 定义损失函数CrossEntropyLoss,并且设置忽略计算loss的索引,以及返回loss的形式
            # 忽略shift_labels中为0的loss,也就是仅计算title部分的损失值
            # 对loss的计算方式设为sum,由于我们仅计算了itle部分的损失值,如果使用mean,会使loss变小(实际除的是sequence_length-1,不是title部分的真实长度)
            loss_fct = CrossEntropyLoss(ignore_index=0, reduction="sum")
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            # 获取title部分的真实长度,并计算真实loss
            num = shift_labels.ne(0).long().sum().item()
            loss = loss / num
            outputs = (loss,) + outputs
        return outputs  # (loss), lm_logits, presents, (all hidden_states), (attentions)