da03 commited on
Commit
9428a07
·
1 Parent(s): e2618b3
Files changed (1) hide show
  1. app.py +14 -4
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import spaces
2
- import torch
3
  import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
@@ -12,14 +11,19 @@ def preprocess(num):
12
  reversed_num = ' '.join(num[::-1])
13
  return reversed_num
14
 
 
 
 
 
15
  @spaces.GPU
16
  def predict_product(num1, num2):
17
  input_text = f'{preprocess(num1)} * {preprocess(num2)} ='
18
  inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu')
19
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
20
  outputs = model.generate(**inputs, max_new_tokens=40)
21
- raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
22
- prediction = raw_output.strip().replace(' ', '')[::-1]
 
23
  return input_text, raw_output, prediction
24
 
25
  demo = gr.Interface(
@@ -31,7 +35,13 @@ demo = gr.Interface(
31
  gr.Textbox(label='Predicted Product')
32
  ],
33
  title='GPT-2 Multiplication Predictor',
34
- description='Enter two numbers up to 9 digits each and get the predicted product.'
 
 
 
 
 
 
35
  )
36
 
37
  demo.launch()
 
1
  import spaces
 
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
 
11
  reversed_num = ' '.join(num[::-1])
12
  return reversed_num
13
 
14
+ def postprocess(raw_output):
15
+ prediction = raw_output.replace(' ', '')[::-1]
16
+ return prediction
17
+
18
  @spaces.GPU
19
  def predict_product(num1, num2):
20
  input_text = f'{preprocess(num1)} * {preprocess(num2)} ='
21
  inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu')
22
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
23
  outputs = model.generate(**inputs, max_new_tokens=40)
24
+ output = outputs[0][inputs['input_ids'].shape[-1]:]
25
+ raw_output = tokenizer.decode(output, skip_special_tokens=True)
26
+ prediction = postprocess(raw_output)
27
  return input_text, raw_output, prediction
28
 
29
  demo = gr.Interface(
 
35
  gr.Textbox(label='Predicted Product')
36
  ],
37
  title='GPT-2 Multiplication Predictor',
38
+ description='Enter two numbers up to 9 digits each and get the predicted product.',
39
+ article="""
40
+ ### Additional Resources
41
+ - [Paper: From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step](https://arxiv.org/pdf/2405.14838)
42
+ - [Code Repository](https://github.com/da03/Internalize_CoT_Step_by_Step)
43
+ - [Tweet Announcement](https://twitter.com/yuntiandeng/status/1795854740879774036)
44
+ """
45
  )
46
 
47
  demo.launch()