Spaces:
Runtime error
Runtime error
import streamlit as st | |
from model import GPT2LMHeadModel | |
from transformers import BertTokenizer | |
import argparse | |
import os | |
import torch | |
import time | |
from generate_title import predict_one_sample | |
st.set_page_config(page_title="Demo", initial_sidebar_state="auto", layout="wide") | |
# @st.cache_data(allow_output_mutation=True) | |
def get_model(device, vocab_path, model_path): | |
tokenizer = BertTokenizer.from_pretrained(vocab_path, do_lower_case=True) | |
model = GPT2LMHeadModel.from_pretrained(model_path) | |
model.to(device) | |
model.eval() | |
return tokenizer, model | |
device_ids = 0 | |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |
os.environ["CUDA_VISIBLE_DEVICE"] = str(device_ids) | |
device = torch.device("cuda:1" if torch.cuda.is_available() and int(device_ids) >= 0 else "cpu") | |
tokenizer, model = get_model(device, "vocab.txt", "checkpoint-55922") | |
def writer(): | |
st.markdown( | |
""" | |
## Text Summary DEMO | |
""" | |
) | |
st.sidebar.subheader("配置参数") | |
batch_size = st.sidebar.slider("batch_size", min_value=0, max_value=10, value=3) | |
generate_max_len = st.sidebar.number_input("generate_max_len", min_value=0, max_value=64, value=32, step=1) | |
repetition_penalty = st.sidebar.number_input("repetition_penalty", min_value=0.0, max_value=10.0, value=1.2, | |
step=0.1) | |
top_k = st.sidebar.slider("top_k", min_value=0, max_value=10, value=3, step=1) | |
top_p = st.sidebar.number_input("top_p", min_value=0.0, max_value=1.0, value=0.95, step=0.01) | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--batch_size', default=batch_size, type=int, help='生成标题的个数') | |
parser.add_argument('--generate_max_len', default=generate_max_len, type=int, help='生成标题的最大长度') | |
parser.add_argument('--repetition_penalty', default=repetition_penalty, type=float, help='重复处罚率') | |
parser.add_argument('--top_k', default=top_k, type=float, help='解码时保留概率最高的多少个标记') | |
parser.add_argument('--top_p', default=top_p, type=float, help='解码时保留概率累加大于多少的标记') | |
parser.add_argument('--max_len', type=int, default=512, help='输入模型的最大长度,要比config中n_ctx小') | |
args = parser.parse_args() | |
content = st.text_area("输入正文", value="近期美元指数大幅攀升,一度升至二十年高位,对此,市场人士认为,主要投资者对全球经济形势悲观预期升温,加上近日美联储决议鹰派基调,导致美元避险需求抬头。", max_chars=512) | |
if st.button("一键生成摘要"): | |
start_message = st.empty() | |
start_message.write("正在抽取,请等待...") | |
start_time = time.time() | |
titles = predict_one_sample(model, tokenizer, device, args, content) | |
end_time = time.time() | |
start_message.write("抽取完成,耗时{}s".format(end_time - start_time)) | |
for i, title in enumerate(titles): | |
st.text_input("第{}个结果".format(i + 1), title) | |
else: | |
st.stop() | |
if __name__ == '__main__': | |
writer() | |