File size: 2,491 Bytes
02df9f8
5c3cea5
02df9f8
 
 
d87049c
 
 
9a65236
 
3ffaf8e
 
 
9a65236
d87049c
9428a07
02df9f8
d87049c
2d0cc03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7dc2d2
d87049c
02df9f8
cf0cf7e
02df9f8
d87049c
3f861c3
859f68f
3f861c3
9bfa66b
137d14f
9bfa66b
859f68f
 
85fafa0
b8ed883
 
85fafa0
 
 
486c21f
d87049c
ebf0fc0
 
02df9f8
d87049c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import spaces
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load the implicit CoT model
implicit_cot_model_name = 'yuntian-deng/implicit-cot-math-mistral7b'
implicit_cot_model = AutoModelForCausalLM.from_pretrained(implicit_cot_model_name, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(implicit_cot_model_name)

implicit_cot_model.to('cuda' if torch.cuda.is_available() else 'cpu')
implicit_cot_model.eval()

# Constants
MAX_RESULT_TOKENS = 10

@spaces.GPU
def predict_answer(question):
    try:
        input_text = ' '.join(question.split()).strip() + ' ' + tokenizer.eos_token
        print (input_text)
        inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu')
        implicit_cot_model.to('cuda' if torch.cuda.is_available() else 'cpu')
    
        input_ids = inputs['input_ids']
        #print (input_ids)
        outputs = implicit_cot_model.generate(input_ids=input_ids,
                max_new_tokens=MAX_RESULT_TOKENS,
                do_sample=False)
        #print (outputs)
    
        prediction = tokenizer.decode(outputs[0, input_ids.shape[-1]:], skip_special_tokens=True)
    except Exception as e:
        prediction = f'{e}'

    return prediction


demo = gr.Interface(
    fn=predict_answer,
    inputs=[
        gr.Textbox(label='Question', value='Asumi\'s bookshelf has 120 books. She has 10 books on history, twice that many books on literature, and the rest are science books. How many science books does Asumi have?'),
    ],
    outputs=[
        gr.Textbox(label='Implicit CoT Prediction'),
    ],
    title='Solving Grade School Math Problems without Intermediate Reasoning Steps',
    description='This demo showcases Mistral-7B\'s ability to solve grade school math problems without producing intermediate steps, using our stepwise internalization approach linked below.',
    article="""
    - [Paper 1: Implicit Chain of Thought Reasoning via Knowledge Distillation](https://arxiv.org/pdf/2311.01460)
    - [Paper 2: From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step](https://arxiv.org/pdf/2405.14838)
    - [Code Repository](https://github.com/da03/Internalize_CoT_Step_by_Step)
    - [Tweet Announcement](https://twitter.com/yuntiandeng/status/1795854740879774036)
    """,
    clear_btn=None,
    submit_btn="Get Answer!",
    live=False,
    concurrency_limit=1
)
demo.queue(max_size=5).launch()