File size: 3,314 Bytes
b99927f
e613ea6
78a5cec
b99927f
 
78a5cec
b99927f
 
78a5cec
 
b99927f
e613ea6
2494b9d
e613ea6
 
b99927f
e613ea6
b99927f
e613ea6
 
78a5cec
b99927f
78a5cec
2494b9d
e613ea6
b99927f
 
 
78a5cec
e613ea6
 
2494b9d
e613ea6
2494b9d
e613ea6
2494b9d
b99927f
78a5cec
 
 
 
 
e613ea6
78a5cec
 
e613ea6
78a5cec
 
b99927f
e613ea6
72162d8
e613ea6
2d945ae
 
72162d8
 
 
e613ea6
72162d8
 
 
 
b99927f
e613ea6
72162d8
 
 
 
2c78193
 
 
 
e613ea6
 
2c78193
 
 
e613ea6
2c78193
e613ea6
 
2c78193
72162d8
e613ea6
 
 
 
 
72162d8
 
 
 
2c78193
e613ea6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c78193
b99927f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import random
from collections import Counter

import gradio as gr
import polars as pl
import spaces
import torch

from metric import PerplexityCalculator

IS_DEBUG = False

os.environ["OMP_NUM_THREADS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

df_sample_submission = pl.read_csv("data/sample_submission.csv")
text_list = df_sample_submission.get_column("text").to_list()
text_counters = [Counter(text.split()) for text in text_list]

# Model Loading
if not IS_DEBUG:
    scorer = PerplexityCalculator("google/gemma-2-9b")


@spaces.GPU()
def inference(text: str, progress=gr.Progress(track_tqdm=True)):
    score = -1

    if IS_DEBUG:
        index_text = f"[DEBUG] "
    else:
        index_text = ""
        score = scorer.get_perplexity(text)

    input_counter = Counter(text.split())
    is_match_list = [input_counter == text_counter for text_counter in text_counters]

    if any(is_match_list):
        index = is_match_list.index(True)
        index_text += f"Task #{index}"
        return score, index_text
    else:
        index_text += "No Match to All Tasks"
        gr.Warning(index_text)
        return score, index_text


def random_inference(text: str, progress=gr.Progress(track_tqdm=True)):
    if text == "":
        text = text_list[0]

    words = text.split()
    random.shuffle(words)

    random_text = " ".join(words)

    score, index_text = inference(random_text)
    return random_text, score, index_text


if __name__ == "__main__":
    theme = gr.themes.Default(
        primary_hue=gr.themes.colors.emerald,
        secondary_hue=gr.themes.colors.green,
    )

    with gr.Blocks(theme=theme) as demo:
        with gr.Column():
            title = gr.Markdown(
                "<h1 style='text-align: center; margin-bottom: 1rem'>Gemma-2-9b Perplexity Calculator</h1>"
            )

            with gr.Row():
                with gr.Column():
                    input_text = gr.Textbox(label="Text Input")

                    output_perplexity = gr.Number(label="Perplexity", render=False)
                    output_index = gr.Textbox(label="Index", render=False)

                    with gr.Row():
                        clear_button = gr.ClearButton(
                            [input_text, output_perplexity, output_index]
                        )
                        random_button = gr.Button("Randomize", variant="secondary")
                        submit_button = gr.Button("Run", variant="primary")

                with gr.Column():
                    output_perplexity.render()
                    output_index.render()

            sample_table = gr.Dataframe(
                df_sample_submission, label="Sample Submission", type="polars"
            )

        submit_button.click(
            inference, inputs=[input_text], outputs=[output_perplexity, output_index]
        )
        input_text.submit(
            inference, inputs=[input_text], outputs=[output_perplexity, output_index]
        )
        random_button.click(
            random_inference,
            inputs=[input_text],
            outputs=[input_text, output_perplexity, output_index],
        )

    demo.queue().launch()