yuntian-deng's picture
Update app.py
2d0cc03 verified
import spaces
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load the implicit CoT model
implicit_cot_model_name = 'yuntian-deng/implicit-cot-math-mistral7b'
implicit_cot_model = AutoModelForCausalLM.from_pretrained(implicit_cot_model_name, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(implicit_cot_model_name)
implicit_cot_model.to('cuda' if torch.cuda.is_available() else 'cpu')
implicit_cot_model.eval()
# Constants
MAX_RESULT_TOKENS = 10
@spaces.GPU
def predict_answer(question):
try:
input_text = ' '.join(question.split()).strip() + ' ' + tokenizer.eos_token
print (input_text)
inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu')
implicit_cot_model.to('cuda' if torch.cuda.is_available() else 'cpu')
input_ids = inputs['input_ids']
#print (input_ids)
outputs = implicit_cot_model.generate(input_ids=input_ids,
max_new_tokens=MAX_RESULT_TOKENS,
do_sample=False)
#print (outputs)
prediction = tokenizer.decode(outputs[0, input_ids.shape[-1]:], skip_special_tokens=True)
except Exception as e:
prediction = f'{e}'
return prediction
demo = gr.Interface(
fn=predict_answer,
inputs=[
gr.Textbox(label='Question', value='Asumi\'s bookshelf has 120 books. She has 10 books on history, twice that many books on literature, and the rest are science books. How many science books does Asumi have?'),
],
outputs=[
gr.Textbox(label='Implicit CoT Prediction'),
],
title='Solving Grade School Math Problems without Intermediate Reasoning Steps',
description='This demo showcases Mistral-7B\'s ability to solve grade school math problems without producing intermediate steps, using our stepwise internalization approach linked below.',
article="""
- [Paper 1: Implicit Chain of Thought Reasoning via Knowledge Distillation](https://arxiv.org/pdf/2311.01460)
- [Paper 2: From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step](https://arxiv.org/pdf/2405.14838)
- [Code Repository](https://github.com/da03/Internalize_CoT_Step_by_Step)
- [Tweet Announcement](https://twitter.com/yuntiandeng/status/1795854740879774036)
""",
clear_btn=None,
submit_btn="Get Answer!",
live=False,
concurrency_limit=1
)
demo.queue(max_size=5).launch()