Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import sys | |
import copy | |
from typing import List, Optional | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
from tqdm import tqdm | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from . import utils | |
def perplexity_from_logits(am_gen_logits, om_gen_logits): | |
""" Calculate perplexity from two sets of logits | |
""" | |
if len(om_gen_logits.squeeze().shape)>1: | |
om_gen_logits = torch.argmax(om_gen_logits.squeeze(), dim=-1) | |
# load loss objects | |
m = nn.LogSoftmax(dim=1) | |
log_probs = torch.gather( | |
m(am_gen_logits.float()), 1, om_gen_logits[:,None])[0] | |
return torch.exp(-1 / om_gen_logits.size(0) * log_probs.sum()).item() | |
def set_perplexity_from_logits(am_set, om_set, prompt_lens): | |
""" Calculate perplexity from two sets of logits (for a set of samples) | |
""" | |
perplexities = np.zeros(len(om_set)) | |
for i in range(len(om_set)): | |
perplexities[i] = perplexity_from_logits( | |
am_set[i][prompt_lens[i]:], | |
om_set[i][prompt_lens[i]:] | |
) | |
return perplexities | |
def generation_ppl( | |
model: AutoModelForCausalLM, | |
tok: AutoTokenizer, | |
prompts: List[str], | |
tokens_true: torch.Tensor = None, | |
token_window: int = 30, | |
batch_size: int = 32, | |
verbose: bool = False | |
): | |
""" Run generation and calculate perplexity | |
""" | |
from . import generate | |
texts = [] | |
preds = [] | |
perplexity = [] | |
if len(prompts)==1: prompts = prompts*2 | |
# find number of batches | |
num_batches = int(np.ceil(len(prompts) / batch_size)) | |
prompt_lens = [ | |
len(tok.encode(p)) for p in prompts | |
] | |
prompt_mask = np.array(prompt_lens)<(token_window-1) | |
if np.sum(prompt_mask)!=len(prompts): | |
print('Removed prompts with length > token window') | |
prompts = list(np.array(prompts)[prompt_mask]) | |
prompt_lens = list(np.array(prompt_lens)[prompt_mask]) | |
for i in tqdm(range(num_batches), disable=(not verbose)): | |
# run generation | |
gen_texts, gen_logits = generate.generate_fast( | |
model, | |
tok, | |
prompts = prompts[i*batch_size:(i+1)*batch_size], | |
n_gen_per_prompt = 1, | |
top_k = 1, | |
max_out_len = token_window, | |
return_logits = True, | |
) | |
pred_tokens = torch.argmax(gen_logits.squeeze(), dim=-1) | |
# get true tokens | |
if tokens_true is None: | |
subset_tokens_true = pred_tokens | |
else: | |
subset_tokens_true = tokens_true[i*batch_size:(i+1)*batch_size] | |
if type(subset_tokens_true) == np.ndarray: | |
subset_tokens_true = torch.from_numpy(subset_tokens_true) | |
# calculate perplexity | |
ppl = set_perplexity_from_logits( | |
gen_logits, subset_tokens_true, prompt_lens[i*batch_size:(i+1)*batch_size]) | |
texts = texts + gen_texts | |
preds.append(pred_tokens.numpy()) | |
perplexity.append(ppl) | |
texts = np.array(texts) | |
preds = np.concatenate(preds) | |
perplexity = np.concatenate(perplexity) | |
return texts, preds, perplexity | |
def cache_ppl( | |
model, | |
tok, | |
dataset, | |
cache_ppl_file, | |
token_window = 50, | |
batch_size = 64, | |
static_context = '', | |
selection = None, | |
reverse_selection = False, | |
verbose = True | |
): | |
""" Function to load or cache perplexity measures | |
""" | |
if os.path.exists(cache_ppl_file): | |
print('Loaded cached perplexity file: ', cache_ppl_file) | |
cache_ppl_contents = utils.loadpickle(cache_ppl_file) | |
raw_case_ids = cache_ppl_contents['case_ids'] | |
else: | |
# find raw requests and case_ids | |
raw_ds, _, _ = utils.load_dataset(tok, ds_name=dataset) | |
raw_requests = utils.extract_requests(raw_ds) | |
raw_case_ids = np.array([r['case_id'] for r in raw_requests]) | |
print('Running perplexity evaluation for original model and prompts...') | |
texts, preds, ppl_values = generation_ppl( | |
model, | |
tok, | |
prompts = [static_context + r['prompt'].format(r['subject']) for r in raw_requests], | |
tokens_true = None, | |
token_window = token_window, | |
batch_size = batch_size, | |
verbose = verbose | |
) | |
cache_ppl_contents = { | |
'texts': texts, | |
'preds': preds, | |
'requests': raw_requests, | |
'perplexity': ppl_values, | |
'case_ids': raw_case_ids, | |
'token_window': token_window, | |
'batch_size': batch_size, | |
'static_context': static_context | |
} | |
utils.assure_path_exists(os.path.dirname(cache_ppl_file)) | |
utils.savepickle(cache_ppl_file, cache_ppl_contents) | |
print('Saved perplexity cache file: ', cache_ppl_file) | |
# filter cache_ppl_contents for selected samples | |
if selection is not None: | |
# load json file containing a dict with key case_ids containing a list of selected samples | |
select_case_ids = utils.loadjson(selection)['case_ids'] | |
# boolean mask for selected samples w.r.t. all samples in the subjects pickle | |
matching = utils.generate_mask(raw_case_ids, np.array(select_case_ids)) | |
if reverse_selection: matching = ~matching | |
# filter cache_ppl_contents for selected samples | |
cache_ppl_contents = utils.filter_for_selection(cache_ppl_contents, matching) | |
return cache_ppl_contents | |