File size: 5,902 Bytes
02df9f8
5c3cea5
02df9f8
 
 
9bfa66b
02df9f8
 
bf65d9e
02df9f8
 
9bfa66b
02df9f8
 
 
9428a07
 
 
 
02df9f8
 
 
 
 
39a2dae
bf65d9e
 
f7dc2d2
 
 
8fa0ae4
70487ef
 
 
eaa0586
70487ef
f7dc2d2
eaa0586
bf65d9e
3e9942b
0d4c25e
 
 
 
 
 
 
 
 
 
 
 
3098025
 
ae9daf1
0ad2aca
4fe17ee
ae9daf1
f7dc2d2
bf65d9e
f7dc2d2
bf65d9e
 
 
09d8750
8ee7a60
09d8750
f7dc2d2
 
 
bf65d9e
 
 
 
 
 
 
 
 
 
09d8750
bf65d9e
 
 
 
 
 
 
 
4fe17ee
bf65d9e
 
 
 
f7dc2d2
bf65d9e
f7dc2d2
bf65d9e
 
 
 
 
f7dc2d2
bf65d9e
 
 
 
 
 
 
 
 
eaa0586
bf65d9e
02df9f8
 
 
3f861c3
a16dab3
 
3f861c3
9bfa66b
ae9daf1
a16dab3
8fa0ae4
9bfa66b
1efd23b
 
9428a07
 
 
 
8fa0ae4
486c21f
 
6cc23f5
02df9f8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import spaces
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = 'yuntian-deng/gpt2-implicit-cot-multiplication'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
MAX_PRODUCT_DIGITS = 100

def preprocess(num):
    num = str(num).strip().replace(' ', '')
    reversed_num = ' '.join(num[::-1])
    return reversed_num

def postprocess(raw_output):
    prediction = raw_output.replace(' ', '')[::-1]
    return prediction

@spaces.GPU
def predict_product(num1, num2):
    input_text = f'{preprocess(num1)} * {preprocess(num2)} ='
    inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu')
    model.to('cuda' if torch.cuda.is_available() else 'cpu')

    input_ids = inputs['input_ids']
    input_len = input_ids.shape[-1]
    prediction = ""
    correct_product = ""
    valid_input = True

    try:
        num1_int = int(num1)
        num2_int = int(num2)
        correct_product = str(num1_int * num2_int)
    except ValueError:
        valid_input = False

    generated_ids = inputs['input_ids']
    past_key_values = None
    for step in range(MAX_PRODUCT_DIGITS):  # Set a maximum limit to prevent infinite loops
        generation_kwargs = {
            'input_ids': generated_ids,
            'max_new_tokens': 1,
            'do_sample': False,
            'past_key_values': past_key_values,
            'return_dict_in_generate': True,
            'use_cache': True
        }
        if step == 0:
            del generation_kwargs['past_key_values']
        outputs = model.generate(**generation_kwargs)
        generated_ids = outputs.sequences
        next_token_id = generated_ids[0, -1]
        print (next_token_id)
        
        if next_token_id.item() == tokenizer.eos_token_id:
            print ('berak')
            break
        past_key_values = outputs.past_key_values
        
        output_text = tokenizer.decode(generated_ids[0, input_len:], skip_special_tokens=True)
        #prediction = postprocess(output_text)
        predicted_digits_reversed = output_text.strip().split(' ')
        print ('p', predicted_digits_reversed)
        correct_digits_reversed = list(correct_product)[::-1]
        print ('c', correct_digits_reversed)
        
        # Create the diff for HighlightedText
        diff = []
        correct_digits = []
        is_correct_sofar = True
        for i in range(len(predicted_digits_reversed)):
            predicted_digit = predicted_digits_reversed[i]
            correct_digit = correct_digits_reversed[i]
            correct_digits.append((correct_digit, None))
            if i >= len(correct_digits_reversed):
                if predicted_digit == '0' and is_correct_sofar:
                    is_correct_digit = True
                else:
                    is_correct_digit = False
            else:
                if predicted_digit == correct_digit:
                    is_correct_digit = True
                else:
                    is_correct_digit = False
            if not is_correct_digit:
                is_correct_sofar = False
            if is_correct_digit:
                diff.append((predicted_digit, "-"))
            else:
                diff.append((predicted_digit, "+"))
        diff = diff[::-1]
        correct_digits = correct_digits[::-1]

        yield correct_digits, diff, ""

    #if valid_input:
    #    is_correct = prediction == correct_product
    #    result_message = "Correct!" if is_correct else f"Incorrect! The correct product is {correct_product}."
    #else:
    #    result_message = "Invalid input. Could not evaluate correctness."

    ## Final diff for the complete prediction
    #final_diff = []
    #for i in range(max(len(prediction), len(correct_product))):
    #    if i < len(prediction) and i < len(correct_product) and prediction[i] == correct_product[i]:
    #        final_diff.append((prediction[i], None))  # No highlight for correct digits
    #    elif i < len(prediction) and (i >= len(correct_product) or prediction[i] != correct_product[i]):
    #        final_diff.append((prediction[i], "+"))  # Highlight incorrect digits in red
    #    if i < len(correct_product) and (i >= len(prediction) or prediction[i] != correct_product[i]):
    #        final_diff.append((correct_product[i], "-"))  # Highlight missing/incorrect digits in green

    #yield final_diff, result_message

demo = gr.Interface(
    fn=predict_product,
    inputs=[
        gr.Textbox(label='First Number (up to 12 digits)', value='123456789'),
        gr.Textbox(label='Second Number (up to 12 digits)', value='987654321'),
    ],
    outputs=[
        gr.HighlightedText(label='Ground Truth Product', combine_adjacent=False, show_legend=False, color_map={"-": "green", "+": "red"}),
        gr.HighlightedText(label='GPT2 Predicted Product', combine_adjacent=False, show_legend=False, color_map={"-": "green", "+": "red"}, show_inline_category=False),
        gr.HTML(label='Result Message')
    ],
    title='GPT2 Direct Multiplication Calculator (Without Using Chain-of-Thought)',
    description='This demo uses GPT2 to directly predict the product of two numbers without using any intermediate reasoning steps. The GPT2 model has been fine-tuned to internalize chain-of-thought reasoning within its hidden states, following our stepwise internalization approach detailed in the paper linked at the bottom of this page.',
    article="""
    - [Paper: From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step](https://arxiv.org/pdf/2405.14838)
    - [Code Repository](https://github.com/da03/Internalize_CoT_Step_by_Step)
    - [Tweet Announcement](https://twitter.com/yuntiandeng/status/1795854740879774036)
    """,
    clear_btn=None,
    submit_btn="Multiply!",
    live=False
)

demo.launch()