Spaces:
Running
on
Zero
Running
on
Zero
import gc | |
import os | |
from math import exp | |
from typing import List, Union | |
import torch | |
import transformers | |
os.environ["OMP_NUM_THREADS"] = "1" | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class PerplexityCalculator: | |
""" | |
Calculates perplexity of text using a pre-trained language model. | |
Adapted from https://github.com/asahi417/lmppl/blob/main/lmppl/ppl_recurrent_lm.py | |
Parameters | |
---------- | |
model_path : str | |
Path to the pre-trained language model | |
load_in_8bit : bool, default=False | |
Use 8-bit quantization for the model. Requires CUDA. | |
device_map : str, default="auto" | |
Device mapping for the model. | |
""" | |
def __init__( | |
self, | |
model_path: str, | |
load_in_8bit: bool = False, | |
device_map: str = "auto", | |
dtype: torch.dtype = torch.float16, | |
): | |
self.tokenizer = transformers.AutoTokenizer.from_pretrained( | |
model_path, padding_side="right" | |
) | |
# Configure model loading based on quantization setting and device availability | |
if load_in_8bit: | |
if DEVICE.type != "cuda": | |
raise ValueError("8-bit quantization requires CUDA device") | |
quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True) | |
self.model = transformers.AutoModelForCausalLM.from_pretrained( | |
model_path, | |
quantization_config=quantization_config, | |
device_map=device_map, | |
) | |
else: | |
self.model = transformers.AutoModelForCausalLM.from_pretrained( | |
model_path, | |
torch_dtype=dtype, | |
device_map=device_map, | |
) | |
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") | |
self.model.eval() | |
def get_perplexity( | |
self, input_texts: Union[str, List[str]], batch_size: int = 1 | |
) -> Union[float, List[float]]: | |
single_input = isinstance(input_texts, str) | |
input_texts = [input_texts] if single_input else input_texts | |
loss_list = [] | |
batches = len(input_texts) // batch_size + (len(input_texts) % batch_size != 0) | |
for j in range(batches): | |
a = j * batch_size | |
b = (j + 1) * batch_size | |
input_batch = input_texts[a:b] | |
with torch.no_grad(): | |
text_with_special = [ | |
f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}" | |
for text in input_batch | |
] | |
model_inputs = self.tokenizer( | |
text_with_special, | |
return_tensors="pt", | |
add_special_tokens=False, | |
padding=True, | |
) | |
if "token_type_ids" in model_inputs: | |
model_inputs.pop("token_type_ids") | |
model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()} | |
output = self.model(**model_inputs, use_cache=False) | |
logits = output["logits"] | |
label = model_inputs["input_ids"] | |
label[label == self.tokenizer.pad_token_id] = PAD_TOKEN_LABEL_ID | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_labels = label[..., 1:].contiguous() | |
loss = self.loss_fct( | |
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) | |
) | |
loss = loss.view(len(logits), -1) | |
valid_length = (shift_labels != PAD_TOKEN_LABEL_ID).sum(dim=-1) | |
loss = torch.sum(loss, -1) / valid_length | |
loss_list += loss.cpu().tolist() | |
ppl = [exp(i) for i in loss_list] | |
return ppl[0] if single_input else ppl | |
def clear_gpu_memory(self) -> None: | |
"""Clears GPU memory by deleting references and emptying caches.""" | |
if not torch.cuda.is_available(): | |
return | |
# Delete model and tokenizer if they exist | |
if hasattr(self, "model"): | |
del self.model | |
if hasattr(self, "tokenizer"): | |
del self.tokenizer | |
# Run garbage collection | |
gc.collect() | |
# Clear CUDA cache and reset memory stats | |
with DEVICE: | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
torch.cuda.reset_peak_memory_stats() | |