da03 commited on
Commit
9a65236
·
1 Parent(s): a16dab3
Files changed (1) hide show
  1. app.py +89 -81
app.py CHANGED
@@ -3,10 +3,21 @@ import torch
3
  import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
- model_name = 'yuntian-deng/gpt2-implicit-cot-multiplication'
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(model_name)
9
- MAX_PRODUCT_DIGITS = 100
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def preprocess(num):
12
  num = str(num).strip().replace(' ', '')
@@ -21,97 +32,92 @@ def postprocess(raw_output):
21
  def predict_product(num1, num2):
22
  input_text = f'{preprocess(num1)} * {preprocess(num2)} ='
23
  inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu')
24
- model.to('cuda' if torch.cuda.is_available() else 'cpu')
25
 
26
  input_ids = inputs['input_ids']
27
  input_len = input_ids.shape[-1]
28
  prediction = ""
29
- correct_product = ""
30
  valid_input = True
31
 
32
  try:
33
  num1_int = int(num1)
34
  num2_int = int(num2)
35
- correct_product = str(num1_int * num2_int)
 
36
  except ValueError:
37
  valid_input = False
38
 
39
- generated_ids = inputs['input_ids']
40
- past_key_values = None
41
- for step in range(MAX_PRODUCT_DIGITS): # Set a maximum limit to prevent infinite loops
42
- generation_kwargs = {
43
- 'input_ids': generated_ids,
44
- 'max_new_tokens': 1,
45
- 'do_sample': False,
46
- 'past_key_values': past_key_values,
47
- 'return_dict_in_generate': True,
48
- 'use_cache': True
49
- }
50
- if step == 0:
51
- del generation_kwargs['past_key_values']
52
- outputs = model.generate(**generation_kwargs)
53
- generated_ids = outputs.sequences
54
- next_token_id = generated_ids[0, -1]
55
- print (next_token_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- if next_token_id.item() == tokenizer.eos_token_id:
58
- print ('berak')
59
- break
60
- past_key_values = outputs.past_key_values
61
-
62
- output_text = tokenizer.decode(generated_ids[0, input_len:], skip_special_tokens=True)
63
- #prediction = postprocess(output_text)
64
- predicted_digits_reversed = output_text.strip().split(' ')
65
- print ('p', predicted_digits_reversed)
66
- correct_digits_reversed = list(correct_product)[::-1]
67
- print ('c', correct_digits_reversed)
68
-
69
- # Create the diff for HighlightedText
70
- diff = []
71
- correct_digits = []
72
- is_correct_sofar = True
73
- for i in range(len(predicted_digits_reversed)):
74
- predicted_digit = predicted_digits_reversed[i]
75
- correct_digit = correct_digits_reversed[i]
76
- correct_digits.append((correct_digit, None))
77
- if i >= len(correct_digits_reversed):
78
- if predicted_digit == '0' and is_correct_sofar:
79
- is_correct_digit = True
80
  else:
81
- is_correct_digit = False
82
- else:
83
- if predicted_digit == correct_digit:
84
- is_correct_digit = True
 
 
 
 
85
  else:
86
- is_correct_digit = False
87
- if not is_correct_digit:
88
- is_correct_sofar = False
89
- if is_correct_digit:
90
- diff.append((predicted_digit, "-"))
91
- else:
92
- diff.append((predicted_digit, "+"))
93
- diff = diff[::-1]
94
- correct_digits = correct_digits[::-1]
95
-
96
- yield correct_digits, diff, ""
97
-
98
- #if valid_input:
99
- # is_correct = prediction == correct_product
100
- # result_message = "Correct!" if is_correct else f"Incorrect! The correct product is {correct_product}."
101
- #else:
102
- # result_message = "Invalid input. Could not evaluate correctness."
103
 
104
- ## Final diff for the complete prediction
105
- #final_diff = []
106
- #for i in range(max(len(prediction), len(correct_product))):
107
- # if i < len(prediction) and i < len(correct_product) and prediction[i] == correct_product[i]:
108
- # final_diff.append((prediction[i], None)) # No highlight for correct digits
109
- # elif i < len(prediction) and (i >= len(correct_product) or prediction[i] != correct_product[i]):
110
- # final_diff.append((prediction[i], "+")) # Highlight incorrect digits in red
111
- # if i < len(correct_product) and (i >= len(prediction) or prediction[i] != correct_product[i]):
112
- # final_diff.append((correct_product[i], "-")) # Highlight missing/incorrect digits in green
113
 
114
- #yield final_diff, result_message
115
 
116
  demo = gr.Interface(
117
  fn=predict_product,
@@ -119,10 +125,12 @@ demo = gr.Interface(
119
  gr.Textbox(label='First Number (up to 12 digits)', value='123456789'),
120
  gr.Textbox(label='Second Number (up to 12 digits)', value='987654321'),
121
  ],
 
122
  outputs=[
123
- gr.HighlightedText(label='Ground Truth Product', combine_adjacent=False, show_legend=False, color_map={"-": "green", "+": "red"}),
124
- gr.HighlightedText(label='GPT2 Predicted Product', combine_adjacent=False, show_legend=False, color_map={"-": "green", "+": "red"}, show_inline_category=False),
125
- gr.HTML(label='Result Message')
 
126
  ],
127
  title='GPT2 Direct Multiplication Calculator (Without Using Chain-of-Thought)',
128
  description='This demo uses GPT2 to directly predict the product of two numbers without using any intermediate reasoning steps. The GPT2 model has been fine-tuned to internalize chain-of-thought reasoning within its hidden states, following our stepwise internalization approach detailed in the paper linked at the bottom of this page.',
 
3
  import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
+ # Load models
7
+ implicit_cot_model_name = 'yuntian-deng/gpt2-implicit-cot-multiplication'
8
+ implicit_cot_model = AutoModelForCausalLM.from_pretrained(implicit_cot_model_name)
9
+ tokenizer = AutoTokenizer.from_pretrained(implicit_cot_model_name)
10
+
11
+ no_cot_model_name = 'yuntian-deng/gpt2-no-cot-multiplication'
12
+ no_cot_model = AutoModelForCausalLM.from_pretrained(no_cot_model_name)
13
+
14
+ explicit_cot_model_name = 'yuntian-deng/gpt2-explicit-cot-multiplication'
15
+ explicit_cot_model = AutoModelForCausalLM.from_pretrained(explicit_cot_model_name)
16
+
17
+ models = {'implicit': implicit_cot_model_name, 'no': no_cot_model, 'explicit': explicit_cot_model}
18
+
19
+ # Constants
20
+ MAX_PRODUCT_DIGITS_PER_MODEL = {'implicit': 100, 'no': 100, 'explicit': 900}
21
 
22
  def preprocess(num):
23
  num = str(num).strip().replace(' ', '')
 
32
  def predict_product(num1, num2):
33
  input_text = f'{preprocess(num1)} * {preprocess(num2)} ='
34
  inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu')
35
+ [model.to('cuda' if torch.cuda.is_available() else 'cpu') for model in models.values()]
36
 
37
  input_ids = inputs['input_ids']
38
  input_len = input_ids.shape[-1]
39
  prediction = ""
40
+ ground_truth_product = ""
41
  valid_input = True
42
 
43
  try:
44
  num1_int = int(num1)
45
  num2_int = int(num2)
46
+ ground_truth_product = str(num1_int * num2_int)
47
+ ground_truth_digits_reversed = list(ground_truth_product)[::-1]
48
  except ValueError:
49
  valid_input = False
50
 
51
+ generated_ids_per_model = {model_name: inputs['input_ids'].data.clone() for model_name in models}
52
+ finished_per_model = {model_name: False for model_name in models}
53
+ past_key_values_per_model = {model_name: None for model_name in models}
54
+ predicted_results_per_model = {}
55
+ for step in range(max(MAX_PRODUCT_DIGITS_PER_MODEL.values())): # Set a maximum limit to prevent infinite loops
56
+ # Ground Truth
57
+ ground_truth_results = []
58
+ for i in range(step+1):
59
+ ground_truth_digit = ground_truth_digits_reversed[i]
60
+ ground_truth_digits.append((ground_truth_digit, None))
61
+ ground_truth_digits = ground_truth_digits[::-1]
62
+ # Predicted
63
+ for model_name in models:
64
+ model = models[model_name]
65
+ if finished_per_model[model_name]:
66
+ continue
67
+ if step >= MAX_PRODUCT_DIGITS_PER_MODE[model_name]:
68
+ continue
69
+ generation_kwargs = {
70
+ 'input_ids': generated_ids_per_model[model_name],
71
+ 'max_new_tokens': 1,
72
+ 'do_sample': False,
73
+ 'past_key_values': past_key_values_per_model[model_name],
74
+ 'return_dict_in_generate': True,
75
+ 'use_cache': True
76
+ }
77
+ if step == 0:
78
+ del generation_kwargs['past_key_values']
79
+ outputs = model.generate(**generation_kwargs)
80
+ generated_ids = outputs.sequences
81
+ next_token_id = generated_ids[0, -1]
82
+ print (next_token_id)
83
+
84
+ if next_token_id.item() == tokenizer.eos_token_id:
85
+ print ('berak')
86
+ break
87
+ past_key_values_per_model[model_name] = outputs.past_key_values
88
+
89
+ output_text = tokenizer.decode(generated_ids[0, input_len:], skip_special_tokens=True)
90
+ predicted_digits_reversed = output_text.strip().split(' ')
91
 
92
+ predicted_results = []
93
+ is_correct_sofar = True
94
+ for i in range(len(predicted_digits_reversed)):
95
+ predicted_digit = predicted_digits_reversed[i]
96
+ ground_truth_digit = ground_truth_digits_reversed[i]
97
+ if i >= len(ground_truth_digits_reversed):
98
+ if predicted_digit == '0' and is_correct_sofar:
99
+ is_correct_digit = True
100
+ else:
101
+ is_correct_digit = False
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  else:
103
+ if predicted_digit == ground_truth_digit:
104
+ is_correct_digit = True
105
+ else:
106
+ is_correct_digit = False
107
+ if not is_correct_digit:
108
+ is_correct_sofar = False
109
+ if is_correct_digit:
110
+ predicted_results.append((predicted_digit, "correct"))
111
  else:
112
+ predicted_results.append((predicted_digit, "wrong"))
113
+ predicted_results = predicted_results[::-1]
114
+ predicted_results_per_model[model_name] = predicted_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
+ predicted_results_implicit_cot = predicted_results_per_model['implicit']
117
+ predicted_results_nocot = predicted_results_per_model['no']
118
+ predicted_results_explicit_cot = predicted_results_per_model['explicit']
 
 
 
 
 
 
119
 
120
+ yield ground_truth_digits_digits, predicted_results_implicit, predicted_results_nocot, predicted_results_explicit_cot
121
 
122
  demo = gr.Interface(
123
  fn=predict_product,
 
125
  gr.Textbox(label='First Number (up to 12 digits)', value='123456789'),
126
  gr.Textbox(label='Second Number (up to 12 digits)', value='987654321'),
127
  ],
128
+ color_map = {"correct": "green", "wrong": "red"}
129
  outputs=[
130
+ gr.HighlightedText(label='Ground Truth Product', combine_adjacent=False, show_legend=False, color_map=color_map),
131
+ gr.HighlightedText(label='Implicit CoT Predicted Product', combine_adjacent=False, show_legend=False, color_map=color_map, show_inline_category=False),
132
+ gr.HighlightedText(label='No CoT Predicted Product', combine_adjacent=False, show_legend=False, color_map=color_map, show_inline_category=False),
133
+ gr.HighlightedText(label='Explicit CoT Predicted Product', combine_adjacent=False, show_legend=False, color_map=color_map, show_inline_category=False),
134
  ],
135
  title='GPT2 Direct Multiplication Calculator (Without Using Chain-of-Thought)',
136
  description='This demo uses GPT2 to directly predict the product of two numbers without using any intermediate reasoning steps. The GPT2 model has been fine-tuned to internalize chain-of-thought reasoning within its hidden states, following our stepwise internalization approach detailed in the paper linked at the bottom of this page.',