Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files- generate_title.py +185 -0
- model.py +70 -0
generate_title.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
文件说明:
|
3 |
+
根据训练好的模型,进行新闻标题生成,预测文件
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import os
|
8 |
+
import argparse
|
9 |
+
from model import GPT2LMHeadModel
|
10 |
+
from transformers import BertTokenizer
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import copy
|
13 |
+
|
14 |
+
|
15 |
+
def set_args():
|
16 |
+
"""设置模型预测所需参数"""
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument('--device', default='0', type=str, help='设置预测时使用的显卡,使用CPU设置成-1即可')
|
19 |
+
parser.add_argument('--model_path', default='output_dir/checkpoint-139805', type=str, help='模型文件路径')
|
20 |
+
parser.add_argument('--vocab_path', default='vocab/vocab.txt', type=str, help='词表,该词表为小词表,并增加了一些新的标记')
|
21 |
+
parser.add_argument('--batch_size', default=3, type=int, help='生成标题的个数')
|
22 |
+
parser.add_argument('--generate_max_len', default=32, type=int, help='生成标题的最大长度')
|
23 |
+
parser.add_argument('--repetition_penalty', default=1.2, type=float, help='重复处罚率')
|
24 |
+
parser.add_argument('--top_k', default=5, type=float, help='解码时保留概率最高的多少个标记')
|
25 |
+
parser.add_argument('--top_p', default=0.95, type=float, help='解码时保留概率累加大于多少的标记')
|
26 |
+
parser.add_argument('--max_len', type=int, default=512, help='输入模型的最大长度,要比config中n_ctx小')
|
27 |
+
return parser.parse_args()
|
28 |
+
|
29 |
+
|
30 |
+
def top_k_top_p_filtering(logits, top_k, top_p, filter_value=-float("Inf")):
|
31 |
+
"""
|
32 |
+
top_k或top_p解码策略,仅保留top_k个或累积概率到达top_p的标记,其他标记设为filter_value,后续在选取标记的过程中会取不到值设为无穷小。
|
33 |
+
Args:
|
34 |
+
logits: 预测结果,即预测成为词典中每个词的分数
|
35 |
+
top_k: 只保留概率最高的top_k个标记
|
36 |
+
top_p: 只保留概率累积达到top_p的标记
|
37 |
+
filter_value: 过滤标记值
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
|
41 |
+
"""
|
42 |
+
# logits的维度必须为2,即size:[batch_size, vocab_size]
|
43 |
+
assert logits.dim() == 2
|
44 |
+
# 获取top_k和字典大小中较小的一个,也就是说,如果top_k大于字典大小,则取字典大小个标记
|
45 |
+
top_k = min(top_k, logits[0].size(-1))
|
46 |
+
# 如果top_k不为0,则将在logits中保留top_k个标记
|
47 |
+
if top_k > 0:
|
48 |
+
# 由于有batch_size个预测结果,因此对其遍历,选取每个预测结果的top_k标记
|
49 |
+
for logit in logits:
|
50 |
+
indices_to_remove = logit < torch.topk(logit, top_k)[0][..., -1, None]
|
51 |
+
logit[indices_to_remove] = filter_value
|
52 |
+
# 如果top_p不为0,则将在logits中保留概率值累积达到top_p的标记
|
53 |
+
if top_p > 0.0:
|
54 |
+
# 对logits进行递减排序
|
55 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
56 |
+
# 对排序后的结果使用softmax归一化,再获取累积概率序列
|
57 |
+
# 例如:原始序列[0.1, 0.2, 0.3, 0.4],则变为:[0.1, 0.3, 0.6, 1.0]
|
58 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
59 |
+
# 删除累积概率高于top_p的标记
|
60 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
61 |
+
# 将索引向右移动,使第一个标记也保持在top_p之上
|
62 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
63 |
+
sorted_indices_to_remove[..., 0] = 0
|
64 |
+
for index, logit in enumerate(logits):
|
65 |
+
# 由于有batch_size个预测结果,因此对其遍历,选取每个预测结果的累积概率达到top_p的标记
|
66 |
+
indices_to_remove = sorted_indices[index][sorted_indices_to_remove[index]]
|
67 |
+
logit[indices_to_remove] = filter_value
|
68 |
+
return logits
|
69 |
+
|
70 |
+
|
71 |
+
def predict_one_sample(model, tokenizer, device, args, content):
|
72 |
+
"""
|
73 |
+
对单个样本进行预测
|
74 |
+
Args:
|
75 |
+
model: 模型
|
76 |
+
tokenizer: 分词器
|
77 |
+
device: 设备信息
|
78 |
+
args: 配置项信息
|
79 |
+
content: 新闻正文
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
|
83 |
+
"""
|
84 |
+
# 对新闻正文进行预处理,并判断如果超长则进行截断
|
85 |
+
content_tokens = tokenizer.tokenize(content)
|
86 |
+
if len(content_tokens) > args.max_len - 3 - args.generate_max_len:
|
87 |
+
content_tokens = content_tokens[:args.max_len - 3 - args.generate_max_len]
|
88 |
+
# 获取content_id、title_id、unk_id、sep_id值
|
89 |
+
content_id = tokenizer.convert_tokens_to_ids("[Content]")
|
90 |
+
title_id = tokenizer.convert_tokens_to_ids("[Title]")
|
91 |
+
unk_id = tokenizer.convert_tokens_to_ids("[UNK]")
|
92 |
+
sep_id = tokenizer.convert_tokens_to_ids("[SEP]")
|
93 |
+
# 将tokens索引化,变成模型所需格式
|
94 |
+
content_tokens = ["[CLS]"] + content_tokens + ["[SEP]"]
|
95 |
+
input_ids = tokenizer.convert_tokens_to_ids(content_tokens)
|
96 |
+
# 将input_ids和token_type_ids进行扩充,扩充到需要预测标题的个数,即batch_size
|
97 |
+
input_ids = [copy.deepcopy(input_ids) for _ in range(args.batch_size)]
|
98 |
+
token_type_ids = [[content_id] * len(content_tokens) for _ in range(args.batch_size)]
|
99 |
+
# 将input_ids和token_type_ids变成tensor
|
100 |
+
input_tensors = torch.tensor(input_ids).long().to(device)
|
101 |
+
token_type_tensors = torch.tensor(token_type_ids).long().to(device)
|
102 |
+
next_token_type = torch.tensor([[title_id] for _ in range(args.batch_size)]).long().to(device)
|
103 |
+
# 用于存放每一步解码的结果
|
104 |
+
generated = []
|
105 |
+
# 用于存放,完成解码序列的序号
|
106 |
+
finish_set = set()
|
107 |
+
with torch.no_grad():
|
108 |
+
# 遍历生成标题最大长度
|
109 |
+
for _ in range(args.generate_max_len):
|
110 |
+
outputs = model(input_ids=input_tensors, token_type_ids=token_type_tensors)
|
111 |
+
# 获取预测结果序列的最后一个标记,next_token_logits size:[batch_size, vocab_size]
|
112 |
+
next_token_logits = outputs[0][:, -1, :]
|
113 |
+
# 对batch_size进行遍历,将词表中出现在序列中的词的概率进行惩罚
|
114 |
+
for index in range(args.batch_size):
|
115 |
+
for token_id in set([token_ids[index] for token_ids in generated]):
|
116 |
+
next_token_logits[index][token_id] /= args.repetition_penalty
|
117 |
+
# 对batch_size进行遍历,将词表中的UNK的值设为无穷小
|
118 |
+
for next_token_logit in next_token_logits:
|
119 |
+
next_token_logit[unk_id] = -float("Inf")
|
120 |
+
# 使用top_k_top_p_filtering函数,按照top_k和top_p的值,对预测结果进行筛选
|
121 |
+
filter_logits = top_k_top_p_filtering(next_token_logits, top_k=args.top_k, top_p=args.top_p)
|
122 |
+
# 对filter_logits的每一行做一次取值,输出结果是每一次取值时filter_logits对应行的下标,即词表位置(词的id)
|
123 |
+
# filter_logits中的越大的值,越容易被选中
|
124 |
+
next_tokens = torch.multinomial(F.softmax(filter_logits, dim=-1), num_samples=1)
|
125 |
+
# 判断如果哪个序列的预测标记为sep_id时,则加入到finish_set
|
126 |
+
for index, token_id in enumerate(next_tokens[:, 0]):
|
127 |
+
if token_id == sep_id:
|
128 |
+
finish_set.add(index)
|
129 |
+
# 判断,如果finish_set包含全部的序列序号,则停止预测;否则继续预测
|
130 |
+
finish_flag = True
|
131 |
+
for index in range(args.batch_size):
|
132 |
+
if index not in finish_set:
|
133 |
+
finish_flag = False
|
134 |
+
break
|
135 |
+
if finish_flag:
|
136 |
+
break
|
137 |
+
# 将预测标记添加到generated中
|
138 |
+
generated.append([token.item() for token in next_tokens[:, 0]])
|
139 |
+
# 将预测结果拼接到input_tensors和token_type_tensors上,继续下一次预测
|
140 |
+
input_tensors = torch.cat((input_tensors, next_tokens), dim=-1)
|
141 |
+
token_type_tensors = torch.cat((token_type_tensors, next_token_type), dim=-1)
|
142 |
+
# 用于存储预测结果
|
143 |
+
candidate_responses = []
|
144 |
+
# 对batch_size进行遍历,并将token_id变成对应汉字
|
145 |
+
for index in range(args.batch_size):
|
146 |
+
responses = []
|
147 |
+
for token_index in range(len(generated)):
|
148 |
+
# 判断,当出现sep_id时,停止在该序列中添加token
|
149 |
+
if generated[token_index][index] != sep_id:
|
150 |
+
responses.append(generated[token_index][index])
|
151 |
+
else:
|
152 |
+
break
|
153 |
+
# 将token_id序列变成汉字序列,去除"##",并将[Space]替换成空格
|
154 |
+
candidate_responses.append(
|
155 |
+
"".join(tokenizer.convert_ids_to_tokens(responses)).replace("##", "").replace("[space]", " "))
|
156 |
+
return candidate_responses
|
157 |
+
|
158 |
+
|
159 |
+
def main():
|
160 |
+
"""主函数"""
|
161 |
+
# 设置预测的配置参数
|
162 |
+
args = set_args()
|
163 |
+
# 获取设备信息
|
164 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
165 |
+
os.environ["CUDA_VISIBLE_DEVICE"] = args.device
|
166 |
+
device = torch.device("cuda" if torch.cuda.is_available() and int(args.device) >= 0 else "cpu")
|
167 |
+
# 实例化tokenizer和model
|
168 |
+
tokenizer = BertTokenizer.from_pretrained(args.vocab_path, do_lower_case=True)
|
169 |
+
model = GPT2LMHeadModel.from_pretrained(args.model_path)
|
170 |
+
model.to(device)
|
171 |
+
model.eval()
|
172 |
+
print('开始对新闻生成标题,输入CTRL + Z,则退出')
|
173 |
+
try:
|
174 |
+
while True:
|
175 |
+
content = input("输入的新闻正文为:")
|
176 |
+
titles = predict_one_sample(model, tokenizer, device, args, content)
|
177 |
+
for i, title in enumerate(titles):
|
178 |
+
print("生成的第{}个标题为:{}".format(i + 1, title))
|
179 |
+
except:
|
180 |
+
pass
|
181 |
+
|
182 |
+
|
183 |
+
if __name__ == '__main__':
|
184 |
+
main()
|
185 |
+
|
model.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
文件说明:
|
3 |
+
GPT2模型文件,主要对transformers包中GPT2LMHeadModel的重写,修改计算loss部分,只计算预测title部分的loss
|
4 |
+
"""
|
5 |
+
|
6 |
+
from torch.nn import CrossEntropyLoss
|
7 |
+
import torch.nn as nn
|
8 |
+
from transformers.models.gpt2.modeling_gpt2 import GPT2PreTrainedModel, GPT2Model
|
9 |
+
|
10 |
+
|
11 |
+
class GPT2LMHeadModel(GPT2PreTrainedModel):
|
12 |
+
"""GPT2模型"""
|
13 |
+
def __init__(self, config):
|
14 |
+
"""
|
15 |
+
初始化函数
|
16 |
+
Args:
|
17 |
+
config: 配置参数
|
18 |
+
"""
|
19 |
+
super().__init__(config)
|
20 |
+
self.transformer = GPT2Model(config)
|
21 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
22 |
+
self.init_weights()
|
23 |
+
|
24 |
+
def forward(self, input_ids=None, past=None, token_type_ids=None, labels=None, title_id=None):
|
25 |
+
"""
|
26 |
+
前向函数,计算GPT2预测结果值
|
27 |
+
Args:
|
28 |
+
input_ids: 输入序列在词表中的索引序列,size:[batch_size, sequence_length]
|
29 |
+
past: 包含由模型预先计算好的隐藏状态,一般使用在预测阶段,用于加速顺序解码,防止重复计算前面计算过的token
|
30 |
+
token_type_ids: 用于区分输入序列中content和title的分隔符序列,size:[batch_size, sequence_length]
|
31 |
+
labels: 标签序列,size:[batch_size, sequence_length],一般情况下,与input_ids相同
|
32 |
+
title_id: title部分分隔符的id
|
33 |
+
Returns:
|
34 |
+
|
35 |
+
"""
|
36 |
+
# 获取GPT2模型的输出结果
|
37 |
+
transformer_outputs = self.transformer(input_ids, past_key_values=past, token_type_ids=token_type_ids)
|
38 |
+
# 获取GPT2模型的最后一层的隐层节点状态,size:[batch_size, sequence_length, config.n_embd]
|
39 |
+
hidden_states = transformer_outputs[0]
|
40 |
+
# 预测隐层节点状态中的每一个token的下一个token,size:[batch_size, sequence_length, config.vocab_size]
|
41 |
+
lm_logits = self.lm_head(hidden_states)
|
42 |
+
# 拼接输出结果
|
43 |
+
outputs = (lm_logits,) + transformer_outputs[1:]
|
44 |
+
# 如果labels不为None时,计算损失值loss,并拼接到输出结果中
|
45 |
+
if labels is not None:
|
46 |
+
# 计算loss时,title_id不可以为None,因为需要title_id找到title的部分
|
47 |
+
if title_id is None or token_type_ids is None:
|
48 |
+
raise Exception("当labels不为None时, title_id和token_type_ids均不可以为None。")
|
49 |
+
# 获取mask值,如果token_type_ids中等于title_id的部分需要计算loss,标记为1;否则为0。
|
50 |
+
# size:[batch_size, sequence_length]
|
51 |
+
mask = (token_type_ids == title_id).long()
|
52 |
+
# 获取新的标签,size:[batch_size, sequence_length]
|
53 |
+
labels = labels * mask
|
54 |
+
# 对预测结果和标签进行偏移操作
|
55 |
+
# GPT2的生成机制为通过前面的token,预测下一个token;并且labels与input_ids相同,
|
56 |
+
# 因此input_ids中的第一个token的预测结果,实际上是标签中的第二个token,以此类推,最终仅计算sequence_length-1个token的loss
|
57 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
58 |
+
shift_labels = labels[..., 1:].contiguous()
|
59 |
+
|
60 |
+
# 定义损失函数CrossEntropyLoss,并且设置忽略计算loss的索引,以及返回loss的形式
|
61 |
+
# 忽略shift_labels中为0的loss,也就是仅计算title部分的损失值
|
62 |
+
# 对loss的计算方式设为sum,由于我们仅计算了itle部分的损失值,如果使用mean,会使loss变小(实际除的是sequence_length-1,不是title部分的真实长度)
|
63 |
+
loss_fct = CrossEntropyLoss(ignore_index=0, reduction="sum")
|
64 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
65 |
+
# 获取title部分的真实长度,并计算真实loss
|
66 |
+
num = shift_labels.ne(0).long().sum().item()
|
67 |
+
loss = loss / num
|
68 |
+
outputs = (loss,) + outputs
|
69 |
+
return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)
|
70 |
+
|