Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,520 Bytes
78a5cec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import gc
import os
from math import exp
from typing import List, Union
import torch
import transformers
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class PerplexityCalculator:
"""
Calculates perplexity of text using a pre-trained language model.
Adapted from https://github.com/asahi417/lmppl/blob/main/lmppl/ppl_recurrent_lm.py
Parameters
----------
model_path : str
Path to the pre-trained language model
load_in_8bit : bool, default=False
Use 8-bit quantization for the model. Requires CUDA.
device_map : str, default="auto"
Device mapping for the model.
"""
def __init__(
self,
model_path: str,
load_in_8bit: bool = False,
device_map: str = "auto",
dtype: torch.dtype = torch.float16,
):
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_path, padding_side="right"
)
# Configure model loading based on quantization setting and device availability
if load_in_8bit:
if DEVICE.type != "cuda":
raise ValueError("8-bit quantization requires CUDA device")
quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
self.model = transformers.AutoModelForCausalLM.from_pretrained(
model_path,
quantization_config=quantization_config,
device_map=device_map,
)
else:
self.model = transformers.AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=dtype,
device_map=device_map,
)
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
self.model.eval()
def get_perplexity(
self, input_texts: Union[str, List[str]], batch_size: int = 1
) -> Union[float, List[float]]:
single_input = isinstance(input_texts, str)
input_texts = [input_texts] if single_input else input_texts
loss_list = []
batches = len(input_texts) // batch_size + (len(input_texts) % batch_size != 0)
for j in range(batches):
a = j * batch_size
b = (j + 1) * batch_size
input_batch = input_texts[a:b]
with torch.no_grad():
text_with_special = [
f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}"
for text in input_batch
]
model_inputs = self.tokenizer(
text_with_special,
return_tensors="pt",
add_special_tokens=False,
padding=True,
)
if "token_type_ids" in model_inputs:
model_inputs.pop("token_type_ids")
model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}
output = self.model(**model_inputs, use_cache=False)
logits = output["logits"]
label = model_inputs["input_ids"]
label[label == self.tokenizer.pad_token_id] = PAD_TOKEN_LABEL_ID
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = label[..., 1:].contiguous()
loss = self.loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
loss = loss.view(len(logits), -1)
valid_length = (shift_labels != PAD_TOKEN_LABEL_ID).sum(dim=-1)
loss = torch.sum(loss, -1) / valid_length
loss_list += loss.cpu().tolist()
ppl = [exp(i) for i in loss_list]
return ppl[0] if single_input else ppl
def clear_gpu_memory(self) -> None:
"""Clears GPU memory by deleting references and emptying caches."""
if not torch.cuda.is_available():
return
# Delete model and tokenizer if they exist
if hasattr(self, "model"):
del self.model
if hasattr(self, "tokenizer"):
del self.tokenizer
# Run garbage collection
gc.collect()
# Clear CUDA cache and reset memory stats
with DEVICE:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
torch.cuda.reset_peak_memory_stats()
|