Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import random | |
from collections import Counter | |
import gradio as gr | |
import polars as pl | |
import spaces | |
import torch | |
from metric import PerplexityCalculator | |
IS_DEBUG = False | |
os.environ["OMP_NUM_THREADS"] = "1" | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
df_sample_submission = pl.read_csv("data/sample_submission.csv") | |
text_list = df_sample_submission.get_column("text").to_list() | |
text_counters = [Counter(text.split()) for text in text_list] | |
# Model Loading | |
if not IS_DEBUG: | |
scorer = PerplexityCalculator("google/gemma-2-9b") | |
def inference(text: str, progress=gr.Progress(track_tqdm=True)): | |
score = -1 | |
if IS_DEBUG: | |
index_text = f"[DEBUG] " | |
else: | |
index_text = "" | |
score = scorer.get_perplexity(text) | |
input_counter = Counter(text.split()) | |
is_match_list = [input_counter == text_counter for text_counter in text_counters] | |
if any(is_match_list): | |
index = is_match_list.index(True) | |
index_text += f"Task #{index}" | |
return score, index_text | |
else: | |
index_text += "No Match to All Tasks" | |
gr.Warning(index_text) | |
return score, index_text | |
def random_inference(text: str, progress=gr.Progress(track_tqdm=True)): | |
if text == "": | |
text = text_list[0] | |
words = text.split() | |
random.shuffle(words) | |
random_text = " ".join(words) | |
score, index_text = inference(random_text) | |
return random_text, score, index_text | |
if __name__ == "__main__": | |
theme = gr.themes.Default( | |
primary_hue=gr.themes.colors.emerald, | |
secondary_hue=gr.themes.colors.green, | |
) | |
with gr.Blocks(theme=theme) as demo: | |
with gr.Column(): | |
title = gr.Markdown( | |
"<h1 style='text-align: center; margin-bottom: 1rem'>Gemma-2-9b Perplexity Calculator</h1>" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox(label="Text Input") | |
output_perplexity = gr.Number(label="Perplexity", render=False) | |
output_index = gr.Textbox(label="Index", render=False) | |
with gr.Row(): | |
clear_button = gr.ClearButton( | |
[input_text, output_perplexity, output_index] | |
) | |
random_button = gr.Button("Randomize", variant="secondary") | |
submit_button = gr.Button("Run", variant="primary") | |
with gr.Column(): | |
output_perplexity.render() | |
output_index.render() | |
sample_table = gr.Dataframe( | |
df_sample_submission, label="Sample Submission", type="polars" | |
) | |
submit_button.click( | |
inference, inputs=[input_text], outputs=[output_perplexity, output_index] | |
) | |
input_text.submit( | |
inference, inputs=[input_text], outputs=[output_perplexity, output_index] | |
) | |
random_button.click( | |
random_inference, | |
inputs=[input_text], | |
outputs=[input_text, output_perplexity, output_index], | |
) | |
demo.queue().launch() | |