gemma2-ppl / app.py
Prgckwb
change
e613ea6
raw
history blame
3.31 kB
import os
import random
from collections import Counter
import gradio as gr
import polars as pl
import spaces
import torch
from metric import PerplexityCalculator
IS_DEBUG = False
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
df_sample_submission = pl.read_csv("data/sample_submission.csv")
text_list = df_sample_submission.get_column("text").to_list()
text_counters = [Counter(text.split()) for text in text_list]
# Model Loading
if not IS_DEBUG:
scorer = PerplexityCalculator("google/gemma-2-9b")
@spaces.GPU()
def inference(text: str, progress=gr.Progress(track_tqdm=True)):
score = -1
if IS_DEBUG:
index_text = f"[DEBUG] "
else:
index_text = ""
score = scorer.get_perplexity(text)
input_counter = Counter(text.split())
is_match_list = [input_counter == text_counter for text_counter in text_counters]
if any(is_match_list):
index = is_match_list.index(True)
index_text += f"Task #{index}"
return score, index_text
else:
index_text += "No Match to All Tasks"
gr.Warning(index_text)
return score, index_text
def random_inference(text: str, progress=gr.Progress(track_tqdm=True)):
if text == "":
text = text_list[0]
words = text.split()
random.shuffle(words)
random_text = " ".join(words)
score, index_text = inference(random_text)
return random_text, score, index_text
if __name__ == "__main__":
theme = gr.themes.Default(
primary_hue=gr.themes.colors.emerald,
secondary_hue=gr.themes.colors.green,
)
with gr.Blocks(theme=theme) as demo:
with gr.Column():
title = gr.Markdown(
"<h1 style='text-align: center; margin-bottom: 1rem'>Gemma-2-9b Perplexity Calculator</h1>"
)
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Text Input")
output_perplexity = gr.Number(label="Perplexity", render=False)
output_index = gr.Textbox(label="Index", render=False)
with gr.Row():
clear_button = gr.ClearButton(
[input_text, output_perplexity, output_index]
)
random_button = gr.Button("Randomize", variant="secondary")
submit_button = gr.Button("Run", variant="primary")
with gr.Column():
output_perplexity.render()
output_index.render()
sample_table = gr.Dataframe(
df_sample_submission, label="Sample Submission", type="polars"
)
submit_button.click(
inference, inputs=[input_text], outputs=[output_perplexity, output_index]
)
input_text.submit(
inference, inputs=[input_text], outputs=[output_perplexity, output_index]
)
random_button.click(
random_inference,
inputs=[input_text],
outputs=[input_text, output_perplexity, output_index],
)
demo.queue().launch()