Spaces:
Runtime error
Runtime error
""" | |
文件说明: | |
根据训练好的模型,进行新闻标题生成,预测文件 | |
""" | |
import torch | |
import os | |
import argparse | |
from model import GPT2LMHeadModel | |
from transformers import BertTokenizer | |
import torch.nn.functional as F | |
import copy | |
def set_args(): | |
"""设置模型预测所需参数""" | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--device', default='0', type=str, help='设置预测时使用的显卡,使用CPU设置成-1即可') | |
parser.add_argument('--model_path', default='output_dir/checkpoint-139805', type=str, help='模型文件路径') | |
parser.add_argument('--vocab_path', default='vocab/vocab.txt', type=str, help='词表,该词表为小词表,并增加了一些新的标记') | |
parser.add_argument('--batch_size', default=3, type=int, help='生成标题的个数') | |
parser.add_argument('--generate_max_len', default=32, type=int, help='生成标题的最大长度') | |
parser.add_argument('--repetition_penalty', default=1.2, type=float, help='重复处罚率') | |
parser.add_argument('--top_k', default=5, type=float, help='解码时保留概率最高的多少个标记') | |
parser.add_argument('--top_p', default=0.95, type=float, help='解码时保留概率累加大于多少的标记') | |
parser.add_argument('--max_len', type=int, default=512, help='输入模型的最大长度,要比config中n_ctx小') | |
return parser.parse_args() | |
def top_k_top_p_filtering(logits, top_k, top_p, filter_value=-float("Inf")): | |
""" | |
top_k或top_p解码策略,仅保留top_k个或累积概率到达top_p的标记,其他标记设为filter_value,后续在选取标记的过程中会取不到值设为无穷小。 | |
Args: | |
logits: 预测结果,即预测成为词典中每个词的分数 | |
top_k: 只保留概率最高的top_k个标记 | |
top_p: 只保留概率累积达到top_p的标记 | |
filter_value: 过滤标记值 | |
Returns: | |
""" | |
# logits的维度必须为2,即size:[batch_size, vocab_size] | |
assert logits.dim() == 2 | |
# 获取top_k和字典大小中较小的一个,也就是说,如果top_k大于字典大小,则取字典大小个标记 | |
top_k = min(top_k, logits[0].size(-1)) | |
# 如果top_k不为0,则将在logits中保留top_k个标记 | |
if top_k > 0: | |
# 由于有batch_size个预测结果,因此对其遍历,选取每个预测结果的top_k标记 | |
for logit in logits: | |
indices_to_remove = logit < torch.topk(logit, top_k)[0][..., -1, None] | |
logit[indices_to_remove] = filter_value | |
# 如果top_p不为0,则将在logits中保留概率值累积达到top_p的标记 | |
if top_p > 0.0: | |
# 对logits进行递减排序 | |
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) | |
# 对排序后的结果使用softmax归一化,再获取累积概率序列 | |
# 例如:原始序列[0.1, 0.2, 0.3, 0.4],则变为:[0.1, 0.3, 0.6, 1.0] | |
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
# 删除累积概率高于top_p的标记 | |
sorted_indices_to_remove = cumulative_probs > top_p | |
# 将索引向右移动,使第一个标记也保持在top_p之上 | |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
sorted_indices_to_remove[..., 0] = 0 | |
for index, logit in enumerate(logits): | |
# 由于有batch_size个预测结果,因此对其遍历,选取每个预测结果的累积概率达到top_p的标记 | |
indices_to_remove = sorted_indices[index][sorted_indices_to_remove[index]] | |
logit[indices_to_remove] = filter_value | |
return logits | |
def predict_one_sample(model, tokenizer, device, args, content): | |
""" | |
对单个样本进行预测 | |
Args: | |
model: 模型 | |
tokenizer: 分词器 | |
device: 设备信息 | |
args: 配置项信息 | |
content: 新闻正文 | |
Returns: | |
""" | |
# 对新闻正文进行预处理,并判断如果超长则进行截断 | |
content_tokens = tokenizer.tokenize(content) | |
if len(content_tokens) > args.max_len - 3 - args.generate_max_len: | |
content_tokens = content_tokens[:args.max_len - 3 - args.generate_max_len] | |
# 获取content_id、title_id、unk_id、sep_id值 | |
content_id = tokenizer.convert_tokens_to_ids("[Content]") | |
title_id = tokenizer.convert_tokens_to_ids("[Title]") | |
unk_id = tokenizer.convert_tokens_to_ids("[UNK]") | |
sep_id = tokenizer.convert_tokens_to_ids("[SEP]") | |
# 将tokens索引化,变成模型所需格式 | |
content_tokens = ["[CLS]"] + content_tokens + ["[SEP]"] | |
input_ids = tokenizer.convert_tokens_to_ids(content_tokens) | |
# 将input_ids和token_type_ids进行扩充,扩充到需要预测标题的个数,即batch_size | |
input_ids = [copy.deepcopy(input_ids) for _ in range(args.batch_size)] | |
token_type_ids = [[content_id] * len(content_tokens) for _ in range(args.batch_size)] | |
# 将input_ids和token_type_ids变成tensor | |
input_tensors = torch.tensor(input_ids).long().to(device) | |
token_type_tensors = torch.tensor(token_type_ids).long().to(device) | |
next_token_type = torch.tensor([[title_id] for _ in range(args.batch_size)]).long().to(device) | |
# 用于存放每一步解码的结果 | |
generated = [] | |
# 用于存放,完成解码序列的序号 | |
finish_set = set() | |
with torch.no_grad(): | |
# 遍历生成标题最大长度 | |
for _ in range(args.generate_max_len): | |
outputs = model(input_ids=input_tensors, token_type_ids=token_type_tensors) | |
# 获取预测结果序列的最后一个标记,next_token_logits size:[batch_size, vocab_size] | |
next_token_logits = outputs[0][:, -1, :] | |
# 对batch_size进行遍历,将词表中出现在序列中的词的概率进行惩罚 | |
for index in range(args.batch_size): | |
for token_id in set([token_ids[index] for token_ids in generated]): | |
next_token_logits[index][token_id] /= args.repetition_penalty | |
# 对batch_size进行遍历,将词表中的UNK的值设为无穷小 | |
for next_token_logit in next_token_logits: | |
next_token_logit[unk_id] = -float("Inf") | |
# 使用top_k_top_p_filtering函数,按照top_k和top_p的值,对预测结果进行筛选 | |
filter_logits = top_k_top_p_filtering(next_token_logits, top_k=args.top_k, top_p=args.top_p) | |
# 对filter_logits的每一行做一次取值,输出结果是每一次取值时filter_logits对应行的下标,即词表位置(词的id) | |
# filter_logits中的越大的值,越容易被选中 | |
next_tokens = torch.multinomial(F.softmax(filter_logits, dim=-1), num_samples=1) | |
# 判断如果哪个序列的预测标记为sep_id时,则加入到finish_set | |
for index, token_id in enumerate(next_tokens[:, 0]): | |
if token_id == sep_id: | |
finish_set.add(index) | |
# 判断,如果finish_set包含全部的序列序号,则停止预测;否则继续预测 | |
finish_flag = True | |
for index in range(args.batch_size): | |
if index not in finish_set: | |
finish_flag = False | |
break | |
if finish_flag: | |
break | |
# 将预测标记添加到generated中 | |
generated.append([token.item() for token in next_tokens[:, 0]]) | |
# 将预测结果拼接到input_tensors和token_type_tensors上,继续下一次预测 | |
input_tensors = torch.cat((input_tensors, next_tokens), dim=-1) | |
token_type_tensors = torch.cat((token_type_tensors, next_token_type), dim=-1) | |
# 用于存储预测结果 | |
candidate_responses = [] | |
# 对batch_size进行遍历,并将token_id变成对应汉字 | |
for index in range(args.batch_size): | |
responses = [] | |
for token_index in range(len(generated)): | |
# 判断,当出现sep_id时,停止在该序列中添加token | |
if generated[token_index][index] != sep_id: | |
responses.append(generated[token_index][index]) | |
else: | |
break | |
# 将token_id序列变成汉字序列,去除"##",并将[Space]替换成空格 | |
candidate_responses.append( | |
"".join(tokenizer.convert_ids_to_tokens(responses)).replace("##", "").replace("[space]", " ")) | |
return candidate_responses | |
def main(): | |
"""主函数""" | |
# 设置预测的配置参数 | |
args = set_args() | |
# 获取设备信息 | |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |
os.environ["CUDA_VISIBLE_DEVICE"] = args.device | |
device = torch.device("cuda" if torch.cuda.is_available() and int(args.device) >= 0 else "cpu") | |
# 实例化tokenizer和model | |
tokenizer = BertTokenizer.from_pretrained(args.vocab_path, do_lower_case=True) | |
model = GPT2LMHeadModel.from_pretrained(args.model_path) | |
model.to(device) | |
model.eval() | |
print('开始对新闻生成标题,输入CTRL + Z,则退出') | |
try: | |
while True: | |
content = input("输入的新闻正文为:") | |
titles = predict_one_sample(model, tokenizer, device, args, content) | |
for i, title in enumerate(titles): | |
print("生成的第{}个标题为:{}".format(i + 1, title)) | |
except: | |
pass | |
if __name__ == '__main__': | |
main() | |