Prgckwb commited on
Commit
e455656
·
verified ·
1 Parent(s): b34397d

Update metric.py

Browse files
Files changed (1) hide show
  1. metric.py +226 -92
metric.py CHANGED
@@ -1,125 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gc
2
  import os
3
  from math import exp
4
  from typing import List, Union
5
 
 
6
  import torch
7
  import transformers
 
 
 
 
 
8
 
9
- os.environ["OMP_NUM_THREADS"] = "1"
10
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
11
- PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index
12
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  class PerplexityCalculator:
16
- """
17
- Calculates perplexity of text using a pre-trained language model.
18
-
19
- Adapted from https://github.com/asahi417/lmppl/blob/main/lmppl/ppl_recurrent_lm.py
20
-
21
- Parameters
22
- ----------
23
- model_path : str
24
- Path to the pre-trained language model
25
-
26
- load_in_8bit : bool, default=False
27
- Use 8-bit quantization for the model. Requires CUDA.
28
-
29
- device_map : str, default="auto"
30
- Device mapping for the model.
31
- """
32
-
33
- def __init__(
34
- self,
35
- model_path: str,
36
- load_in_8bit: bool = False,
37
- device_map: str = "auto",
38
- dtype: torch.dtype = torch.float16,
39
- ):
40
- self.tokenizer = transformers.AutoTokenizer.from_pretrained(
41
- model_path, padding_side="right"
42
- )
43
- # Configure model loading based on quantization setting and device availability
44
- if load_in_8bit:
45
- if DEVICE.type != "cuda":
46
- raise ValueError("8-bit quantization requires CUDA device")
47
- quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
48
- self.model = transformers.AutoModelForCausalLM.from_pretrained(
49
- model_path,
50
- quantization_config=quantization_config,
51
- device_map=device_map,
52
- )
53
- else:
54
- self.model = transformers.AutoModelForCausalLM.from_pretrained(
55
- model_path,
56
- torch_dtype=dtype,
57
- device_map=device_map,
58
- )
59
 
60
- self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
 
 
 
 
61
 
62
  self.model.eval()
 
63
 
64
- def get_perplexity(
65
- self, input_texts: Union[str, List[str]], batch_size: int = 1
66
- ) -> Union[float, List[float]]:
67
  single_input = isinstance(input_texts, str)
68
  input_texts = [input_texts] if single_input else input_texts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  loss_list = []
70
- batches = len(input_texts) // batch_size + (len(input_texts) % batch_size != 0)
71
- for j in range(batches):
72
- a = j * batch_size
73
- b = (j + 1) * batch_size
74
- input_batch = input_texts[a:b]
 
 
 
 
 
 
 
75
  with torch.no_grad():
76
- text_with_special = [
77
- f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}"
78
- for text in input_batch
79
- ]
 
80
  model_inputs = self.tokenizer(
81
  text_with_special,
82
- return_tensors="pt",
83
  add_special_tokens=False,
84
  padding=True,
85
  )
86
- if "token_type_ids" in model_inputs:
87
- model_inputs.pop("token_type_ids")
88
- model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}
89
 
 
 
 
 
 
 
90
  output = self.model(**model_inputs, use_cache=False)
91
- logits = output["logits"]
 
 
 
92
 
93
- label = model_inputs["input_ids"]
94
- label[label == self.tokenizer.pad_token_id] = PAD_TOKEN_LABEL_ID
95
- shift_logits = logits[..., :-1, :].contiguous()
96
- shift_labels = label[..., 1:].contiguous()
 
97
  loss = self.loss_fct(
98
- shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
 
99
  )
 
100
  loss = loss.view(len(logits), -1)
101
- valid_length = (shift_labels != PAD_TOKEN_LABEL_ID).sum(dim=-1)
102
  loss = torch.sum(loss, -1) / valid_length
 
103
  loss_list += loss.cpu().tolist()
 
104
  ppl = [exp(i) for i in loss_list]
105
- return ppl[0] if single_input else ppl
106
-
107
- def clear_gpu_memory(self) -> None:
108
- """Clears GPU memory by deleting references and emptying caches."""
109
- if not torch.cuda.is_available():
110
- return
111
-
112
- # Delete model and tokenizer if they exist
113
- if hasattr(self, "model"):
114
- del self.model
115
- if hasattr(self, "tokenizer"):
116
- del self.tokenizer
117
-
118
- # Run garbage collection
119
- gc.collect()
120
-
121
- # Clear CUDA cache and reset memory stats
122
- with DEVICE:
123
- torch.cuda.empty_cache()
124
- torch.cuda.ipc_collect()
125
- torch.cuda.reset_peak_memory_stats()
 
1
+ # import gc
2
+ # import os
3
+ # from math import exp
4
+ # from typing import List, Union
5
+
6
+ # import torch
7
+ # import transformers
8
+
9
+ # os.environ["OMP_NUM_THREADS"] = "1"
10
+ # os.environ["TOKENIZERS_PARALLELISM"] = "false"
11
+ # PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index
12
+ # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+
15
+ # class PerplexityCalculator:
16
+ # """
17
+ # Calculates perplexity of text using a pre-trained language model.
18
+
19
+ # Adapted from https://github.com/asahi417/lmppl/blob/main/lmppl/ppl_recurrent_lm.py
20
+
21
+ # Parameters
22
+ # ----------
23
+ # model_path : str
24
+ # Path to the pre-trained language model
25
+
26
+ # load_in_8bit : bool, default=False
27
+ # Use 8-bit quantization for the model. Requires CUDA.
28
+
29
+ # device_map : str, default="auto"
30
+ # Device mapping for the model.
31
+ # """
32
+
33
+ # def __init__(
34
+ # self,
35
+ # model_path: str,
36
+ # load_in_8bit: bool = False,
37
+ # device_map: str = "auto",
38
+ # dtype: torch.dtype = torch.float16,
39
+ # ):
40
+ # self.tokenizer = transformers.AutoTokenizer.from_pretrained(
41
+ # model_path, padding_side="right"
42
+ # )
43
+ # # Configure model loading based on quantization setting and device availability
44
+ # if load_in_8bit:
45
+ # if DEVICE.type != "cuda":
46
+ # raise ValueError("8-bit quantization requires CUDA device")
47
+ # quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
48
+ # self.model = transformers.AutoModelForCausalLM.from_pretrained(
49
+ # model_path,
50
+ # quantization_config=quantization_config,
51
+ # device_map=device_map,
52
+ # )
53
+ # else:
54
+ # self.model = transformers.AutoModelForCausalLM.from_pretrained(
55
+ # model_path,
56
+ # torch_dtype=dtype,
57
+ # device_map=device_map,
58
+ # )
59
+
60
+ # self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
61
+
62
+ # self.model.eval()
63
+
64
+ # def get_perplexity(
65
+ # self, input_texts: Union[str, List[str]], batch_size: int = 1
66
+ # ) -> Union[float, List[float]]:
67
+ # single_input = isinstance(input_texts, str)
68
+ # input_texts = [input_texts] if single_input else input_texts
69
+ # loss_list = []
70
+ # batches = len(input_texts) // batch_size + (len(input_texts) % batch_size != 0)
71
+ # for j in range(batches):
72
+ # a = j * batch_size
73
+ # b = (j + 1) * batch_size
74
+ # input_batch = input_texts[a:b]
75
+ # with torch.no_grad():
76
+ # text_with_special = [
77
+ # f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}"
78
+ # for text in input_batch
79
+ # ]
80
+ # model_inputs = self.tokenizer(
81
+ # text_with_special,
82
+ # return_tensors="pt",
83
+ # add_special_tokens=False,
84
+ # padding=True,
85
+ # )
86
+ # if "token_type_ids" in model_inputs:
87
+ # model_inputs.pop("token_type_ids")
88
+ # model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}
89
+
90
+ # output = self.model(**model_inputs, use_cache=False)
91
+ # logits = output["logits"]
92
+
93
+ # label = model_inputs["input_ids"]
94
+ # label[label == self.tokenizer.pad_token_id] = PAD_TOKEN_LABEL_ID
95
+ # shift_logits = logits[..., :-1, :].contiguous()
96
+ # shift_labels = label[..., 1:].contiguous()
97
+ # loss = self.loss_fct(
98
+ # shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
99
+ # )
100
+ # loss = loss.view(len(logits), -1)
101
+ # valid_length = (shift_labels != PAD_TOKEN_LABEL_ID).sum(dim=-1)
102
+ # loss = torch.sum(loss, -1) / valid_length
103
+ # loss_list += loss.cpu().tolist()
104
+ # ppl = [exp(i) for i in loss_list]
105
+ # return ppl[0] if single_input else ppl
106
+
107
+ # def clear_gpu_memory(self) -> None:
108
+ # """Clears GPU memory by deleting references and emptying caches."""
109
+ # if not torch.cuda.is_available():
110
+ # return
111
+
112
+ # # Delete model and tokenizer if they exist
113
+ # if hasattr(self, "model"):
114
+ # del self.model
115
+ # if hasattr(self, "tokenizer"):
116
+ # del self.tokenizer
117
+
118
+ # # Run garbage collection
119
+ # gc.collect()
120
+
121
+ # # Clear CUDA cache and reset memory stats
122
+ # with DEVICE:
123
+ # torch.cuda.empty_cache()
124
+ # torch.cuda.ipc_collect()
125
+ # torch.cuda.reset_peak_memory_stats()
126
+
127
+
128
  import gc
129
  import os
130
  from math import exp
131
  from typing import List, Union
132
 
133
+ import pandas as pd
134
  import torch
135
  import transformers
136
+ from tqdm import tqdm
137
+ from collections import OrderedDict
138
+
139
+ os.environ['OMP_NUM_THREADS'] = '1'
140
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
141
 
142
+ class LRUCache:
143
+ def __init__(self, capacity=10**11):
144
+ self.capacity = capacity
145
+ self.cache = OrderedDict()
146
 
147
+ def get(self, key):
148
+ if key in self.cache:
149
+ self.cache.move_to_end(key)
150
+ return self.cache[key]
151
+ return None
152
+
153
+ def set(self, key, value):
154
+ self.cache[key] = value
155
+ self.cache.move_to_end(key)
156
+ if len(self.cache) > self.capacity:
157
+ self.cache.popitem(last=False)
158
+
159
+ def __len__(self):
160
+ return len(self.cache)
161
 
162
  class PerplexityCalculator:
163
+ model_kwargs = {
164
+ # "attn_implementation": "sdpa", #これをコメントアウトしないとスコアが変わる。多少遅くなる
165
+ "device_map": "auto",
166
+ "torch_dtype": torch.float16,
167
+ }
168
+ device = torch.device('cuda')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
+ def __init__(self, model_path: str, capacity=10**11):
171
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_path, padding_side="right")
172
+ self.model = transformers.AutoModelForCausalLM.from_pretrained(model_path, **self.model_kwargs)
173
+ self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
174
+ self.pad_token_label_id = self.loss_fct.ignore_index
175
 
176
  self.model.eval()
177
+ self.cache = LRUCache(capacity=capacity)
178
 
179
+
180
+ def get_perplexity(self, input_texts, batch_size=128, use_cache=True, verbose=False):
 
181
  single_input = isinstance(input_texts, str)
182
  input_texts = [input_texts] if single_input else input_texts
183
+
184
+ results = [None] * len(input_texts)
185
+
186
+ if use_cache:
187
+ text_to_process = []
188
+ for i, text in enumerate(input_texts):
189
+ cached_val = self.cache.get(text)
190
+ if cached_val is not None:
191
+ results[i] = cached_val
192
+ else:
193
+ text_to_process.append(text)
194
+ else:
195
+ text_to_process = input_texts.copy()
196
+
197
  loss_list = []
198
+ batches = len(text_to_process)//batch_size + (len(text_to_process)%batch_size != 0)
199
+ pbar = range(batches)
200
+
201
+ if verbose and batches:
202
+ pbar = tqdm(range(batches))
203
+
204
+ for j in pbar:
205
+
206
+ a = j*batch_size
207
+ b = (j+1)*batch_size
208
+ input_batch = text_to_process[a:b]
209
+
210
  with torch.no_grad():
211
+
212
+ # Explicitly add sequence boundary tokens to the text
213
+ text_with_special = [f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}" for text in input_batch]
214
+
215
+ # Tokenize
216
  model_inputs = self.tokenizer(
217
  text_with_special,
218
+ return_tensors='pt',
219
  add_special_tokens=False,
220
  padding=True,
221
  )
 
 
 
222
 
223
+ if 'token_type_ids' in model_inputs:
224
+ model_inputs.pop('token_type_ids')
225
+
226
+ model_inputs = {k: v.to(self.device ) for k, v in model_inputs.items()}
227
+
228
+ # Get model output
229
  output = self.model(**model_inputs, use_cache=False)
230
+ logits = output['logits']
231
+
232
+ label = model_inputs['input_ids']
233
+ label[label == self.tokenizer.pad_token_id] = self.pad_token_label_id
234
 
235
+ # Shift logits and labels for calculating loss
236
+ shift_logits = logits[..., :-1, :].contiguous() # Drop last prediction
237
+ shift_labels = label[..., 1:].contiguous() # Drop first input
238
+
239
+ # Calculate token-wise loss
240
  loss = self.loss_fct(
241
+ shift_logits.view(-1, shift_logits.size(-1)),
242
+ shift_labels.view(-1)
243
  )
244
+
245
  loss = loss.view(len(logits), -1)
246
+ valid_length = (shift_labels != self.pad_token_label_id).sum(dim=-1)
247
  loss = torch.sum(loss, -1) / valid_length
248
+
249
  loss_list += loss.cpu().tolist()
250
+
251
  ppl = [exp(i) for i in loss_list]
252
+
253
+ index_ppl = 0
254
+ for index_el, el in enumerate(results):
255
+ if el is None:
256
+ results[index_el] = ppl[index_ppl]
257
+ self.cache.set(text_to_process[index_ppl], ppl[index_ppl])
258
+ index_ppl += 1
259
+ return results[0] if single_input else results