File size: 8,638 Bytes
39aed69
7802ab3
 
39aed69
7802ab3
 
 
 
 
 
 
 
39aed69
7802ab3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f3e715
7802ab3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f3e715
 
1a61355
 
 
 
 
 
 
 
 
 
ebb18a8
94a59f6
 
 
 
 
 
 
 
 
ebb18a8
a1f4642
0f3e715
a1f4642
 
ebb18a8
 
 
 
 
 
1a61355
 
 
 
 
 
 
39aed69
 
7802ab3
0f3e715
1a61355
ebb18a8
 
7802ab3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import gradio as gr
import torch
import transformers

def reduce_sum(value, mask, axis=None):
    if axis is None:
        return torch.sum(value * mask)
    return torch.sum(value * mask, axis)
def reduce_mean(value, mask, axis=None):
    if axis is None:
        return torch.sum(value * mask) / torch.sum(mask)
    return reduce_sum(value, mask, axis) / torch.sum(mask, axis)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

class InteractiveRainier:
    def __init__(self):
        self.tokenizer = transformers.AutoTokenizer.from_pretrained('allenai/unifiedqa-t5-large')
        self.rainier_model = transformers.AutoModelForSeq2SeqLM.from_pretrained('liujch1998/rainier-large').to(device)
        self.qa_model = transformers.AutoModelForSeq2SeqLM.from_pretrained('allenai/unifiedqa-t5-large').to(device)
        self.loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100,reduction='none')

    def parse_choices(self, s):
        '''
        s: serialized_choices '(A) ... (B) ... (C) ...'
        '''
        choices = []
        key = 'A' if s.find('(A)') != -1 else 'a'
        while True:
            pos = s.find(f'({chr(ord(key) + 1)})')
            if pos == -1:
                break
            choice = s[3:pos]
            s = s[pos:]
            choice = choice.strip(' ')
            choices.append(choice)
            key = chr(ord(key) + 1)
        choice = s[3:]
        choice = choice.strip(' ')
        choices.append(choice)
        return choices

    def run(self, question, max_input_len, max_output_len, m, top_p):
        tokenized = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_input_len).to(device) # (1, L)
        knowledges_ids = self.rainier_model.generate(
            input_ids=tokenized.input_ids,
            max_length=max_output_len + 1,
            min_length=3,
            do_sample=True,
            num_return_sequences=m,
            top_p=top_p,
        ) # (K, L); begins with 0 ([BOS]); ends with 1 ([EOS])
        knowledges_ids = knowledges_ids[:, 1:].contiguous() # no beginning; ends with 1 ([EOS])
        knowledges = self.tokenizer.batch_decode(knowledges_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        knowledges = list(set(knowledges))
        knowledges = [''] + knowledges

        prompts = [question + (f' \\n {knowledge}' if knowledge != '' else '') for knowledge in knowledges]
        choices = self.parse_choices(question.split('\\n')[1].strip(' '))
        prompts = [prompt.lower() for prompt in prompts]
        choices = [choice.lower() for choice in choices]
        answer_logitss = []
        for choice in choices:
            tokenized_prompts = self.tokenizer(prompts, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_input_len).to(device) # (1+K, L)
            tokenized_choices = self.tokenizer([choice], return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_input_len).to(device) # (1, L)
            pad_mask = (tokenized_choices.input_ids == self.tokenizer.pad_token_id)
            tokenized_choices.input_ids[pad_mask] = -100
            tokenized_choices.input_ids = tokenized_choices.input_ids.repeat(len(knowledges), 1) # (1+K, L)

            with torch.no_grad():
                logits = self.qa_model(
                    input_ids=tokenized_prompts.input_ids,
                    attention_mask=tokenized_prompts.attention_mask,
                    labels=tokenized_choices.input_ids,
                ).logits # (1+K, L, V)

            losses = self.loss_fct(logits.view(-1, logits.size(-1)), tokenized_choices.input_ids.view(-1))
            losses = losses.view(tokenized_choices.input_ids.shape) # (1+K, L)
            losses = reduce_mean(losses, ~pad_mask, axis=-1) # (1+K)
            answer_logitss.append(-losses)
        answer_logitss = torch.stack(answer_logitss, dim=1) # (1+K, C)
        answer_probss = answer_logitss.softmax(dim=1) # (1+K, C)

        # Ensemble
        knowless_pred = answer_probss[0, :].argmax(dim=0).item()
        knowless_pred = choices[knowless_pred]

        answer_probs = answer_probss.max(dim=0).values # (C)
        knowful_pred = answer_probs.argmax(dim=0).item()
        knowful_pred = choices[knowful_pred]
        selected_knowledge_ix = answer_probss.max(dim=1).values.argmax(dim=0).item()
        selected_knowledge = knowledges[selected_knowledge_ix]

        return {
            'question': question,
            'knowledges': knowledges,
            'knowless_pred': knowless_pred,
            'knowful_pred': knowful_pred,
            'selected_knowledge': selected_knowledge,
        }

rainier = InteractiveRainier()

def predict(question, kg_model, qa_model, max_input_len, max_output_len, m, top_p):
    result = rainier.run(question, max_input_len, max_output_len, m, top_p)
    # output = ''
    # output += f'QA model answer without knowledge: {result["knowless_pred"]}\n'
    # output += f'QA model answer with knowledge: {result["knowful_pred"]}\n'
    # output += '\n'
    # output += f'All generated knowledges:\n'
    # for knowledge in result['knowledges']:
    #     output += f'  {knowledge}\n'
    # output += '\n'
    # output += f'Knowledge selected to make the prediction: {result["selected_knowledge"]}\n'
    return result['knowless_pred'], result['knowful_pred'], '\n'.join(result['knowledges']), result['selected_knowledge']

examples = [
    'If the mass of an object gets bigger what will happen to the amount of matter contained within it? \\n (A) gets bigger (B) gets smaller',
    'What would vinyl be an odd thing to replace? \\n (A) pants (B) record albums (C) record store (D) cheese (E) wallpaper',
    'Some pelycosaurs gave rise to reptile ancestral to \\n (A) lamphreys (B) angiosperm (C) mammals (D) paramecium (E) animals (F) protozoa (G) arachnids (H) backbones',
    'Sydney rubbed Addison’s head because she had a horrible headache. What will happen to Sydney? \\n (A) drift to sleep (B) receive thanks (C) be reprimanded',
    'Adam always spent all of the free time watching Tv unlike Hunter who volunteered, due to _ being lazy. \\n (A) Adam (B) Hunter',
    'Causes bad breath and frightens blood-suckers \\n (A) tuna (B) iron (C) trash (D) garlic (E) pubs',
]

input_question = gr.Dropdown(choices=examples, label='Question:',
    info='A multiple-choice commonsense question. Please follow the UnifiedQA input format: "{question} \\n (A) ... (B) ... (C) ..."',
)
input_kg_model = gr.Textbox(label='Knowledge generation model:', value='liujch1998/rainier-large', interactive=False)
input_qa_model = gr.Textbox(label='QA model:', value='allenai/unifiedqa-t5-large', interactive=False)
input_max_input_len = gr.Number(label='Max number of tokens in question:', value=256, precision=0)
input_max_output_len = gr.Number(label='Max number of tokens in knowledge:', value=32, precision=0)
input_m = gr.Slider(label='Number of generated knowledges:', value=10, mininum=1, maximum=20, step=1,
    info='The actual number of generated knowledges may be less than this number due to possible duplicates.',
)
input_top_p = gr.Slider(label='top_p for knowledge generation:', value=0.5, mininum=0.0, maximum=1.0, step=0.05)
output_knowless_answer = gr.Textbox(label='QA model answer without knowledge:', interactive=False)
output_knowful_answer = gr.Textbox(label='QA model answer with knowledge:', interactive=False)
output_all_knowledges = gr.Textbox(label='All generated knowledges:', interactive=False)
output_selected_knowledge = gr.Textbox(label='Knowledge selected to make the prediction:', interactive=False)

description = '''This is a demo for the paper, [*Rainier: Reinforced Knowledge Introspector for Commonsense Question Answering*](https://arxiv.org/pdf/2210.03078.pdf), presented at EMNLP 2022. [[Code](https://github.com/liujch1998/rainier)] [[Model](https://huggingface.co/liujch1998/rainier-large)] This demo is made & maintained by [Jiacheng (Gary) Liu](https://liujch1998.github.io).
Rainier is a knowledge-generating model that enhances the commonsense QA capability of a QA model. To try this model, select an example question, or write your own commonsense question in the suggested format.'''

gr.Interface(
    fn=predict,
    inputs=[input_question, input_kg_model, input_qa_model, input_max_input_len, input_max_output_len, input_m, input_top_p],
    outputs=[output_knowless_answer, output_knowful_answer, output_all_knowledges, output_selected_knowledge],
    title="Rainier Demo",
    description=description,
).launch()