Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import os
|
4 |
+
import argparse
|
5 |
+
from tqdm import trange
|
6 |
+
from transformers import GPT2LMHeadModel
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
|
10 |
+
def is_word(word):
|
11 |
+
for item in list(word):
|
12 |
+
if item not in 'qwertyuiopasdfghjklzxcvbnm':
|
13 |
+
return False
|
14 |
+
return True
|
15 |
+
|
16 |
+
|
17 |
+
def _is_chinese_char(char):
|
18 |
+
"""Checks whether CP is the codepoint of a CJK character."""
|
19 |
+
# This defines a "chinese character" as anything in the CJK Unicode block:
|
20 |
+
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
21 |
+
#
|
22 |
+
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
23 |
+
# despite its name. The modern Korean Hangul alphabet is a different block,
|
24 |
+
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
25 |
+
# space-separated words, so they are not treated specially and handled
|
26 |
+
# like the all of the other languages.
|
27 |
+
cp = ord(char)
|
28 |
+
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
29 |
+
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
30 |
+
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
31 |
+
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
32 |
+
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
33 |
+
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
34 |
+
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
35 |
+
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
36 |
+
return True
|
37 |
+
|
38 |
+
return False
|
39 |
+
|
40 |
+
|
41 |
+
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
42 |
+
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
43 |
+
Args:
|
44 |
+
logits: logits distribution shape (vocabulary size)
|
45 |
+
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
46 |
+
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
47 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
48 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
49 |
+
"""
|
50 |
+
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
|
51 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
52 |
+
if top_k > 0:
|
53 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
54 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
55 |
+
logits[indices_to_remove] = filter_value
|
56 |
+
|
57 |
+
if top_p > 0.0:
|
58 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
59 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
60 |
+
|
61 |
+
# Remove tokens with cumulative probability above the threshold
|
62 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
63 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
64 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
65 |
+
sorted_indices_to_remove[..., 0] = 0
|
66 |
+
|
67 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
68 |
+
logits[indices_to_remove] = filter_value
|
69 |
+
return logits
|
70 |
+
|
71 |
+
|
72 |
+
def sample_sequence(model, context, length, n_ctx, tokenizer, temperature=1.0, top_k=30, top_p=0.0, repitition_penalty=1.0,
|
73 |
+
device='cpu'):
|
74 |
+
context = torch.tensor(context, dtype=torch.long, device=device)
|
75 |
+
context = context.unsqueeze(0)
|
76 |
+
generated = context
|
77 |
+
with torch.no_grad():
|
78 |
+
for _ in trange(length):
|
79 |
+
inputs = {'input_ids': generated[0][-(n_ctx - 1):].unsqueeze(0)}
|
80 |
+
outputs = model(
|
81 |
+
**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
82 |
+
next_token_logits = outputs[0][0, -1, :]
|
83 |
+
for id in set(generated):
|
84 |
+
next_token_logits[id] /= repitition_penalty
|
85 |
+
next_token_logits = next_token_logits / temperature
|
86 |
+
next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf')
|
87 |
+
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
88 |
+
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
89 |
+
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
|
90 |
+
return generated.tolist()[0]
|
91 |
+
|
92 |
+
|
93 |
+
def fast_sample_sequence(model, context, length, temperature=1.0, top_k=30, top_p=0.0, device='cpu'):
|
94 |
+
inputs = torch.LongTensor(context).view(1, -1).to(device)
|
95 |
+
if len(context) > 1:
|
96 |
+
_, past = model(inputs[:, :-1], None)[:2]
|
97 |
+
prev = inputs[:, -1].view(1, -1)
|
98 |
+
else:
|
99 |
+
past = None
|
100 |
+
prev = inputs
|
101 |
+
generate = [] + context
|
102 |
+
with torch.no_grad():
|
103 |
+
for i in trange(length):
|
104 |
+
output = model(prev, past=past)
|
105 |
+
output, past = output[:2]
|
106 |
+
output = output[-1].squeeze(0) / temperature
|
107 |
+
filtered_logits = top_k_top_p_filtering(output, top_k=top_k, top_p=top_p)
|
108 |
+
next_token = torch.multinomial(torch.softmax(filtered_logits, dim=-1), num_samples=1)
|
109 |
+
generate.append(next_token.item())
|
110 |
+
prev = next_token.view(1, 1)
|
111 |
+
return generate
|
112 |
+
|
113 |
+
|
114 |
+
# 通过命令行参数--fast_pattern,指定模式
|
115 |
+
def generate(n_ctx, model, context, length, tokenizer, temperature=1, top_k=0, top_p=0.0, repitition_penalty=1.0, device='cpu',
|
116 |
+
is_fast_pattern=False):
|
117 |
+
if is_fast_pattern:
|
118 |
+
return fast_sample_sequence(model, context, length, temperature=temperature, top_k=top_k, top_p=top_p,
|
119 |
+
device=device)
|
120 |
+
else:
|
121 |
+
return sample_sequence(model, context, length, n_ctx, tokenizer=tokenizer, temperature=temperature, top_k=top_k, top_p=top_p,
|
122 |
+
repitition_penalty=repitition_penalty, device=device)
|
123 |
+
|
124 |
+
def smp_generate(pre_str):
|
125 |
+
|
126 |
+
from tokenizations import tokenization_bert
|
127 |
+
|
128 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3' # 此处设置程序使用哪些显卡
|
129 |
+
length = 500
|
130 |
+
batch_size = 1
|
131 |
+
nsamples = 1
|
132 |
+
temperature = 1
|
133 |
+
topk = 8
|
134 |
+
topp = 0
|
135 |
+
repetition_penalty = 1.0
|
136 |
+
model_path = 'pretrained'
|
137 |
+
tokenizer_path = 'cache/vocab.txt'
|
138 |
+
save_samples = False
|
139 |
+
save_samples_path = '.'
|
140 |
+
fast_pattern = True
|
141 |
+
prefix = pre_str
|
142 |
+
|
143 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
144 |
+
|
145 |
+
tokenizer = tokenization_bert.BertTokenizer(vocab_file=tokenizer_path)
|
146 |
+
model = GPT2LMHeadModel.from_pretrained(model_path)
|
147 |
+
model.to(device)
|
148 |
+
model.eval()
|
149 |
+
|
150 |
+
n_ctx = model.config.n_ctx
|
151 |
+
|
152 |
+
if length == -1:
|
153 |
+
length = model.config.n_ctx
|
154 |
+
|
155 |
+
while True:
|
156 |
+
raw_text = prefix
|
157 |
+
context_tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(raw_text))
|
158 |
+
generated = 0
|
159 |
+
for _ in range(nsamples // batch_size):
|
160 |
+
out = generate(
|
161 |
+
n_ctx=n_ctx,
|
162 |
+
model=model,
|
163 |
+
context=context_tokens,
|
164 |
+
length=length,
|
165 |
+
is_fast_pattern=fast_pattern, tokenizer=tokenizer,
|
166 |
+
temperature=temperature, top_k=topk, top_p=topp, repitition_penalty=repetition_penalty, device=device
|
167 |
+
)
|
168 |
+
for i in range(batch_size):
|
169 |
+
generated += 1
|
170 |
+
text = tokenizer.convert_ids_to_tokens(out)
|
171 |
+
for i, item in enumerate(text[:-1]): # 确保英文前后有空格
|
172 |
+
if is_word(item) and is_word(text[i + 1]):
|
173 |
+
text[i] = item + ' '
|
174 |
+
for i, item in enumerate(text):
|
175 |
+
if item == '[MASK]':
|
176 |
+
text[i] = ''
|
177 |
+
elif item == '[CLS]':
|
178 |
+
text[i] = '\n\n'
|
179 |
+
elif item == '[SEP]':
|
180 |
+
text[i] = '\n'
|
181 |
+
info = "=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40 + "\n"
|
182 |
+
text = ''.join(text).replace('##', '').strip()
|
183 |
+
return text
|
184 |
+
|
185 |
+
|
186 |
+
def format_text(text):
|
187 |
+
return "<p>" + text.replace("\n", "<br>") + "</p>"
|
188 |
+
|
189 |
+
input_textbox = gr.inputs.Textbox(label="输入前缀")
|
190 |
+
output_textbox = gr.outputs.Textbox(label="生成文言文")
|
191 |
+
|
192 |
+
# 自定义HTML和CSS
|
193 |
+
html_content = """
|
194 |
+
<div style="display: flex; flex-direction: column-reverse;">
|
195 |
+
<div style="flex-grow: 1; overflow-y: auto;">
|
196 |
+
{output}
|
197 |
+
</div>
|
198 |
+
<div style="margin-top: 10px;">
|
199 |
+
{input}
|
200 |
+
</div>
|
201 |
+
</div>
|
202 |
+
"""
|
203 |
+
|
204 |
+
iface = gr.Interface(fn=smp_generate, inputs=input_textbox, outputs=output_textbox,
|
205 |
+
title="文言文生成器", layout="vertical", layout_mode="size",
|
206 |
+
layout_alignments=["center", "top"], template="gradio/custom.html",
|
207 |
+
html=html_content)
|
208 |
+
|
209 |
+
iface.launch()
|
210 |
+
|
211 |
+
|
212 |
+
|