Spaces:
Sleeping
Sleeping
liujch1998
commited on
Commit
·
7eae3e8
1
Parent(s):
da832fe
Initial commit
Browse files- app.py +144 -0
- requirements.txt +4 -0
app.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import transformers
|
5 |
+
|
6 |
+
def reduce_sum(value, mask, axis=None):
|
7 |
+
if axis is None:
|
8 |
+
return torch.sum(value * mask)
|
9 |
+
return torch.sum(value * mask, axis)
|
10 |
+
def reduce_mean(value, mask, axis=None):
|
11 |
+
if axis is None:
|
12 |
+
return torch.sum(value * mask) / torch.sum(mask)
|
13 |
+
return reduce_sum(value, mask, axis) / torch.sum(mask, axis)
|
14 |
+
|
15 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
16 |
+
|
17 |
+
HF_TOKEN_DOWNLOAD = os.environ.get('HF_TOKEN_DOWNLOAD')
|
18 |
+
|
19 |
+
class Processor:
|
20 |
+
def __init__(self, model):
|
21 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_auth_token=HF_TOKEN_DOWNLOAD)
|
22 |
+
self.model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model, use_auth_token=HF_TOKEN_DOWNLOAD, low_cpu_mem_usage=True, device_map='auto', torch_dtype='auto', offload_folder='offload')
|
23 |
+
self.model.eval()
|
24 |
+
|
25 |
+
def parse_choices(self, s):
|
26 |
+
'''
|
27 |
+
s: serialized_choices '(A) ... (B) ... (C) ...'
|
28 |
+
'''
|
29 |
+
choices = []
|
30 |
+
key = 'A' if s.find('(A)') != -1 else 'a'
|
31 |
+
while True:
|
32 |
+
pos = s.find(f'({chr(ord(key) + 1)})')
|
33 |
+
if pos == -1:
|
34 |
+
break
|
35 |
+
choice = s[3:pos]
|
36 |
+
s = s[pos:]
|
37 |
+
choice = choice.strip(' ')
|
38 |
+
choices.append(choice)
|
39 |
+
key = chr(ord(key) + 1)
|
40 |
+
choice = s[3:]
|
41 |
+
choice = choice.strip(' ')
|
42 |
+
choices.append(choice)
|
43 |
+
return choices
|
44 |
+
|
45 |
+
def run(self, question, max_question_len, max_knowledge_len, max_answer_len, m, top_p):
|
46 |
+
choices = self.parse_choices(question.split('\\n')[1].strip(' '))
|
47 |
+
choices = [chr(ord('A') + i) for i, choice in enumerate(choices)]
|
48 |
+
choices_ids = self.tokenizer(choices, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_answer_len).input_ids.to(device) # (C, AL)
|
49 |
+
|
50 |
+
prompt = question + ' \\n Knowledge: '
|
51 |
+
prompt_tok = self.tokenizer(prompt, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_question_len).to(device) # (1, QL)
|
52 |
+
knowledges_ids = self.model.generate(
|
53 |
+
input_ids=prompt_tok.input_ids,
|
54 |
+
attention_mask=prompt_tok.attention_mask,
|
55 |
+
max_length=max_knowledge_len + 1,
|
56 |
+
min_length=3,
|
57 |
+
do_sample=True,
|
58 |
+
num_return_sequences=m,
|
59 |
+
top_p=top_p,
|
60 |
+
) # (K, KL); begins with 0 ([BOS]); ends with 1 ([EOS])
|
61 |
+
knowledges_ids = knowledges_ids[:, 1:].contiguous() # no beginning; ends with 1 ([EOS])
|
62 |
+
knowledges = self.tokenizer.batch_decode(knowledges_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
63 |
+
knowledges = list(set(knowledges))
|
64 |
+
knowledges = [''] + knowledges
|
65 |
+
|
66 |
+
prompts = [question + (f' \\n Knowledge: {knowledge} \\n Answer: ' if knowledge != '' else ' \\n Answer:') for knowledge in knowledges]
|
67 |
+
prompts_tok = self.tokenizer(prompts, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_question_len + max_knowledge_len).input_ids.to(device) # (1+K, QL+KL)
|
68 |
+
output = self.model(
|
69 |
+
input_ids=prompts_tok.input_ids,
|
70 |
+
attention_mask=prompts_tok.attention_mask,
|
71 |
+
# labels=choices_ids[0].unsqueeze(0).expand(len(knowledges), -1),
|
72 |
+
)
|
73 |
+
logitsss = output.logits # (1+K, AL, V)
|
74 |
+
logitss = logitsss[:, 0, :] # (1+K, V)
|
75 |
+
choice_ids = choices_ids[:, 0] # (C)
|
76 |
+
answer_logitss = logitss.gather(dim=1, index=choice_ids.unsqueeze(0).expand(len(knowledges), -1)) # (1+K, C)
|
77 |
+
answer_probss = answer_logitss.softmax(dim=1) # (1+K, C)
|
78 |
+
|
79 |
+
# Ensemble
|
80 |
+
knowless_pred = answer_probss[0, :].argmax(dim=0).item()
|
81 |
+
knowless_pred = choices[knowless_pred]
|
82 |
+
|
83 |
+
answer_probs = answer_probss.max(dim=0).values # (C)
|
84 |
+
knowful_pred = answer_probs.argmax(dim=0).item()
|
85 |
+
knowful_pred = choices[knowful_pred]
|
86 |
+
selected_knowledge_ix = answer_probss.max(dim=1).values.argmax(dim=0).item()
|
87 |
+
selected_knowledge = knowledges[selected_knowledge_ix]
|
88 |
+
|
89 |
+
return {
|
90 |
+
'question': question,
|
91 |
+
'knowledges': knowledges,
|
92 |
+
'knowless_pred': knowless_pred,
|
93 |
+
'knowful_pred': knowful_pred,
|
94 |
+
'selected_knowledge': selected_knowledge,
|
95 |
+
}
|
96 |
+
|
97 |
+
MODELS = [
|
98 |
+
'liujch1998/crystal-large',
|
99 |
+
# 'liujch1998/crystal-3b',
|
100 |
+
# 'liujch1998/crystal-11b',
|
101 |
+
]
|
102 |
+
processor_by_model = {}
|
103 |
+
for model in MODELS:
|
104 |
+
processor_by_model[model] = Processor(model)
|
105 |
+
|
106 |
+
def predict(question, model, max_question_len, max_knowledge_len, max_answer_len, m, top_p):
|
107 |
+
result = processor_by_model[model].run(question, max_question_len, max_knowledge_len, max_answer_len, m, top_p)
|
108 |
+
return result['knowless_pred'], result['knowful_pred'], '\n'.join(result['knowledges']), result['selected_knowledge']
|
109 |
+
|
110 |
+
examples = [
|
111 |
+
'If the mass of an object gets bigger what will happen to the amount of matter contained within it? \\n (A) gets bigger (B) gets smaller',
|
112 |
+
'What would vinyl be an odd thing to replace? \\n (A) pants (B) record albums (C) record store (D) cheese (E) wallpaper',
|
113 |
+
'Some pelycosaurs gave rise to reptile ancestral to \\n (A) lamphreys (B) angiosperm (C) mammals (D) paramecium (E) animals (F) protozoa (G) arachnids (H) backbones',
|
114 |
+
'Sydney rubbed Addison’s head because she had a horrible headache. What will happen to Sydney? \\n (A) drift to sleep (B) receive thanks (C) be reprimanded',
|
115 |
+
'Adam always spent all of the free time watching Tv unlike Hunter who volunteered, due to _ being lazy. \\n (A) Adam (B) Hunter',
|
116 |
+
'Causes bad breath and frightens blood-suckers \\n (A) tuna (B) iron (C) trash (D) garlic (E) pubs',
|
117 |
+
]
|
118 |
+
|
119 |
+
input_question = gr.Dropdown(choices=examples, label='Question:',
|
120 |
+
info='A multiple-choice commonsense question. Please follow the UnifiedQA input format: "{question} \\n (A) ... (B) ... (C) ..."',
|
121 |
+
)
|
122 |
+
input_model = gr.DropDown(label='Model:', value=MODELS[0], choices=MODELS)
|
123 |
+
input_max_question_len = gr.Number(label='Max number of tokens in question:', value=256, precision=0)
|
124 |
+
input_max_knowledge_len = gr.Number(label='Max number of tokens in knowledge:', value=32, precision=0)
|
125 |
+
input_max_answer_len = gr.Number(label='Max number of tokens in answer:', value=2, precision=0)
|
126 |
+
input_m = gr.Slider(label='Number of generated knowledges:', value=10, mininum=1, maximum=20, step=1,
|
127 |
+
info='The actual number of generated knowledges may be less than this number due to possible duplicates.',
|
128 |
+
)
|
129 |
+
input_top_p = gr.Slider(label='top_p for knowledge generation:', value=0.5, mininum=0.0, maximum=1.0, step=0.05)
|
130 |
+
output_knowless_answer = gr.Textbox(label='QA model answer without knowledge:', interactive=False)
|
131 |
+
output_knowful_answer = gr.Textbox(label='QA model answer with knowledge:', interactive=False)
|
132 |
+
output_all_knowledges = gr.Textbox(label='All generated knowledges:', interactive=False)
|
133 |
+
output_selected_knowledge = gr.Textbox(label='Knowledge selected to make the prediction:', interactive=False)
|
134 |
+
|
135 |
+
description = '''This is a demo for the paper, [*Crystal: Introspective Reasoners Reinforced with Self-Feedback*](), presented at EMNLP 2023. [[Code](https://github.com/liujch1998/crystal)] [[Model](https://huggingface.co/liujch1998/crystal-large)] This demo is made & maintained by [Jiacheng (Gary) Liu](https://liujch1998.github.io).
|
136 |
+
Crystal is an introspective reasoning model that answers commonsense questions by first generating knowledge and then use knowledge-grounded reasoning to reach a final prediction. To try this model, select an example question, or write your own commonsense question in the suggested format.'''
|
137 |
+
|
138 |
+
gr.Interface(
|
139 |
+
fn=predict,
|
140 |
+
inputs=[input_question, input_model, input_max_question_len, input_max_knowledge_len, input_max_answer_len, input_m, input_top_p],
|
141 |
+
outputs=[output_knowless_answer, output_knowful_answer, output_all_knowledges, output_selected_knowledge],
|
142 |
+
title="Crystal Demo",
|
143 |
+
description=description,
|
144 |
+
).launch()
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
3 |
+
tokenizers
|
4 |
+
sentencepiece
|