da03 commited on
Commit
d87049c
·
1 Parent(s): 5c05a33
Files changed (1) hide show
  1. app.py +20 -134
app.py CHANGED
@@ -3,156 +3,42 @@ import torch
3
  import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
- # Load models
7
- implicit_cot_model_name = 'yuntian-deng/gpt2-implicit-cot-multiplication'
8
- implicit_cot_model = AutoModelForCausalLM.from_pretrained(implicit_cot_model_name)
9
  tokenizer = AutoTokenizer.from_pretrained(implicit_cot_model_name)
10
 
11
- no_cot_model_name = 'yuntian-deng/gpt2-no-cot-multiplication'
12
- no_cot_model = AutoModelForCausalLM.from_pretrained(no_cot_model_name)
13
-
14
- explicit_cot_model_name = 'yuntian-deng/gpt2-explicit-cot-multiplication'
15
- explicit_cot_model = AutoModelForCausalLM.from_pretrained(explicit_cot_model_name)
16
-
17
- models = {'implicit': implicit_cot_model, 'no': no_cot_model, 'explicit': explicit_cot_model}
18
-
19
  # Constants
20
- MAX_PRODUCT_DIGITS_PER_MODEL = {'implicit': 100, 'no': 100, 'explicit': 960}
21
-
22
- def preprocess(num):
23
- num = str(num).strip().replace(' ', '')
24
- reversed_num = ' '.join(num[::-1])
25
- return reversed_num
26
-
27
- def postprocess(raw_output):
28
- prediction = raw_output.replace(' ', '')[::-1]
29
- return prediction
30
 
31
  @spaces.GPU
32
- def predict_product(num1, num2):
33
- input_text = f'{preprocess(num1)} * {preprocess(num2)} ='
34
  inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu')
35
- [model.to('cuda' if torch.cuda.is_available() else 'cpu') for model in models.values()]
36
 
37
  input_ids = inputs['input_ids']
38
- input_len = input_ids.shape[-1]
39
- prediction = ""
40
- ground_truth_product = ""
41
- valid_input = True
42
-
43
- try:
44
- num1_int = int(num1)
45
- num2_int = int(num2)
46
- ground_truth_product = str(num1_int * num2_int)
47
- ground_truth_digits_reversed = list(ground_truth_product)[::-1]
48
- except ValueError:
49
- valid_input = False
50
-
51
- generated_ids_per_model = {model_name: inputs['input_ids'].data.clone() for model_name in models}
52
- finished_per_model = {model_name: False for model_name in models}
53
- past_key_values_per_model = {model_name: None for model_name in models}
54
- predicted_annotations_per_model = {}
55
- for step in range(max(MAX_PRODUCT_DIGITS_PER_MODEL.values())): # Set a maximum limit to prevent infinite loops
56
- # Ground Truth
57
- if not valid_input:
58
- ground_truth_annotations = [('Invalid Input!', None)]
59
- else:
60
- ground_truth_annotations = [(ground_truth_digit, None) for ground_truth_digit in ground_truth_digits_reversed[:step+1]]
61
- ground_truth_annotations = ground_truth_annotations[::-1]
62
- # Predicted
63
- for model_name in models:
64
- model = models[model_name]
65
- if finished_per_model[model_name]:
66
- continue
67
- if step >= MAX_PRODUCT_DIGITS_PER_MODEL[model_name]:
68
- continue
69
- generation_kwargs = {
70
- 'input_ids': generated_ids_per_model[model_name],
71
- 'max_new_tokens': 1,
72
- 'do_sample': False,
73
- 'past_key_values': past_key_values_per_model[model_name],
74
- 'return_dict_in_generate': True,
75
- 'use_cache': True
76
- }
77
- if step == 0:
78
- del generation_kwargs['past_key_values']
79
- outputs = model.generate(**generation_kwargs)
80
- generated_ids = outputs.sequences
81
- next_token_id = generated_ids[0, -1]
82
- #print (next_token_id)
83
-
84
- if next_token_id.item() == tokenizer.eos_token_id:
85
- finished_per_model[model_name] = True
86
- if valid_input:
87
- if len([item for item in predicted_annotations_per_model[model_name] if item[1] is not None]) < len(ground_truth_digits_reversed):
88
- predicted_annotations_per_model[model_name].insert(0, ('⠀', 'wrong'))
89
- continue
90
-
91
- generated_ids_per_model[model_name] = generated_ids
92
- past_key_values_per_model[model_name] = outputs.past_key_values
93
-
94
- output_text = tokenizer.decode(generated_ids[0, input_len:], skip_special_tokens=True)
95
- predicted_digits_reversed = output_text.strip().split(' ')
96
-
97
- predicted_annotations = []
98
- is_correct_sofar = True
99
- if model_name == 'explicit':
100
- if '=' not in predicted_digits_reversed:
101
- predicted_annotations = [(predicted_digit, None) for predicted_digit in predicted_digits_reversed]
102
- predicted_digits_reversed = []
103
- else:
104
- equal_sign_position = predicted_digits_reversed.index('=')
105
- predicted_annotations = [(predicted_digit, None) for predicted_digit in predicted_digits_reversed[:equal_sign_position+1]]
106
- predicted_digits_reversed = predicted_digits_reversed[equal_sign_position+1:]
107
 
108
- for i in range(len(predicted_digits_reversed)):
109
- predicted_digit = predicted_digits_reversed[i]
110
- if not valid_input:
111
- is_correct_digit = None
112
- elif i >= len(ground_truth_digits_reversed):
113
- if predicted_digit == '0' and is_correct_sofar:
114
- is_correct_digit = True
115
- else:
116
- is_correct_digit = False
117
- else:
118
- ground_truth_digit = ground_truth_digits_reversed[i]
119
- if predicted_digit == ground_truth_digit:
120
- is_correct_digit = True
121
- else:
122
- is_correct_digit = False
123
- if not is_correct_digit:
124
- is_correct_sofar = False
125
- if is_correct_digit is None:
126
- predicted_annotations.append((predicted_digit, None))
127
- elif is_correct_digit:
128
- predicted_annotations.append((predicted_digit, "correct"))
129
- else:
130
- predicted_annotations.append((predicted_digit, "wrong"))
131
- predicted_annotations = predicted_annotations[::-1]
132
- predicted_annotations_per_model[model_name] = predicted_annotations
133
 
134
- predicted_annotations_implicit_cot = predicted_annotations_per_model['implicit']
135
- predicted_annotations_nocot = predicted_annotations_per_model['no']
136
- predicted_annotations_explicit_cot = predicted_annotations_per_model['explicit']
137
-
138
- yield ground_truth_annotations, predicted_annotations_implicit_cot, predicted_annotations_nocot, predicted_annotations_explicit_cot
139
 
140
  color_map = {"correct": "green", "wrong": "red"}
141
 
142
  demo = gr.Interface(
143
- fn=predict_product,
144
  inputs=[
145
- gr.Textbox(label='First Number (up to 15 digits)', value='123456789'),
146
- gr.Textbox(label='Second Number (up to 15 digits)', value='987654321'),
147
  ],
148
  outputs=[
149
- gr.HighlightedText(label='Ground Truth Product', combine_adjacent=False, show_legend=False, color_map=color_map),
150
- gr.HighlightedText(label='Implicit CoT Prediction (Ours)', combine_adjacent=False, show_legend=False, color_map=color_map, show_inline_category=False),
151
- gr.HighlightedText(label='No CoT Prediction', combine_adjacent=False, show_legend=False, color_map=color_map, show_inline_category=False),
152
- gr.HighlightedText(label='Explicit CoT Steps & Prediction', combine_adjacent=False, show_legend=False, color_map=color_map, show_inline_category=False),
153
  ],
154
- title='Predicting Multiplication with GPT-2: Implicit vs. Explicit CoT',
155
- description='This demo showcases GPT-2\'s ability to directly predict the product of two large numbers without intermediate steps, using our stepwise internalization method. Compare the performance of implicit CoT (our method), no CoT, and explicit CoT. Implicit CoT offers accuracy and speed, while explicit CoT provides detailed reasoning but is slower.',
156
  article="""
157
  - [Paper 1: Implicit Chain of Thought Reasoning via Knowledge Distillation](https://arxiv.org/pdf/2311.01460)
158
  - [Paper 2: From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step](https://arxiv.org/pdf/2405.14838)
@@ -160,8 +46,8 @@ demo = gr.Interface(
160
  - [Tweet Announcement](https://twitter.com/yuntiandeng/status/1795854740879774036)
161
  """,
162
  clear_btn=None,
163
- submit_btn="Multiply!",
164
  live=False,
165
  concurrency_limit=1
166
  )
167
- demo.queue(max_size=20).launch()
 
3
  import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
+ # Load the implicit CoT model
7
+ implicit_cot_model_name = 'yuntian-deng/implicit-cot-math-mistral7b'
8
+ implicit_cot_model = AutoModelForCausalLM.from_pretrained(implicit_cot_model_name, torch_dtype=torch.bfloat16)
9
  tokenizer = AutoTokenizer.from_pretrained(implicit_cot_model_name)
10
 
 
 
 
 
 
 
 
 
11
  # Constants
12
+ MAX_RESULT_TOKENS = 10
 
 
 
 
 
 
 
 
 
13
 
14
  @spaces.GPU
15
+ def predict_answer(question):
16
+ input_text = ' '.join(question.split()).strip() + ' ' + tokenizer.eos_token
17
  inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu')
18
+ implicit_cot_model.to('cuda' if torch.cuda.is_available() else 'cpu')
19
 
20
  input_ids = inputs['input_ids']
21
+ outputs = implicit_cot_model.generate(input_ids=input_ids,
22
+ max_new_tokens=MAX_RESULT_TOKENS,
23
+ do_sample=False)
24
+ generated_ids = outputs.sequences[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ prediction = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ return prediction
 
 
 
 
29
 
30
  color_map = {"correct": "green", "wrong": "red"}
31
 
32
  demo = gr.Interface(
33
+ fn=predict_answer,
34
  inputs=[
35
+ gr.Textbox(label='Question', value='A set of 7 spoons costs $21. If each spoon would be sold separately, how much would 5 spoons cost?'),
 
36
  ],
37
  outputs=[
38
+ gr.HighlightedText(label='Implicit CoT Prediction', combine_adjacent=False, show_legend=False, color_map=color_map, show_inline_category=False),
 
 
 
39
  ],
40
+ title='Solving Grade School Math Problems with Implicit CoT',
41
+ description='This demo showcases Mistral-7B\'s ability to solve grade school math problems without producing intermediate steps, using our stepwise internalization method.',
42
  article="""
43
  - [Paper 1: Implicit Chain of Thought Reasoning via Knowledge Distillation](https://arxiv.org/pdf/2311.01460)
44
  - [Paper 2: From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step](https://arxiv.org/pdf/2405.14838)
 
46
  - [Tweet Announcement](https://twitter.com/yuntiandeng/status/1795854740879774036)
47
  """,
48
  clear_btn=None,
49
+ submit_btn="Get Answer!",
50
  live=False,
51
  concurrency_limit=1
52
  )
53
+ demo.queue(max_size=5).launch()