Prgckwb commited on
Commit
e613ea6
·
1 Parent(s): 2d945ae
Files changed (1) hide show
  1. app.py +43 -26
app.py CHANGED
@@ -1,35 +1,38 @@
1
  import os
 
2
  from collections import Counter
3
 
4
  import gradio as gr
5
  import polars as pl
6
  import spaces
7
  import torch
8
- import random
9
 
10
  from metric import PerplexityCalculator
11
 
12
- IS_DEBUG = True
13
 
14
- os.environ['OMP_NUM_THREADS'] = '1'
15
- os.environ['TOKENIZERS_PARALLELISM'] = 'false'
16
  PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index
17
- DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
 
19
- df_sample_submission = pl.read_csv('data/sample_submission.csv')
20
- text_list = df_sample_submission.get_column('text').to_list()
21
  text_counters = [Counter(text.split()) for text in text_list]
22
 
23
  # Model Loading
24
  if not IS_DEBUG:
25
- scorer = PerplexityCalculator('google/gemma-2-9b')
26
 
27
 
28
  @spaces.GPU()
29
  def inference(text: str, progress=gr.Progress(track_tqdm=True)):
 
 
30
  if IS_DEBUG:
31
- score = -1
32
  else:
 
33
  score = scorer.get_perplexity(text)
34
 
35
  input_counter = Counter(text.split())
@@ -37,27 +40,28 @@ def inference(text: str, progress=gr.Progress(track_tqdm=True)):
37
 
38
  if any(is_match_list):
39
  index = is_match_list.index(True)
40
- index_text = f'Task #{index}'
41
  return score, index_text
42
  else:
43
- index_text = 'No Match to All Tasks'
44
  gr.Warning(index_text)
45
  return score, index_text
46
 
 
47
  def random_inference(text: str, progress=gr.Progress(track_tqdm=True)):
48
- if text == '':
49
  text = text_list[0]
50
 
51
  words = text.split()
52
  random.shuffle(words)
53
 
54
- random_text = ' '.join(words)
55
 
56
  score, index_text = inference(random_text)
57
  return random_text, score, index_text
58
 
59
 
60
- if __name__ == '__main__':
61
  theme = gr.themes.Default(
62
  primary_hue=gr.themes.colors.emerald,
63
  secondary_hue=gr.themes.colors.green,
@@ -66,28 +70,41 @@ if __name__ == '__main__':
66
  with gr.Blocks(theme=theme) as demo:
67
  with gr.Column():
68
  title = gr.Markdown(
69
- "<h1 style='text-align: center; margin-bottom: 1rem'>Gemma-2-9b Perplexity Calculator</h1>")
 
70
 
71
  with gr.Row():
72
  with gr.Column():
73
- input_text = gr.Textbox(label='Text Input')
74
 
75
- output_perplexity = gr.Number(label='Perplexity', render=False)
76
- output_index = gr.Textbox(label='Index', render=False)
77
 
78
  with gr.Row():
79
- clear_button = gr.ClearButton([input_text, output_perplexity, output_index])
80
- random_button = gr.Button('Randomize', variant='secondary')
81
- submit_button = gr.Button('Run', variant='primary')
 
 
82
 
83
  with gr.Column():
84
  output_perplexity.render()
85
  output_index.render()
86
 
87
- sample_table = gr.Dataframe(df_sample_submission, label='Sample Submission', type='polars')
88
-
89
- submit_button.click(inference, inputs=[input_text], outputs=[output_perplexity, output_index])
90
- input_text.submit(inference, inputs=[input_text], outputs=[output_perplexity, output_index])
91
- random_button.click(random_inference, inputs=[input_text], outputs=[input_text, output_perplexity, output_index])
 
 
 
 
 
 
 
 
 
 
92
 
93
  demo.queue().launch()
 
1
  import os
2
+ import random
3
  from collections import Counter
4
 
5
  import gradio as gr
6
  import polars as pl
7
  import spaces
8
  import torch
 
9
 
10
  from metric import PerplexityCalculator
11
 
12
+ IS_DEBUG = False
13
 
14
+ os.environ["OMP_NUM_THREADS"] = "1"
15
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
16
  PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index
17
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
 
19
+ df_sample_submission = pl.read_csv("data/sample_submission.csv")
20
+ text_list = df_sample_submission.get_column("text").to_list()
21
  text_counters = [Counter(text.split()) for text in text_list]
22
 
23
  # Model Loading
24
  if not IS_DEBUG:
25
+ scorer = PerplexityCalculator("google/gemma-2-9b")
26
 
27
 
28
  @spaces.GPU()
29
  def inference(text: str, progress=gr.Progress(track_tqdm=True)):
30
+ score = -1
31
+
32
  if IS_DEBUG:
33
+ index_text = f"[DEBUG] "
34
  else:
35
+ index_text = ""
36
  score = scorer.get_perplexity(text)
37
 
38
  input_counter = Counter(text.split())
 
40
 
41
  if any(is_match_list):
42
  index = is_match_list.index(True)
43
+ index_text += f"Task #{index}"
44
  return score, index_text
45
  else:
46
+ index_text += "No Match to All Tasks"
47
  gr.Warning(index_text)
48
  return score, index_text
49
 
50
+
51
  def random_inference(text: str, progress=gr.Progress(track_tqdm=True)):
52
+ if text == "":
53
  text = text_list[0]
54
 
55
  words = text.split()
56
  random.shuffle(words)
57
 
58
+ random_text = " ".join(words)
59
 
60
  score, index_text = inference(random_text)
61
  return random_text, score, index_text
62
 
63
 
64
+ if __name__ == "__main__":
65
  theme = gr.themes.Default(
66
  primary_hue=gr.themes.colors.emerald,
67
  secondary_hue=gr.themes.colors.green,
 
70
  with gr.Blocks(theme=theme) as demo:
71
  with gr.Column():
72
  title = gr.Markdown(
73
+ "<h1 style='text-align: center; margin-bottom: 1rem'>Gemma-2-9b Perplexity Calculator</h1>"
74
+ )
75
 
76
  with gr.Row():
77
  with gr.Column():
78
+ input_text = gr.Textbox(label="Text Input")
79
 
80
+ output_perplexity = gr.Number(label="Perplexity", render=False)
81
+ output_index = gr.Textbox(label="Index", render=False)
82
 
83
  with gr.Row():
84
+ clear_button = gr.ClearButton(
85
+ [input_text, output_perplexity, output_index]
86
+ )
87
+ random_button = gr.Button("Randomize", variant="secondary")
88
+ submit_button = gr.Button("Run", variant="primary")
89
 
90
  with gr.Column():
91
  output_perplexity.render()
92
  output_index.render()
93
 
94
+ sample_table = gr.Dataframe(
95
+ df_sample_submission, label="Sample Submission", type="polars"
96
+ )
97
+
98
+ submit_button.click(
99
+ inference, inputs=[input_text], outputs=[output_perplexity, output_index]
100
+ )
101
+ input_text.submit(
102
+ inference, inputs=[input_text], outputs=[output_perplexity, output_index]
103
+ )
104
+ random_button.click(
105
+ random_inference,
106
+ inputs=[input_text],
107
+ outputs=[input_text, output_perplexity, output_index],
108
+ )
109
 
110
  demo.queue().launch()