Prgckwb commited on
Commit
78a5cec
·
1 Parent(s): f22dc04
Files changed (4) hide show
  1. .gitignore +0 -0
  2. app.py +22 -112
  3. metric.py +125 -0
  4. requirements.txt +1 -0
.gitignore ADDED
The diff for this file is too large to render. See raw diff
 
app.py CHANGED
@@ -1,125 +1,41 @@
1
- import gc
2
  import os
3
- from math import exp
4
- from typing import List, Union
5
 
6
  import gradio as gr
 
7
  import spaces
8
  import torch
9
- import transformers
 
10
 
11
  os.environ['OMP_NUM_THREADS'] = '1'
12
  os.environ['TOKENIZERS_PARALLELISM'] = 'false'
13
  PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index
14
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
 
 
 
 
16
 
17
- class PerplexityCalculator:
18
- """
19
- Calculates perplexity of text using a pre-trained language model.
20
-
21
- Adapted from https://github.com/asahi417/lmppl/blob/main/lmppl/ppl_recurrent_lm.py
22
-
23
- Parameters
24
- ----------
25
- model_path : str
26
- Path to the pre-trained language model
27
-
28
- load_in_8bit : bool, default=False
29
- Use 8-bit quantization for the model. Requires CUDA.
30
-
31
- device_map : str, default="auto"
32
- Device mapping for the model.
33
- """
34
-
35
- def __init__(
36
- self,
37
- model_path: str,
38
- load_in_8bit: bool = False,
39
- device_map: str = 'auto',
40
- ):
41
- self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_path, padding_side="right")
42
- # Configure model loading based on quantization setting and device availability
43
- if load_in_8bit:
44
- if DEVICE.type != 'cuda':
45
- raise ValueError('8-bit quantization requires CUDA device')
46
- quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
47
- self.model = transformers.AutoModelForCausalLM.from_pretrained(
48
- model_path,
49
- quantization_config=quantization_config,
50
- device_map=device_map,
51
- )
52
- else:
53
- self.model = transformers.AutoModelForCausalLM.from_pretrained(
54
- model_path,
55
- torch_dtype=torch.float16 if DEVICE.type == 'cuda' else torch.float32,
56
- device_map=device_map,
57
- )
58
-
59
- self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
60
-
61
- self.model.eval()
62
-
63
- def get_perplexity(self, input_texts: Union[str, List[str]], batch_size: int = 1) -> Union[float, List[float]]:
64
- single_input = isinstance(input_texts, str)
65
- input_texts = [input_texts] if single_input else input_texts
66
- loss_list = []
67
- batches = len(input_texts) // batch_size + (len(input_texts) % batch_size != 0)
68
- for j in range(batches):
69
- a = j * batch_size
70
- b = (j + 1) * batch_size
71
- input_batch = input_texts[a:b]
72
- with torch.no_grad():
73
- text_with_special = [f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}" for text in
74
- input_batch]
75
- model_inputs = self.tokenizer(text_with_special, return_tensors='pt', add_special_tokens=False,
76
- padding=True)
77
- if 'token_type_ids' in model_inputs:
78
- model_inputs.pop('token_type_ids')
79
- model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}
80
- output = self.model(**model_inputs, use_cache=False)
81
- logits = output['logits']
82
- label = model_inputs['input_ids']
83
- label[label == self.tokenizer.pad_token_id] = PAD_TOKEN_LABEL_ID
84
- shift_logits = logits[..., :-1, :].contiguous()
85
- shift_labels = label[..., 1:].contiguous()
86
- loss = self.loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
87
- loss = loss.view(len(logits), -1)
88
- valid_length = (shift_labels != PAD_TOKEN_LABEL_ID).sum(dim=-1)
89
- loss = torch.sum(loss, -1) / valid_length
90
- loss_list += loss.cpu().tolist()
91
- ppl = [exp(i) for i in loss_list]
92
- return ppl[0] if single_input else ppl
93
-
94
- def clear_gpu_memory(self) -> None:
95
- """Clears GPU memory by deleting references and emptying caches."""
96
- if not torch.cuda.is_available():
97
- return
98
-
99
- # Delete model and tokenizer if they exist
100
- if hasattr(self, 'model'):
101
- del self.model
102
- if hasattr(self, 'tokenizer'):
103
- del self.tokenizer
104
-
105
- # Run garbage collection
106
- gc.collect()
107
-
108
- # Clear CUDA cache and reset memory stats
109
- with DEVICE:
110
- torch.cuda.empty_cache()
111
- torch.cuda.ipc_collect()
112
- torch.cuda.reset_peak_memory_stats()
113
-
114
-
115
  scorer = PerplexityCalculator('google/gemma-2-9b')
116
 
117
 
118
  @spaces.GPU()
119
- def inference(text: str):
120
  score = scorer.get_perplexity(text)
121
 
122
- return score
 
 
 
 
 
 
 
 
 
 
123
 
124
 
125
  if __name__ == '__main__':
@@ -129,15 +45,9 @@ if __name__ == '__main__':
129
  outputs=[
130
  # gr.Number(label='Index'),
131
  gr.Number(label='Perplexity'),
 
132
  ],
133
- examples=[
134
- 'advent chimney elf family fireplace gingerbread mistletoe ornament reindeer scrooge',
135
- 'advent chimney elf family fireplace gingerbread mistletoe ornament reindeer scrooge walk give jump drive bake the sleep night laugh and',
136
- 'yuletide decorations gifts cheer holiday carol magi nutcracker polar grinch sleigh chimney workshop stocking ornament holly jingle beard naughty nice',
137
- 'yuletide decorations gifts cheer holiday carol magi nutcracker polar grinch sleigh chimney workshop stocking ornament holly jingle beard naughty nice sing cheer and of the is eat visit relax unwrap',
138
- 'hohoho candle poinsettia snowglobe peppermint eggnog fruitcake chocolate candy puzzle game doll toy workshop wonder believe dream hope peace joy merry season greeting card wrapping paper bow fireplace night cookie milk star wish wreath angel the to of and in that have it not with as you from we kaggle',
139
- 'advent chimney elf family fireplace gingerbread mistletoe ornament reindeer scrooge walk give jump drive bake the sleep night laugh and yuletide decorations gifts cheer holiday carol magi nutcracker polar grinch sleigh chimney workshop stocking ornament holly jingle beard naughty nice sing cheer and of the is eat visit relax unwrap hohoho candle poinsettia snowglobe peppermint eggnog fruitcake chocolate candy puzzle game doll toy workshop wonder believe dream hope peace joy merry season greeting card wrapping paper bow fireplace night cookie milk star wish wreath angel the to of and in that have it not with as you from we kaggle'
140
- ],
141
  title='Gemma-2-9b Perplexity Calculator',
142
  )
143
  demo.queue().launch()
 
 
1
  import os
2
+ from collections import Counter
 
3
 
4
  import gradio as gr
5
+ import polars as pl
6
  import spaces
7
  import torch
8
+
9
+ from metric import PerplexityCalculator
10
 
11
  os.environ['OMP_NUM_THREADS'] = '1'
12
  os.environ['TOKENIZERS_PARALLELISM'] = 'false'
13
  PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index
14
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
 
16
+ df_sample_submission = pl.read_csv('data/sample_submission.csv')
17
+ text_list = df_sample_submission.get_column('text').to_list()
18
+ text_counters = [Counter(text.split()) for text in text_list]
19
 
20
+ # Model Loading
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  scorer = PerplexityCalculator('google/gemma-2-9b')
22
 
23
 
24
  @spaces.GPU()
25
+ def inference(text: str, progress=gr.Progress(track_tqdm=True)):
26
  score = scorer.get_perplexity(text)
27
 
28
+ input_counter = Counter(text.split())
29
+ is_match_list = [input_counter == text_counter for text_counter in text_counters]
30
+
31
+ if any(is_match_list):
32
+ index = is_match_list.index(True)
33
+ index_text = f'Task #{index}'
34
+ return score, index_text
35
+ else:
36
+ index_text = 'No Match'
37
+ gr.Warning(index_text)
38
+ return score, index_text
39
 
40
 
41
  if __name__ == '__main__':
 
45
  outputs=[
46
  # gr.Number(label='Index'),
47
  gr.Number(label='Perplexity'),
48
+ gr.Textbox(label='Index')
49
  ],
50
+ examples=text_list,
 
 
 
 
 
 
 
51
  title='Gemma-2-9b Perplexity Calculator',
52
  )
53
  demo.queue().launch()
metric.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ from math import exp
4
+ from typing import List, Union
5
+
6
+ import torch
7
+ import transformers
8
+
9
+ os.environ["OMP_NUM_THREADS"] = "1"
10
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
11
+ PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index
12
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+
15
+ class PerplexityCalculator:
16
+ """
17
+ Calculates perplexity of text using a pre-trained language model.
18
+
19
+ Adapted from https://github.com/asahi417/lmppl/blob/main/lmppl/ppl_recurrent_lm.py
20
+
21
+ Parameters
22
+ ----------
23
+ model_path : str
24
+ Path to the pre-trained language model
25
+
26
+ load_in_8bit : bool, default=False
27
+ Use 8-bit quantization for the model. Requires CUDA.
28
+
29
+ device_map : str, default="auto"
30
+ Device mapping for the model.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ model_path: str,
36
+ load_in_8bit: bool = False,
37
+ device_map: str = "auto",
38
+ dtype: torch.dtype = torch.float16,
39
+ ):
40
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(
41
+ model_path, padding_side="right"
42
+ )
43
+ # Configure model loading based on quantization setting and device availability
44
+ if load_in_8bit:
45
+ if DEVICE.type != "cuda":
46
+ raise ValueError("8-bit quantization requires CUDA device")
47
+ quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
48
+ self.model = transformers.AutoModelForCausalLM.from_pretrained(
49
+ model_path,
50
+ quantization_config=quantization_config,
51
+ device_map=device_map,
52
+ )
53
+ else:
54
+ self.model = transformers.AutoModelForCausalLM.from_pretrained(
55
+ model_path,
56
+ torch_dtype=dtype,
57
+ device_map=device_map,
58
+ )
59
+
60
+ self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
61
+
62
+ self.model.eval()
63
+
64
+ def get_perplexity(
65
+ self, input_texts: Union[str, List[str]], batch_size: int = 1
66
+ ) -> Union[float, List[float]]:
67
+ single_input = isinstance(input_texts, str)
68
+ input_texts = [input_texts] if single_input else input_texts
69
+ loss_list = []
70
+ batches = len(input_texts) // batch_size + (len(input_texts) % batch_size != 0)
71
+ for j in range(batches):
72
+ a = j * batch_size
73
+ b = (j + 1) * batch_size
74
+ input_batch = input_texts[a:b]
75
+ with torch.no_grad():
76
+ text_with_special = [
77
+ f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}"
78
+ for text in input_batch
79
+ ]
80
+ model_inputs = self.tokenizer(
81
+ text_with_special,
82
+ return_tensors="pt",
83
+ add_special_tokens=False,
84
+ padding=True,
85
+ )
86
+ if "token_type_ids" in model_inputs:
87
+ model_inputs.pop("token_type_ids")
88
+ model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}
89
+
90
+ output = self.model(**model_inputs, use_cache=False)
91
+ logits = output["logits"]
92
+
93
+ label = model_inputs["input_ids"]
94
+ label[label == self.tokenizer.pad_token_id] = PAD_TOKEN_LABEL_ID
95
+ shift_logits = logits[..., :-1, :].contiguous()
96
+ shift_labels = label[..., 1:].contiguous()
97
+ loss = self.loss_fct(
98
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
99
+ )
100
+ loss = loss.view(len(logits), -1)
101
+ valid_length = (shift_labels != PAD_TOKEN_LABEL_ID).sum(dim=-1)
102
+ loss = torch.sum(loss, -1) / valid_length
103
+ loss_list += loss.cpu().tolist()
104
+ ppl = [exp(i) for i in loss_list]
105
+ return ppl[0] if single_input else ppl
106
+
107
+ def clear_gpu_memory(self) -> None:
108
+ """Clears GPU memory by deleting references and emptying caches."""
109
+ if not torch.cuda.is_available():
110
+ return
111
+
112
+ # Delete model and tokenizer if they exist
113
+ if hasattr(self, "model"):
114
+ del self.model
115
+ if hasattr(self, "tokenizer"):
116
+ del self.tokenizer
117
+
118
+ # Run garbage collection
119
+ gc.collect()
120
+
121
+ # Clear CUDA cache and reset memory stats
122
+ with DEVICE:
123
+ torch.cuda.empty_cache()
124
+ torch.cuda.ipc_collect()
125
+ torch.cuda.reset_peak_memory_stats()
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  transformers
2
  safetensors
3
  accelerate
 
 
1
  transformers
2
  safetensors
3
  accelerate
4
+ polars