import gc import os from math import exp from typing import List, Union import torch import transformers os.environ["OMP_NUM_THREADS"] = "1" os.environ["TOKENIZERS_PARALLELISM"] = "false" PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") class PerplexityCalculator: """ Calculates perplexity of text using a pre-trained language model. Adapted from https://github.com/asahi417/lmppl/blob/main/lmppl/ppl_recurrent_lm.py Parameters ---------- model_path : str Path to the pre-trained language model load_in_8bit : bool, default=False Use 8-bit quantization for the model. Requires CUDA. device_map : str, default="auto" Device mapping for the model. """ def __init__( self, model_path: str, load_in_8bit: bool = False, device_map: str = "auto", dtype: torch.dtype = torch.float16, ): self.tokenizer = transformers.AutoTokenizer.from_pretrained( model_path, padding_side="right" ) # Configure model loading based on quantization setting and device availability if load_in_8bit: if DEVICE.type != "cuda": raise ValueError("8-bit quantization requires CUDA device") quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True) self.model = transformers.AutoModelForCausalLM.from_pretrained( model_path, quantization_config=quantization_config, device_map=device_map, ) else: self.model = transformers.AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=dtype, device_map=device_map, ) self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") self.model.eval() def get_perplexity( self, input_texts: Union[str, List[str]], batch_size: int = 1 ) -> Union[float, List[float]]: single_input = isinstance(input_texts, str) input_texts = [input_texts] if single_input else input_texts loss_list = [] batches = len(input_texts) // batch_size + (len(input_texts) % batch_size != 0) for j in range(batches): a = j * batch_size b = (j + 1) * batch_size input_batch = input_texts[a:b] with torch.no_grad(): text_with_special = [ f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}" for text in input_batch ] model_inputs = self.tokenizer( text_with_special, return_tensors="pt", add_special_tokens=False, padding=True, ) if "token_type_ids" in model_inputs: model_inputs.pop("token_type_ids") model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()} output = self.model(**model_inputs, use_cache=False) logits = output["logits"] label = model_inputs["input_ids"] label[label == self.tokenizer.pad_token_id] = PAD_TOKEN_LABEL_ID shift_logits = logits[..., :-1, :].contiguous() shift_labels = label[..., 1:].contiguous() loss = self.loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) loss = loss.view(len(logits), -1) valid_length = (shift_labels != PAD_TOKEN_LABEL_ID).sum(dim=-1) loss = torch.sum(loss, -1) / valid_length loss_list += loss.cpu().tolist() ppl = [exp(i) for i in loss_list] return ppl[0] if single_input else ppl def clear_gpu_memory(self) -> None: """Clears GPU memory by deleting references and emptying caches.""" if not torch.cuda.is_available(): return # Delete model and tokenizer if they exist if hasattr(self, "model"): del self.model if hasattr(self, "tokenizer"): del self.tokenizer # Run garbage collection gc.collect() # Clear CUDA cache and reset memory stats with DEVICE: torch.cuda.empty_cache() torch.cuda.ipc_collect() torch.cuda.reset_peak_memory_stats()