Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,314 Bytes
b99927f e613ea6 78a5cec b99927f 78a5cec b99927f 78a5cec b99927f e613ea6 2494b9d e613ea6 b99927f e613ea6 b99927f e613ea6 78a5cec b99927f 78a5cec 2494b9d e613ea6 b99927f 78a5cec e613ea6 2494b9d e613ea6 2494b9d e613ea6 2494b9d b99927f 78a5cec e613ea6 78a5cec e613ea6 78a5cec b99927f e613ea6 72162d8 e613ea6 2d945ae 72162d8 e613ea6 72162d8 b99927f e613ea6 72162d8 2c78193 e613ea6 2c78193 e613ea6 2c78193 e613ea6 2c78193 72162d8 e613ea6 72162d8 2c78193 e613ea6 2c78193 b99927f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import os
import random
from collections import Counter
import gradio as gr
import polars as pl
import spaces
import torch
from metric import PerplexityCalculator
IS_DEBUG = False
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
df_sample_submission = pl.read_csv("data/sample_submission.csv")
text_list = df_sample_submission.get_column("text").to_list()
text_counters = [Counter(text.split()) for text in text_list]
# Model Loading
if not IS_DEBUG:
scorer = PerplexityCalculator("google/gemma-2-9b")
@spaces.GPU()
def inference(text: str, progress=gr.Progress(track_tqdm=True)):
score = -1
if IS_DEBUG:
index_text = f"[DEBUG] "
else:
index_text = ""
score = scorer.get_perplexity(text)
input_counter = Counter(text.split())
is_match_list = [input_counter == text_counter for text_counter in text_counters]
if any(is_match_list):
index = is_match_list.index(True)
index_text += f"Task #{index}"
return score, index_text
else:
index_text += "No Match to All Tasks"
gr.Warning(index_text)
return score, index_text
def random_inference(text: str, progress=gr.Progress(track_tqdm=True)):
if text == "":
text = text_list[0]
words = text.split()
random.shuffle(words)
random_text = " ".join(words)
score, index_text = inference(random_text)
return random_text, score, index_text
if __name__ == "__main__":
theme = gr.themes.Default(
primary_hue=gr.themes.colors.emerald,
secondary_hue=gr.themes.colors.green,
)
with gr.Blocks(theme=theme) as demo:
with gr.Column():
title = gr.Markdown(
"<h1 style='text-align: center; margin-bottom: 1rem'>Gemma-2-9b Perplexity Calculator</h1>"
)
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Text Input")
output_perplexity = gr.Number(label="Perplexity", render=False)
output_index = gr.Textbox(label="Index", render=False)
with gr.Row():
clear_button = gr.ClearButton(
[input_text, output_perplexity, output_index]
)
random_button = gr.Button("Randomize", variant="secondary")
submit_button = gr.Button("Run", variant="primary")
with gr.Column():
output_perplexity.render()
output_index.render()
sample_table = gr.Dataframe(
df_sample_submission, label="Sample Submission", type="polars"
)
submit_button.click(
inference, inputs=[input_text], outputs=[output_perplexity, output_index]
)
input_text.submit(
inference, inputs=[input_text], outputs=[output_perplexity, output_index]
)
random_button.click(
random_inference,
inputs=[input_text],
outputs=[input_text, output_perplexity, output_index],
)
demo.queue().launch()
|