JunhuiJi commited on
Commit
1c9cb67
·
1 Parent(s): a0da06b

Upload 2 files

Browse files
Files changed (2) hide show
  1. generate_title.py +185 -0
  2. model.py +70 -0
generate_title.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 文件说明:
3
+ 根据训练好的模型,进行新闻标题生成,预测文件
4
+ """
5
+
6
+ import torch
7
+ import os
8
+ import argparse
9
+ from model import GPT2LMHeadModel
10
+ from transformers import BertTokenizer
11
+ import torch.nn.functional as F
12
+ import copy
13
+
14
+
15
+ def set_args():
16
+ """设置模型预测所需参数"""
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument('--device', default='0', type=str, help='设置预测时使用的显卡,使用CPU设置成-1即可')
19
+ parser.add_argument('--model_path', default='output_dir/checkpoint-139805', type=str, help='模型文件路径')
20
+ parser.add_argument('--vocab_path', default='vocab/vocab.txt', type=str, help='词表,该词表为小词表,并增加了一些新的标记')
21
+ parser.add_argument('--batch_size', default=3, type=int, help='生成标题的个数')
22
+ parser.add_argument('--generate_max_len', default=32, type=int, help='生成标题的最大长度')
23
+ parser.add_argument('--repetition_penalty', default=1.2, type=float, help='重复处罚率')
24
+ parser.add_argument('--top_k', default=5, type=float, help='解码时保留概率最高的多少个标记')
25
+ parser.add_argument('--top_p', default=0.95, type=float, help='解码时保留概率累加大于多少的标记')
26
+ parser.add_argument('--max_len', type=int, default=512, help='输入模型的最大长度,要比config中n_ctx小')
27
+ return parser.parse_args()
28
+
29
+
30
+ def top_k_top_p_filtering(logits, top_k, top_p, filter_value=-float("Inf")):
31
+ """
32
+ top_k或top_p解码策略,仅保留top_k个或累积概率到达top_p的标记,其他标记设为filter_value,后续在选取标记的过程中会取不到值设为无穷小。
33
+ Args:
34
+ logits: 预测结果,即预测成为词典中每个词的分数
35
+ top_k: 只保留概率最高的top_k个标记
36
+ top_p: 只保留概率累积达到top_p的标记
37
+ filter_value: 过滤标记值
38
+
39
+ Returns:
40
+
41
+ """
42
+ # logits的维度必须为2,即size:[batch_size, vocab_size]
43
+ assert logits.dim() == 2
44
+ # 获取top_k和字典大小中较小的一个,也就是说,如果top_k大于字典大小,则取字典大小个标记
45
+ top_k = min(top_k, logits[0].size(-1))
46
+ # 如果top_k不为0,则将在logits中保留top_k个标记
47
+ if top_k > 0:
48
+ # 由于有batch_size个预测结果,因此对其遍历,选取每个预测结果的top_k标记
49
+ for logit in logits:
50
+ indices_to_remove = logit < torch.topk(logit, top_k)[0][..., -1, None]
51
+ logit[indices_to_remove] = filter_value
52
+ # 如果top_p不为0,则将在logits中保留概率值累积达到top_p的标记
53
+ if top_p > 0.0:
54
+ # 对logits进行递减排序
55
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
56
+ # 对排序后的结果使用softmax归一化,再获取累积概率序列
57
+ # 例如:原始序列[0.1, 0.2, 0.3, 0.4],则变为:[0.1, 0.3, 0.6, 1.0]
58
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
59
+ # 删除累积概率高于top_p的标记
60
+ sorted_indices_to_remove = cumulative_probs > top_p
61
+ # 将索引向右移动,使第一个标记也保持在top_p之上
62
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
63
+ sorted_indices_to_remove[..., 0] = 0
64
+ for index, logit in enumerate(logits):
65
+ # 由于有batch_size个预测结果,因此对其遍历,选取每个预测结果的累积概率达到top_p的标记
66
+ indices_to_remove = sorted_indices[index][sorted_indices_to_remove[index]]
67
+ logit[indices_to_remove] = filter_value
68
+ return logits
69
+
70
+
71
+ def predict_one_sample(model, tokenizer, device, args, content):
72
+ """
73
+ 对单个样本进行预测
74
+ Args:
75
+ model: 模型
76
+ tokenizer: 分词器
77
+ device: 设备信息
78
+ args: 配置项信息
79
+ content: 新闻正文
80
+
81
+ Returns:
82
+
83
+ """
84
+ # 对新闻正文进行预处理,并判断如果超长则进行截断
85
+ content_tokens = tokenizer.tokenize(content)
86
+ if len(content_tokens) > args.max_len - 3 - args.generate_max_len:
87
+ content_tokens = content_tokens[:args.max_len - 3 - args.generate_max_len]
88
+ # 获取content_id、title_id、unk_id、sep_id值
89
+ content_id = tokenizer.convert_tokens_to_ids("[Content]")
90
+ title_id = tokenizer.convert_tokens_to_ids("[Title]")
91
+ unk_id = tokenizer.convert_tokens_to_ids("[UNK]")
92
+ sep_id = tokenizer.convert_tokens_to_ids("[SEP]")
93
+ # 将tokens索引化,变成模型所需格式
94
+ content_tokens = ["[CLS]"] + content_tokens + ["[SEP]"]
95
+ input_ids = tokenizer.convert_tokens_to_ids(content_tokens)
96
+ # 将input_ids和token_type_ids进行扩充,扩充到需要预测标题的个数,即batch_size
97
+ input_ids = [copy.deepcopy(input_ids) for _ in range(args.batch_size)]
98
+ token_type_ids = [[content_id] * len(content_tokens) for _ in range(args.batch_size)]
99
+ # 将input_ids和token_type_ids变成tensor
100
+ input_tensors = torch.tensor(input_ids).long().to(device)
101
+ token_type_tensors = torch.tensor(token_type_ids).long().to(device)
102
+ next_token_type = torch.tensor([[title_id] for _ in range(args.batch_size)]).long().to(device)
103
+ # 用于存放每一步解码的结果
104
+ generated = []
105
+ # 用于存放,完成解码序列的序号
106
+ finish_set = set()
107
+ with torch.no_grad():
108
+ # 遍历生成标题最大长度
109
+ for _ in range(args.generate_max_len):
110
+ outputs = model(input_ids=input_tensors, token_type_ids=token_type_tensors)
111
+ # 获取预测结果序列的最后一个标记,next_token_logits size:[batch_size, vocab_size]
112
+ next_token_logits = outputs[0][:, -1, :]
113
+ # 对batch_size进行遍历,将词表中出现在序列中的词的概率进行惩罚
114
+ for index in range(args.batch_size):
115
+ for token_id in set([token_ids[index] for token_ids in generated]):
116
+ next_token_logits[index][token_id] /= args.repetition_penalty
117
+ # 对batch_size进行遍历,将词表中的UNK的值设为无穷小
118
+ for next_token_logit in next_token_logits:
119
+ next_token_logit[unk_id] = -float("Inf")
120
+ # 使用top_k_top_p_filtering函数,按照top_k和top_p的值,对预测结果进行筛选
121
+ filter_logits = top_k_top_p_filtering(next_token_logits, top_k=args.top_k, top_p=args.top_p)
122
+ # 对filter_logits的每一行做一次取值,输出结果是每一次取值时filter_logits对应行的下标,即词表位置(词的id)
123
+ # filter_logits中的越大的值,越容易被选中
124
+ next_tokens = torch.multinomial(F.softmax(filter_logits, dim=-1), num_samples=1)
125
+ # 判断如果哪个序列的预测标记为sep_id时,则加入到finish_set
126
+ for index, token_id in enumerate(next_tokens[:, 0]):
127
+ if token_id == sep_id:
128
+ finish_set.add(index)
129
+ # 判断,如果finish_set包含全部的序列序号,则停止预测;否则继续预测
130
+ finish_flag = True
131
+ for index in range(args.batch_size):
132
+ if index not in finish_set:
133
+ finish_flag = False
134
+ break
135
+ if finish_flag:
136
+ break
137
+ # 将预测标记添加到generated中
138
+ generated.append([token.item() for token in next_tokens[:, 0]])
139
+ # 将预测结果拼接到input_tensors和token_type_tensors上,继续下一次预测
140
+ input_tensors = torch.cat((input_tensors, next_tokens), dim=-1)
141
+ token_type_tensors = torch.cat((token_type_tensors, next_token_type), dim=-1)
142
+ # 用于存储预测结果
143
+ candidate_responses = []
144
+ # 对batch_size进行遍历,并将token_id变成对应汉字
145
+ for index in range(args.batch_size):
146
+ responses = []
147
+ for token_index in range(len(generated)):
148
+ # 判断,当出现sep_id时,停止在该序列中添加token
149
+ if generated[token_index][index] != sep_id:
150
+ responses.append(generated[token_index][index])
151
+ else:
152
+ break
153
+ # 将token_id序列变成汉字序列,去除"##",并将[Space]替换成空格
154
+ candidate_responses.append(
155
+ "".join(tokenizer.convert_ids_to_tokens(responses)).replace("##", "").replace("[space]", " "))
156
+ return candidate_responses
157
+
158
+
159
+ def main():
160
+ """主函数"""
161
+ # 设置预测的配置参数
162
+ args = set_args()
163
+ # 获取设备信息
164
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
165
+ os.environ["CUDA_VISIBLE_DEVICE"] = args.device
166
+ device = torch.device("cuda" if torch.cuda.is_available() and int(args.device) >= 0 else "cpu")
167
+ # 实例化tokenizer和model
168
+ tokenizer = BertTokenizer.from_pretrained(args.vocab_path, do_lower_case=True)
169
+ model = GPT2LMHeadModel.from_pretrained(args.model_path)
170
+ model.to(device)
171
+ model.eval()
172
+ print('开始对新闻生成标题,输入CTRL + Z,则退出')
173
+ try:
174
+ while True:
175
+ content = input("输入的新闻正文为:")
176
+ titles = predict_one_sample(model, tokenizer, device, args, content)
177
+ for i, title in enumerate(titles):
178
+ print("生成的第{}个标题为:{}".format(i + 1, title))
179
+ except:
180
+ pass
181
+
182
+
183
+ if __name__ == '__main__':
184
+ main()
185
+
model.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 文件说明:
3
+ GPT2模型文件,主要对transformers包中GPT2LMHeadModel的重写,修改计算loss部分,只计算预测title部分的loss
4
+ """
5
+
6
+ from torch.nn import CrossEntropyLoss
7
+ import torch.nn as nn
8
+ from transformers.models.gpt2.modeling_gpt2 import GPT2PreTrainedModel, GPT2Model
9
+
10
+
11
+ class GPT2LMHeadModel(GPT2PreTrainedModel):
12
+ """GPT2模型"""
13
+ def __init__(self, config):
14
+ """
15
+ 初始化函数
16
+ Args:
17
+ config: 配置参数
18
+ """
19
+ super().__init__(config)
20
+ self.transformer = GPT2Model(config)
21
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
22
+ self.init_weights()
23
+
24
+ def forward(self, input_ids=None, past=None, token_type_ids=None, labels=None, title_id=None):
25
+ """
26
+ 前向函数,计算GPT2预测结果值
27
+ Args:
28
+ input_ids: 输入序列在词表中的索引序列,size:[batch_size, sequence_length]
29
+ past: 包含由模型预先计算好的隐藏状态,一般使用在预测阶段,用于加速顺序解码,防止重复计算前面计算过的token
30
+ token_type_ids: 用于区分输入序列中content和title的分隔符序列,size:[batch_size, sequence_length]
31
+ labels: 标签序列,size:[batch_size, sequence_length],一般情况下,与input_ids相同
32
+ title_id: title部分分隔符的id
33
+ Returns:
34
+
35
+ """
36
+ # 获取GPT2模型的输出结果
37
+ transformer_outputs = self.transformer(input_ids, past_key_values=past, token_type_ids=token_type_ids)
38
+ # 获取GPT2模型的最后一层的隐层节点状态,size:[batch_size, sequence_length, config.n_embd]
39
+ hidden_states = transformer_outputs[0]
40
+ # 预测隐层节点状态中的每一个token的下一个token,size:[batch_size, sequence_length, config.vocab_size]
41
+ lm_logits = self.lm_head(hidden_states)
42
+ # 拼接输出结果
43
+ outputs = (lm_logits,) + transformer_outputs[1:]
44
+ # 如果labels不为None时,计算损失值loss,并拼接到输出结果中
45
+ if labels is not None:
46
+ # 计算loss时,title_id不可以为None,因为需要title_id找到title的部分
47
+ if title_id is None or token_type_ids is None:
48
+ raise Exception("当labels不为None时, title_id和token_type_ids均不可以为None。")
49
+ # 获取mask值,如果token_type_ids中等于title_id的部分需要计算loss,标记为1;否则为0。
50
+ # size:[batch_size, sequence_length]
51
+ mask = (token_type_ids == title_id).long()
52
+ # 获取新的标签,size:[batch_size, sequence_length]
53
+ labels = labels * mask
54
+ # 对预测结果和标签进行偏移操作
55
+ # GPT2的生成机制为通过前面的token,预测下一个token;并且labels与input_ids相同,
56
+ # 因此input_ids中的第一个token的预测结果,实际上是标签中的第二个token,以此类推,最终仅计算sequence_length-1个token的loss
57
+ shift_logits = lm_logits[..., :-1, :].contiguous()
58
+ shift_labels = labels[..., 1:].contiguous()
59
+
60
+ # 定义损失函数CrossEntropyLoss,并且设置忽略计算loss的索引,以及返回loss的形式
61
+ # 忽略shift_labels中为0的loss,也就是仅计算title部分的损失值
62
+ # 对loss的计算方式设为sum,由于我们仅计算了itle部分的损失值,如果使用mean,会使loss变小(实际除的是sequence_length-1,不是title部分的真实长度)
63
+ loss_fct = CrossEntropyLoss(ignore_index=0, reduction="sum")
64
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
65
+ # 获取title部分的真实长度,并计算真实loss
66
+ num = shift_labels.ne(0).long().sum().item()
67
+ loss = loss / num
68
+ outputs = (loss,) + outputs
69
+ return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)
70
+