da03 commited on
Commit
9bfa66b
·
1 Parent(s): 02df9f8
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -2,12 +2,12 @@ import spaces
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
- model_name = 'yuntian-deng/gpt2-small-implicit-cot-multiplication'
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = AutoModelForCausalLM.from_pretrained(model_name)
8
 
9
  def preprocess(num):
10
- num = num.strip().replace(' ', '')
11
  reversed_num = ' '.join(num[::-1])
12
  return reversed_num
13
 
@@ -17,13 +17,18 @@ def predict_product(num1, num2):
17
  inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu')
18
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
19
  outputs = model.generate(**inputs, max_new_tokens=40)
20
- prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
21
- return prediction.strip()
 
22
 
23
  demo = gr.Interface(
24
  fn=predict_product,
25
  inputs=[gr.Number(label='First Number (up to 9 digits)'), gr.Number(label='Second Number (up to 9 digits)')],
26
- outputs='text',
 
 
 
 
27
  title='GPT-2 Multiplication Predictor',
28
  description='Enter two numbers up to 9 digits each and get the predicted product.'
29
  )
 
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
+ model_name = 'yuntian-deng/gpt2-implicit-cot-multiplication'
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = AutoModelForCausalLM.from_pretrained(model_name)
8
 
9
  def preprocess(num):
10
+ num = str(num).strip().replace(' ', '')
11
  reversed_num = ' '.join(num[::-1])
12
  return reversed_num
13
 
 
17
  inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu')
18
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
19
  outputs = model.generate(**inputs, max_new_tokens=40)
20
+ raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
21
+ prediction = raw_output.strip().replace(' ', '')[::-1]
22
+ return input_text, raw_output, prediction
23
 
24
  demo = gr.Interface(
25
  fn=predict_product,
26
  inputs=[gr.Number(label='First Number (up to 9 digits)'), gr.Number(label='Second Number (up to 9 digits)')],
27
+ outputs=[
28
+ gr.Textbox(label='Raw Input to GPT-2'),
29
+ gr.Textbox(label='Raw Output from GPT-2'),
30
+ gr.Textbox(label='Predicted Product')
31
+ ],
32
  title='GPT-2 Multiplication Predictor',
33
  description='Enter two numbers up to 9 digits each and get the predicted product.'
34
  )