Spaces:
Running
on
Zero
Running
on
Zero
Prgckwb
commited on
Commit
·
78a5cec
1
Parent(s):
f22dc04
change
Browse files- .gitignore +0 -0
- app.py +22 -112
- metric.py +125 -0
- requirements.txt +1 -0
.gitignore
ADDED
The diff for this file is too large to render.
See raw diff
|
|
app.py
CHANGED
@@ -1,125 +1,41 @@
|
|
1 |
-
import gc
|
2 |
import os
|
3 |
-
from
|
4 |
-
from typing import List, Union
|
5 |
|
6 |
import gradio as gr
|
|
|
7 |
import spaces
|
8 |
import torch
|
9 |
-
|
|
|
10 |
|
11 |
os.environ['OMP_NUM_THREADS'] = '1'
|
12 |
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
13 |
PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index
|
14 |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
15 |
|
|
|
|
|
|
|
16 |
|
17 |
-
|
18 |
-
"""
|
19 |
-
Calculates perplexity of text using a pre-trained language model.
|
20 |
-
|
21 |
-
Adapted from https://github.com/asahi417/lmppl/blob/main/lmppl/ppl_recurrent_lm.py
|
22 |
-
|
23 |
-
Parameters
|
24 |
-
----------
|
25 |
-
model_path : str
|
26 |
-
Path to the pre-trained language model
|
27 |
-
|
28 |
-
load_in_8bit : bool, default=False
|
29 |
-
Use 8-bit quantization for the model. Requires CUDA.
|
30 |
-
|
31 |
-
device_map : str, default="auto"
|
32 |
-
Device mapping for the model.
|
33 |
-
"""
|
34 |
-
|
35 |
-
def __init__(
|
36 |
-
self,
|
37 |
-
model_path: str,
|
38 |
-
load_in_8bit: bool = False,
|
39 |
-
device_map: str = 'auto',
|
40 |
-
):
|
41 |
-
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_path, padding_side="right")
|
42 |
-
# Configure model loading based on quantization setting and device availability
|
43 |
-
if load_in_8bit:
|
44 |
-
if DEVICE.type != 'cuda':
|
45 |
-
raise ValueError('8-bit quantization requires CUDA device')
|
46 |
-
quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
|
47 |
-
self.model = transformers.AutoModelForCausalLM.from_pretrained(
|
48 |
-
model_path,
|
49 |
-
quantization_config=quantization_config,
|
50 |
-
device_map=device_map,
|
51 |
-
)
|
52 |
-
else:
|
53 |
-
self.model = transformers.AutoModelForCausalLM.from_pretrained(
|
54 |
-
model_path,
|
55 |
-
torch_dtype=torch.float16 if DEVICE.type == 'cuda' else torch.float32,
|
56 |
-
device_map=device_map,
|
57 |
-
)
|
58 |
-
|
59 |
-
self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
|
60 |
-
|
61 |
-
self.model.eval()
|
62 |
-
|
63 |
-
def get_perplexity(self, input_texts: Union[str, List[str]], batch_size: int = 1) -> Union[float, List[float]]:
|
64 |
-
single_input = isinstance(input_texts, str)
|
65 |
-
input_texts = [input_texts] if single_input else input_texts
|
66 |
-
loss_list = []
|
67 |
-
batches = len(input_texts) // batch_size + (len(input_texts) % batch_size != 0)
|
68 |
-
for j in range(batches):
|
69 |
-
a = j * batch_size
|
70 |
-
b = (j + 1) * batch_size
|
71 |
-
input_batch = input_texts[a:b]
|
72 |
-
with torch.no_grad():
|
73 |
-
text_with_special = [f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}" for text in
|
74 |
-
input_batch]
|
75 |
-
model_inputs = self.tokenizer(text_with_special, return_tensors='pt', add_special_tokens=False,
|
76 |
-
padding=True)
|
77 |
-
if 'token_type_ids' in model_inputs:
|
78 |
-
model_inputs.pop('token_type_ids')
|
79 |
-
model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}
|
80 |
-
output = self.model(**model_inputs, use_cache=False)
|
81 |
-
logits = output['logits']
|
82 |
-
label = model_inputs['input_ids']
|
83 |
-
label[label == self.tokenizer.pad_token_id] = PAD_TOKEN_LABEL_ID
|
84 |
-
shift_logits = logits[..., :-1, :].contiguous()
|
85 |
-
shift_labels = label[..., 1:].contiguous()
|
86 |
-
loss = self.loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
87 |
-
loss = loss.view(len(logits), -1)
|
88 |
-
valid_length = (shift_labels != PAD_TOKEN_LABEL_ID).sum(dim=-1)
|
89 |
-
loss = torch.sum(loss, -1) / valid_length
|
90 |
-
loss_list += loss.cpu().tolist()
|
91 |
-
ppl = [exp(i) for i in loss_list]
|
92 |
-
return ppl[0] if single_input else ppl
|
93 |
-
|
94 |
-
def clear_gpu_memory(self) -> None:
|
95 |
-
"""Clears GPU memory by deleting references and emptying caches."""
|
96 |
-
if not torch.cuda.is_available():
|
97 |
-
return
|
98 |
-
|
99 |
-
# Delete model and tokenizer if they exist
|
100 |
-
if hasattr(self, 'model'):
|
101 |
-
del self.model
|
102 |
-
if hasattr(self, 'tokenizer'):
|
103 |
-
del self.tokenizer
|
104 |
-
|
105 |
-
# Run garbage collection
|
106 |
-
gc.collect()
|
107 |
-
|
108 |
-
# Clear CUDA cache and reset memory stats
|
109 |
-
with DEVICE:
|
110 |
-
torch.cuda.empty_cache()
|
111 |
-
torch.cuda.ipc_collect()
|
112 |
-
torch.cuda.reset_peak_memory_stats()
|
113 |
-
|
114 |
-
|
115 |
scorer = PerplexityCalculator('google/gemma-2-9b')
|
116 |
|
117 |
|
118 |
@spaces.GPU()
|
119 |
-
def inference(text: str):
|
120 |
score = scorer.get_perplexity(text)
|
121 |
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
|
125 |
if __name__ == '__main__':
|
@@ -129,15 +45,9 @@ if __name__ == '__main__':
|
|
129 |
outputs=[
|
130 |
# gr.Number(label='Index'),
|
131 |
gr.Number(label='Perplexity'),
|
|
|
132 |
],
|
133 |
-
examples=
|
134 |
-
'advent chimney elf family fireplace gingerbread mistletoe ornament reindeer scrooge',
|
135 |
-
'advent chimney elf family fireplace gingerbread mistletoe ornament reindeer scrooge walk give jump drive bake the sleep night laugh and',
|
136 |
-
'yuletide decorations gifts cheer holiday carol magi nutcracker polar grinch sleigh chimney workshop stocking ornament holly jingle beard naughty nice',
|
137 |
-
'yuletide decorations gifts cheer holiday carol magi nutcracker polar grinch sleigh chimney workshop stocking ornament holly jingle beard naughty nice sing cheer and of the is eat visit relax unwrap',
|
138 |
-
'hohoho candle poinsettia snowglobe peppermint eggnog fruitcake chocolate candy puzzle game doll toy workshop wonder believe dream hope peace joy merry season greeting card wrapping paper bow fireplace night cookie milk star wish wreath angel the to of and in that have it not with as you from we kaggle',
|
139 |
-
'advent chimney elf family fireplace gingerbread mistletoe ornament reindeer scrooge walk give jump drive bake the sleep night laugh and yuletide decorations gifts cheer holiday carol magi nutcracker polar grinch sleigh chimney workshop stocking ornament holly jingle beard naughty nice sing cheer and of the is eat visit relax unwrap hohoho candle poinsettia snowglobe peppermint eggnog fruitcake chocolate candy puzzle game doll toy workshop wonder believe dream hope peace joy merry season greeting card wrapping paper bow fireplace night cookie milk star wish wreath angel the to of and in that have it not with as you from we kaggle'
|
140 |
-
],
|
141 |
title='Gemma-2-9b Perplexity Calculator',
|
142 |
)
|
143 |
demo.queue().launch()
|
|
|
|
|
1 |
import os
|
2 |
+
from collections import Counter
|
|
|
3 |
|
4 |
import gradio as gr
|
5 |
+
import polars as pl
|
6 |
import spaces
|
7 |
import torch
|
8 |
+
|
9 |
+
from metric import PerplexityCalculator
|
10 |
|
11 |
os.environ['OMP_NUM_THREADS'] = '1'
|
12 |
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
13 |
PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index
|
14 |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
15 |
|
16 |
+
df_sample_submission = pl.read_csv('data/sample_submission.csv')
|
17 |
+
text_list = df_sample_submission.get_column('text').to_list()
|
18 |
+
text_counters = [Counter(text.split()) for text in text_list]
|
19 |
|
20 |
+
# Model Loading
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
scorer = PerplexityCalculator('google/gemma-2-9b')
|
22 |
|
23 |
|
24 |
@spaces.GPU()
|
25 |
+
def inference(text: str, progress=gr.Progress(track_tqdm=True)):
|
26 |
score = scorer.get_perplexity(text)
|
27 |
|
28 |
+
input_counter = Counter(text.split())
|
29 |
+
is_match_list = [input_counter == text_counter for text_counter in text_counters]
|
30 |
+
|
31 |
+
if any(is_match_list):
|
32 |
+
index = is_match_list.index(True)
|
33 |
+
index_text = f'Task #{index}'
|
34 |
+
return score, index_text
|
35 |
+
else:
|
36 |
+
index_text = 'No Match'
|
37 |
+
gr.Warning(index_text)
|
38 |
+
return score, index_text
|
39 |
|
40 |
|
41 |
if __name__ == '__main__':
|
|
|
45 |
outputs=[
|
46 |
# gr.Number(label='Index'),
|
47 |
gr.Number(label='Perplexity'),
|
48 |
+
gr.Textbox(label='Index')
|
49 |
],
|
50 |
+
examples=text_list,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
title='Gemma-2-9b Perplexity Calculator',
|
52 |
)
|
53 |
demo.queue().launch()
|
metric.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import os
|
3 |
+
from math import exp
|
4 |
+
from typing import List, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import transformers
|
8 |
+
|
9 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
10 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
11 |
+
PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index
|
12 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
+
|
14 |
+
|
15 |
+
class PerplexityCalculator:
|
16 |
+
"""
|
17 |
+
Calculates perplexity of text using a pre-trained language model.
|
18 |
+
|
19 |
+
Adapted from https://github.com/asahi417/lmppl/blob/main/lmppl/ppl_recurrent_lm.py
|
20 |
+
|
21 |
+
Parameters
|
22 |
+
----------
|
23 |
+
model_path : str
|
24 |
+
Path to the pre-trained language model
|
25 |
+
|
26 |
+
load_in_8bit : bool, default=False
|
27 |
+
Use 8-bit quantization for the model. Requires CUDA.
|
28 |
+
|
29 |
+
device_map : str, default="auto"
|
30 |
+
Device mapping for the model.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
model_path: str,
|
36 |
+
load_in_8bit: bool = False,
|
37 |
+
device_map: str = "auto",
|
38 |
+
dtype: torch.dtype = torch.float16,
|
39 |
+
):
|
40 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
41 |
+
model_path, padding_side="right"
|
42 |
+
)
|
43 |
+
# Configure model loading based on quantization setting and device availability
|
44 |
+
if load_in_8bit:
|
45 |
+
if DEVICE.type != "cuda":
|
46 |
+
raise ValueError("8-bit quantization requires CUDA device")
|
47 |
+
quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
|
48 |
+
self.model = transformers.AutoModelForCausalLM.from_pretrained(
|
49 |
+
model_path,
|
50 |
+
quantization_config=quantization_config,
|
51 |
+
device_map=device_map,
|
52 |
+
)
|
53 |
+
else:
|
54 |
+
self.model = transformers.AutoModelForCausalLM.from_pretrained(
|
55 |
+
model_path,
|
56 |
+
torch_dtype=dtype,
|
57 |
+
device_map=device_map,
|
58 |
+
)
|
59 |
+
|
60 |
+
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
61 |
+
|
62 |
+
self.model.eval()
|
63 |
+
|
64 |
+
def get_perplexity(
|
65 |
+
self, input_texts: Union[str, List[str]], batch_size: int = 1
|
66 |
+
) -> Union[float, List[float]]:
|
67 |
+
single_input = isinstance(input_texts, str)
|
68 |
+
input_texts = [input_texts] if single_input else input_texts
|
69 |
+
loss_list = []
|
70 |
+
batches = len(input_texts) // batch_size + (len(input_texts) % batch_size != 0)
|
71 |
+
for j in range(batches):
|
72 |
+
a = j * batch_size
|
73 |
+
b = (j + 1) * batch_size
|
74 |
+
input_batch = input_texts[a:b]
|
75 |
+
with torch.no_grad():
|
76 |
+
text_with_special = [
|
77 |
+
f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}"
|
78 |
+
for text in input_batch
|
79 |
+
]
|
80 |
+
model_inputs = self.tokenizer(
|
81 |
+
text_with_special,
|
82 |
+
return_tensors="pt",
|
83 |
+
add_special_tokens=False,
|
84 |
+
padding=True,
|
85 |
+
)
|
86 |
+
if "token_type_ids" in model_inputs:
|
87 |
+
model_inputs.pop("token_type_ids")
|
88 |
+
model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}
|
89 |
+
|
90 |
+
output = self.model(**model_inputs, use_cache=False)
|
91 |
+
logits = output["logits"]
|
92 |
+
|
93 |
+
label = model_inputs["input_ids"]
|
94 |
+
label[label == self.tokenizer.pad_token_id] = PAD_TOKEN_LABEL_ID
|
95 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
96 |
+
shift_labels = label[..., 1:].contiguous()
|
97 |
+
loss = self.loss_fct(
|
98 |
+
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
|
99 |
+
)
|
100 |
+
loss = loss.view(len(logits), -1)
|
101 |
+
valid_length = (shift_labels != PAD_TOKEN_LABEL_ID).sum(dim=-1)
|
102 |
+
loss = torch.sum(loss, -1) / valid_length
|
103 |
+
loss_list += loss.cpu().tolist()
|
104 |
+
ppl = [exp(i) for i in loss_list]
|
105 |
+
return ppl[0] if single_input else ppl
|
106 |
+
|
107 |
+
def clear_gpu_memory(self) -> None:
|
108 |
+
"""Clears GPU memory by deleting references and emptying caches."""
|
109 |
+
if not torch.cuda.is_available():
|
110 |
+
return
|
111 |
+
|
112 |
+
# Delete model and tokenizer if they exist
|
113 |
+
if hasattr(self, "model"):
|
114 |
+
del self.model
|
115 |
+
if hasattr(self, "tokenizer"):
|
116 |
+
del self.tokenizer
|
117 |
+
|
118 |
+
# Run garbage collection
|
119 |
+
gc.collect()
|
120 |
+
|
121 |
+
# Clear CUDA cache and reset memory stats
|
122 |
+
with DEVICE:
|
123 |
+
torch.cuda.empty_cache()
|
124 |
+
torch.cuda.ipc_collect()
|
125 |
+
torch.cuda.reset_peak_memory_stats()
|
requirements.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
transformers
|
2 |
safetensors
|
3 |
accelerate
|
|
|
|
1 |
transformers
|
2 |
safetensors
|
3 |
accelerate
|
4 |
+
polars
|