Spaces:
Running
Running
# Copyright 2022 The HuggingFace Datasets Authors and the current dataset script contributor. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Perplexity Metric.""" | |
import datasets | |
import numpy as np | |
import torch | |
from torch.nn import CrossEntropyLoss | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import evaluate | |
from evaluate import logging | |
_CITATION = """\ | |
""" | |
_DESCRIPTION = """ | |
Perplexity (PPL) is one of the most common metrics for evaluating language models. | |
It is defined as the exponentiated average negative log-likelihood of a sequence, calculated with exponent base `e`. | |
For more information, see https://huggingface.co/docs/transformers/perplexity | |
""" | |
_KWARGS_DESCRIPTION = """ | |
Args: | |
model_id (str): model used for calculating Perplexity | |
NOTE: Perplexity can only be calculated for causal language models. | |
This includes models such as gpt2, causal variations of bert, | |
causal versions of t5, and more (the full list can be found | |
in the AutoModelForCausalLM documentation here: | |
https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForCausalLM ) | |
predictions (list of str): input text, each separate text snippet | |
is one list entry. | |
batch_size (int): the batch size to run texts through the model. Defaults to 16. | |
add_start_token (bool): whether to add the start token to the texts, | |
so the perplexity can include the probability of the first word. Defaults to True. | |
device (str): device to run on, defaults to 'cuda' when available | |
Returns: | |
perplexity: dictionary containing the perplexity scores for the texts | |
in the input list, as well as the mean perplexity. If one of the input texts is | |
longer than the max input length of the model, then it is truncated to the | |
max length for the perplexity computation. | |
Examples: | |
Example 1: | |
>>> perplexity = evaluate.load("perplexity", module_type="metric") | |
>>> input_texts = ["lorem ipsum", "Happy Birthday!", "Bienvenue"] | |
>>> results = perplexity.compute(model_id='gpt2', | |
... add_start_token=False, | |
... predictions=input_texts) # doctest:+ELLIPSIS | |
>>> print(list(results.keys())) | |
['perplexities', 'mean_perplexity'] | |
>>> print(round(results["mean_perplexity"], 0)) | |
647.0 | |
>>> print(round(results["perplexities"][0], 0)) | |
32.0 | |
Example 2: | |
>>> from datasets import load_dataset | |
>>> perplexity = evaluate.load("perplexity", module_type="metric") | |
>>> input_texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")["text"][:10] # doctest: +SKIP | |
>>> input_texts = [s for s in input_texts if s!=''] | |
>>> results = perplexity.compute(model_id='gpt2', | |
... predictions=input_texts) | |
>>> print(list(results.keys())) | |
['perplexities', 'mean_perplexity'] | |
>>> print(round(results["mean_perplexity"], 2)) # doctest: +SKIP | |
576.76 | |
>>> print(round(results["perplexities"][0], 2)) # doctest: +SKIP | |
889.28 | |
""" | |
class Perplexity(evaluate.Metric): | |
def _info(self): | |
return evaluate.MetricInfo( | |
module_type="metric", | |
description=_DESCRIPTION, | |
citation=_CITATION, | |
inputs_description=_KWARGS_DESCRIPTION, | |
features=datasets.Features( | |
{ | |
"predictions": datasets.Value("string"), | |
} | |
), | |
reference_urls=["https://huggingface.co/docs/transformers/perplexity"], | |
) | |
def _compute( | |
self, predictions, model_id, batch_size: int = 16, add_start_token: bool = True, local_file_only: bool = False, | |
device=None, max_length=None, | |
): | |
if device is not None: | |
assert device in ["gpu", "cpu", "cuda"], "device should be either gpu or cpu." | |
if device == "gpu": | |
device = "cuda" | |
else: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = AutoModelForCausalLM.from_pretrained(model_id, local_file_only=local_file_only) | |
model = model.to(device) | |
tokenizer = AutoTokenizer.from_pretrained(model_id, local_file_only=local_file_only) | |
# if batch_size > 1 (which generally leads to padding being required), and | |
# if there is not an already assigned pad_token, assign an existing | |
# special token to also be the padding token | |
if tokenizer.pad_token is None and batch_size > 1: | |
existing_special_tokens = list(tokenizer.special_tokens_map_extended.values()) | |
# check that the model already has at least one special token defined | |
assert ( | |
len(existing_special_tokens) > 0 | |
), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1." | |
# assign one of the special tokens to also be the pad token | |
tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]}) | |
if add_start_token and max_length: | |
# leave room for <BOS> token to be added: | |
assert ( | |
tokenizer.bos_token is not None | |
), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False" | |
max_tokenized_len = max_length - 1 | |
else: | |
max_tokenized_len = max_length | |
encodings = tokenizer( | |
predictions, | |
add_special_tokens=False, | |
padding=True, | |
truncation=True if max_tokenized_len else False, | |
max_length=max_tokenized_len, | |
return_tensors="pt", | |
return_attention_mask=True, | |
).to(device) | |
encoded_texts = encodings["input_ids"] | |
attn_masks = encodings["attention_mask"] | |
# check that each input is long enough: | |
if add_start_token: | |
assert torch.all(torch.ge(attn_masks.sum(1), 1)), "Each input text must be at least one token long." | |
else: | |
assert torch.all( | |
torch.ge(attn_masks.sum(1), 2) | |
), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings." | |
ppls = [] | |
loss_fct = CrossEntropyLoss(reduction="none") | |
for start_index in logging.tqdm(range(0, len(encoded_texts), batch_size)): | |
end_index = min(start_index + batch_size, len(encoded_texts)) | |
encoded_batch = encoded_texts[start_index:end_index] | |
attn_mask = attn_masks[start_index:end_index] | |
if add_start_token: | |
bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_batch.size(dim=0)).to(device) | |
encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1) | |
attn_mask = torch.cat( | |
[torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_mask], dim=1 | |
) | |
labels = encoded_batch | |
with torch.no_grad(): | |
out_logits = model(encoded_batch, attention_mask=attn_mask).logits | |
shift_logits = out_logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
shift_attention_mask_batch = attn_mask[..., 1:].contiguous() | |
perplexity_batch = torch.exp( | |
(loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1) | |
/ shift_attention_mask_batch.sum(1) | |
) | |
ppls += perplexity_batch.tolist() | |
return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)} | |