import os import random from collections import Counter import gradio as gr import polars as pl import spaces import torch from metric import PerplexityCalculator IS_DEBUG = False os.environ["OMP_NUM_THREADS"] = "1" os.environ["TOKENIZERS_PARALLELISM"] = "false" PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") df_sample_submission = pl.read_csv("data/sample_submission.csv") text_list = df_sample_submission.get_column("text").to_list() text_counters = [Counter(text.split()) for text in text_list] # Model Loading if not IS_DEBUG: scorer = PerplexityCalculator("google/gemma-2-9b") @spaces.GPU() def inference(text: str, progress=gr.Progress(track_tqdm=True)): score = -1 if IS_DEBUG: index_text = f"[DEBUG] " else: index_text = "" score = scorer.get_perplexity(text) input_counter = Counter(text.split()) is_match_list = [input_counter == text_counter for text_counter in text_counters] if any(is_match_list): index = is_match_list.index(True) index_text += f"Task #{index}" return score, index_text else: index_text += "No Match to All Tasks" gr.Warning(index_text) return score, index_text def random_inference(text: str, progress=gr.Progress(track_tqdm=True)): if text == "": text = text_list[0] words = text.split() random.shuffle(words) random_text = " ".join(words) score, index_text = inference(random_text) return random_text, score, index_text if __name__ == "__main__": theme = gr.themes.Default( primary_hue=gr.themes.colors.emerald, secondary_hue=gr.themes.colors.green, ) with gr.Blocks(theme=theme) as demo: with gr.Column(): title = gr.Markdown( "