da03 commited on
Commit
eaa0586
·
1 Parent(s): 521b575
Files changed (1) hide show
  1. app.py +25 -22
app.py CHANGED
@@ -2,6 +2,7 @@ import spaces
2
  import torch
3
  import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
5
 
6
  model_name = 'yuntian-deng/gpt2-implicit-cot-multiplication'
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -18,41 +19,43 @@ def postprocess(raw_output):
18
 
19
  @spaces.GPU
20
  def predict_product(num1, num2):
21
- # Reverse input digits and add spaces
22
  input_text = f'{preprocess(num1)} * {preprocess(num2)} ='
23
-
24
  inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu')
25
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
26
 
27
- # Generate output
28
- outputs = model.generate(**inputs, max_new_tokens=40)
29
-
30
- output = outputs[0][inputs['input_ids'].shape[-1]:]
31
- raw_output = tokenizer.decode(output, skip_special_tokens=True)
32
- prediction = postprocess(raw_output)
33
 
34
- # Evaluate the correctness of the result
35
  try:
36
  num1_int = int(num1)
37
  num2_int = int(num2)
38
- valid_input = True
39
  except ValueError:
40
  valid_input = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  if valid_input:
42
- correct_product = num1_int * num2_int
43
- try:
44
- prediction_int = int(prediction)
45
- is_correct = (prediction_int == correct_product)
46
- except ValueError:
47
- is_correct = False
48
- result_color = "green" if is_correct else "red"
49
  result_message = "Correct!" if is_correct else f"Incorrect! The correct product is {correct_product}."
50
  else:
51
- result_color = "black"
52
  result_message = "Invalid input. Could not evaluate correctness."
53
- result_html = f"<div style='color: {result_color};'>{result_message}</div>"
54
 
55
- return prediction, result_html
56
 
57
  demo = gr.Interface(
58
  fn=predict_product,
@@ -61,7 +64,7 @@ demo = gr.Interface(
61
  gr.Textbox(label='Second Number (up to 12 digits)', value='67890'),
62
  ],
63
  outputs=[
64
- gr.Textbox(label='Predicted Product'),
65
  gr.HTML(label='Result Message')
66
  ],
67
  title='GPT2 Direct Multiplication Calculator (Without Using Chain-of-Thought)',
@@ -73,7 +76,7 @@ demo = gr.Interface(
73
  """,
74
  clear_btn=None,
75
  submit_btn="Multiply!",
76
- live=False
77
  )
78
 
79
  demo.launch()
 
2
  import torch
3
  import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import time
6
 
7
  model_name = 'yuntian-deng/gpt2-implicit-cot-multiplication'
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
19
 
20
  @spaces.GPU
21
  def predict_product(num1, num2):
 
22
  input_text = f'{preprocess(num1)} * {preprocess(num2)} ='
 
23
  inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu')
24
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
25
 
26
+ generated_ids = inputs['input_ids']
27
+ prediction = ""
28
+ correct_product = ""
29
+ valid_input = True
 
 
30
 
 
31
  try:
32
  num1_int = int(num1)
33
  num2_int = int(num2)
34
+ correct_product = str(num1_int * num2_int)
35
  except ValueError:
36
  valid_input = False
37
+
38
+ for _ in range(40): # Adjust the range to control the maximum number of generated tokens
39
+ outputs = model.generate(generated_ids, max_new_tokens=1, do_sample=False)
40
+ generated_ids = torch.cat((generated_ids, outputs[:, -1:]), dim=-1)
41
+ output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
42
+ prediction = postprocess(output_text)
43
+
44
+ result_html = "<div style='margin-bottom: 10px;'>Correct Result: " + " ".join(correct_product) + "</div><div>"
45
+ for i, pred_digit in enumerate(prediction):
46
+ color = "green" if i < len(correct_product) and pred_digit == correct_product[i] else "red"
47
+ result_html += f"<span style='color: {color};'>{pred_digit}</span>"
48
+ result_html += "</div>"
49
+
50
+ yield result_html, ""
51
+
52
  if valid_input:
53
+ is_correct = prediction == correct_product
 
 
 
 
 
 
54
  result_message = "Correct!" if is_correct else f"Incorrect! The correct product is {correct_product}."
55
  else:
 
56
  result_message = "Invalid input. Could not evaluate correctness."
 
57
 
58
+ yield result_html, result_message
59
 
60
  demo = gr.Interface(
61
  fn=predict_product,
 
64
  gr.Textbox(label='Second Number (up to 12 digits)', value='67890'),
65
  ],
66
  outputs=[
67
+ gr.HTML(label='Predicted Product with Matching Digits Highlighted'),
68
  gr.HTML(label='Result Message')
69
  ],
70
  title='GPT2 Direct Multiplication Calculator (Without Using Chain-of-Thought)',
 
76
  """,
77
  clear_btn=None,
78
  submit_btn="Multiply!",
79
+ live=True
80
  )
81
 
82
  demo.launch()