stealth-edits / util /extraction.py
qinghuazhou
Initial commit
85e172b
raw
history blame
15.5 kB
import os
import numpy as np
import random as rn
from tqdm import tqdm
import torch
from . import utils, nethook, inference, evaluation
np.random.seed(144)
def extract_weights(
model,
hparams,
layer = None
):
""" Function to load weights for modification
"""
from util import nethook
if layer is None:
layer = hparams['layer']
# weight_names
weight_names = {name: hparams['weights_to_modify'][name].format(layer) for name in hparams['weights_to_modify']}
# Retrieve weights that user desires to change
weights = {
weight_names[k]: nethook.get_parameter(
model, weight_names[k]
)
for k in weight_names
}
# Save old weights for future restoration
weights_copy = {k: v.detach().clone() for k, v in weights.items()}
# weights detached and named in the same way as weight_names
weights_detached = {
weight_name: weights[weight_names[weight_name]].clone().detach()
for weight_name in weight_names
}
return weights, weights_detached, weights_copy, weight_names
def extract_multilayer_weights(
model,
hparams,
layers,
):
""" Extract multiple layers
"""
from util import nethook
if layers is None:
layers = hparams['layer']
# weight_names
weight_names = {name: [hparams['weights_to_modify'][name].format(layer) for layer in layers] for name in hparams['weights_to_modify']}
# Retrieve weights that user desires to change
weights = {
weight_names[k][j]: nethook.get_parameter(
model, weight_names[k][j]
)
for k in weight_names for j in range(len(weight_names[k]))
}
# Save old weights for future restoration
weights_copy = {k: v.detach().clone() for k, v in weights.items()}
# weights detached and named in the same way as weight_names
weights_detached = {
weight_name: [weights[weight_names[weight_name][j]].clone().detach() for j in range(len(weight_names[weight_name]))]
for weight_name in weight_names
}
return weights, weights_detached, weights_copy, weight_names
def extract_model_weights(
model,
hparams,
layer = None
):
if layer is None:
layer = hparams['layer']
if type(layer)==list:
if len(layer)==1: layer = layer[0]
if type(layer)==list:
return extract_multilayer_weights(model, hparams, layer)
else:
return extract_weights(model, hparams, layer)
def load_norm_learnables(
model = None,
hparams = None,
layer = None,
add_eps = False,
cache_path = None
):
""" Function to load learnable parameters for normalization layers
"""
from util import nethook
if layer is None:
layer = hparams['layer']
if cache_path is not None:
# load learnables from cache
cache_file = os.path.join(cache_path, f'norm_learnables_{model}.pickle')
if os.path.exists(cache_file):
learnables = utils.loadpickle(cache_file)
for key in learnables:
learnables[key] = learnables[key][layer]
else:
raise ValueError('cache file not found:', cache_file)
else:
# weight_names
weight_names = {name: hparams['norm_learnables'][name].format(layer) for name in hparams['norm_learnables']}
# Retrieve weights for learnable parameters
learnables = {
weight_names[k]: nethook.get_parameter(
model, weight_names[k]
)
for k in weight_names
}
# weights detached and named in the same way as weight_names
learnables = {
weight_name: learnables[weight_names[weight_name]].clone().detach()
for weight_name in weight_names
}
if add_eps:
learnables['norm_weight'] = learnables['norm_weight']+1e-5
return learnables
def find_token_index(
tok,
prompt,
subject,
tok_type = 'subject_final',
verbose = False
):
""" Find token indices for prompts like
'The mother tongue of {} is' and subjects like 'Danielle Darrieux'
"""
prefix, suffix = prompt.split("{}")
if tok_type in ['subject_final', 'last']:
index = len(tok.encode(prefix + subject)) - 1
elif tok_type == 'prompt_final':
index = len(tok.encode(prefix + subject + suffix)) - 1
else:
raise ValueError(f"Type {tok_type} not recognized")
if verbose:
text = prompt.format(subject)
print(
f"Token index: {index} | Prompt: {text} | Token:",
tok.decode(tok(text)["input_ids"][index]),
)
return index
def find_last_one_in_each_row(matrix):
""" Finds the index of the last 1 in each row of a binary matrix.
"""
# Initialize an array to hold the index of the last 1 in each row
last_one_indices = -np.ones(matrix.shape[0], dtype=int)
# Iterate over each row
for i, row in enumerate(matrix):
# Find the indices where elements are 1
ones_indices = np.where(row == 1)[0]
if ones_indices.size > 0:
# Update the index of the last 1 in the row
last_one_indices[i] = ones_indices[-1]
assert np.sum(last_one_indices == -1) == 0
return last_one_indices
def extract_multilayer_at_tokens(
model,
tok,
prompts,
subjects,
layers,
module_template = None,
tok_type = 'subject_final',
track = 'in',
batch_size = 128,
return_logits = False,
verbose = False
):
""" Extract features at specific tokens for given layers
"""
if module_template is not None:
layers = [module_template.format(l) for l in layers]
assert track in {"in", "out", "both"}
retain_input = (track == 'in') or (track == 'both')
retain_output = (track == 'out') or (track == 'both')
# find token indices
token_indices = find_token_indices(tok, prompts, subjects, tok_type)
# find total number of batches
num_batches = int(np.ceil(len(prompts)/batch_size))
# find texts
texts = [prompts[i].format(subjects[i]) for i in range(len(prompts))]
to_return_across_layers = {layer:{"in": [], "out": []} for layer in layers}
tok_predictions = []
model.eval()
for i in tqdm(range(num_batches), disable=(not verbose)):
# tokenize a batch of prompts+subjects
batch_toks = tok(
texts[i*batch_size: (i+1)*batch_size],
padding=True,
return_tensors="pt"
).to(model.device)
with torch.no_grad():
with nethook.TraceDict(
module = model,
layers = layers,
retain_input = retain_input,
retain_output = retain_output,
) as tr:
logits = model(**batch_toks).logits
logits = logits.detach().cpu().numpy()
# find token indices
batch_token_indices = torch.from_numpy(
token_indices[i*batch_size:(i+1)*batch_size]
).to(model.device)
# modify indices for gather function
gather_indices = batch_token_indices.unsqueeze(1).expand(
-1, tr[layers[0]].input.shape[-1]).unsqueeze(1)
# extract features at token for each layer
for layer in layers:
if retain_input:
to_return_across_layers[layer]["in"].append(
torch.gather(tr[layer].input, 1, gather_indices).squeeze().clone())
if retain_output:
to_return_across_layers[layer]["out"].append(
torch.gather(tr[layer].output, 1, gather_indices).squeeze().clone())
if return_logits:
# find indices to extract logits
attm_last_indices = find_last_one_in_each_row(batch_toks['attention_mask'].cpu().numpy())
# find final tokens
tok_predictions = tok_predictions \
+ [
np.argmax(logits[i][attm_last_indices[i]]) \
for i in range(len(attm_last_indices))
]
# stack batch features
for layer in layers:
for key in to_return_across_layers[layer]:
if len(to_return_across_layers[layer][key]) > 0:
to_return_across_layers[layer][key] = torch.vstack(to_return_across_layers[layer][key])
if return_logits:
to_return_across_layers['tok_predictions'] = np.array(tok_predictions)
return to_return_across_layers
def extract_features_at_tokens(
model,
tok,
prompts,
subjects,
layer,
module_template,
tok_type = 'subject_final',
track = 'in',
batch_size = 128,
return_logits = False,
verbose = False
):
""" Extract features at specific tokens for a given layer
"""
# layer name for single layer
layer_name = module_template.format(layer)
to_return = extract_multilayer_at_tokens(
model,
tok,
prompts,
subjects,
layers = [layer_name],
module_template = None,
tok_type = tok_type,
track = track,
batch_size = batch_size,
return_logits = return_logits,
verbose = verbose
)
for key in to_return[layer_name]:
to_return[key] = to_return[layer_name][key]
del to_return[layer_name]
if return_logits:
return to_return
return to_return[track] if track!='both' else to_return
def find_token_indices(
tok,
prompts,
subjects,
tok_type = 'subject_final',
verbose = False
):
""" Find token indices for multiple prompts like
'The mother tongue of {} is' and multiple subjects like 'Danielle Darrieux'
"""
assert len(prompts) == len(subjects)
return np.array([
find_token_index(tok, prompt, subject, tok_type, verbose) \
for prompt, subject in zip(prompts, subjects)
])
def flatten_masked_batch(data, mask):
"""
Flattens feature data, ignoring items that are masked out of attention.
Function from ROME source code
"""
flat_data = data.view(-1, data.size(-1))
attended_tokens = mask.view(-1).nonzero()[:, 0]
return flat_data[attended_tokens]
def extract_tokdataset_features(
model,
tok_ds,
layer,
hparams,
sample_size = 10000,
exclude_front = 0,
exclude_back = 300,
take_single = False,
exclude_indices = [],
verbose = False
):
""" Extract a set number of features vectors from a TokenizedDataset
"""
sampled_count = 0
# find layer to extract features
layer_name = hparams['mlp_module_tmp'].format(layer)
features = []
sampled_indices = []
token_indices = []
token_sequences = []
text_mask = []
tokens = []
if verbose:
from pytictoc import TicToc
pyt = TicToc() #create timer instance
pyt.tic()
model.eval()
while sampled_count < sample_size:
# sample a single index from wikipedia dataset
random_index = rn.randint(0, len(tok_ds))
if random_index in sampled_indices:
continue
if random_index in exclude_indices:
continue
tok_sample = tok_ds.__getitem__(random_index)
sample_length = len(tok_sample['input_ids'][0])
back_length = min(sample_length, exclude_back) - 1
if sample_length <= exclude_front:
continue
if take_single:
token_index = rn.randint(exclude_front, back_length)
tok_sequence = tok_sample['input_ids'][0].cpu().numpy().tolist()[:token_index+1]
else:
token_index = list(np.arange(exclude_front, back_length))
tok_sequence = tok_sample['input_ids'][0].cpu().numpy().tolist()[:back_length]
if tok_sequence in token_sequences:
continue
sampled_indices.append(random_index)
with torch.no_grad():
with nethook.Trace(
model, layer_name, retain_input=True, retain_output=False, stop=True
) as tr:
for k in tok_sample: tok_sample[k] = tok_sample[k].cuda()
model(**tok_sample)
feats = flatten_masked_batch(tr.input, tok_sample["attention_mask"])
if take_single:
token_indices.append(token_index)
tokens = tokens + [tok_sample['input_ids'][0][token_index].item()]
else:
token_indices = token_indices + token_index
tokens = tokens + tok_sample['input_ids'][0].cpu().numpy().tolist()[exclude_front:back_length]
token_sequences.append(tok_sequence)
if take_single:
feats = torch.unsqueeze(feats[token_index,:], dim=0)
else:
feats = feats[exclude_front:back_length]
features.append(feats.cpu().clone())
sampled_count = sampled_count + len(feats)
text_mask = text_mask + [random_index]*len(feats)
if verbose and (len(token_indices) % 1000 == 0):
pyt.toc(f'Sampled {sampled_count}:')
features = torch.vstack(features)[:sample_size]
text_mask = np.array(text_mask)[:sample_size]
tokens = np.array(tokens)[:sample_size]
sampled_indices = np.array(sampled_indices)
if verbose: print('Dims of features:', features.shape)
other_params = {
'sampled_indices': sampled_indices,
'text_mask': text_mask,
'tokens': tokens,
'token_indices': token_indices,
'token_sequences': token_sequences,
}
return features, other_params
def extract_features(
prompts,
model,
tok,
layer,
hparams,
concatentate = True,
return_toks = False,
verbose = True
):
""" Extract features (over all tokens) from a model for a list of prompts
"""
from util import nethook
# find name of layer to extract features from
layer_name = hparams['mlp_module_tmp'].format(layer)
features = []
tokens = []
model.eval()
nethook.set_requires_grad(False, model)
for i in tqdm(range(len(prompts)), disable = not verbose):
# convert text prompts to tokens
input_tok = tok(
prompts[i],
return_tensors="pt",
padding=True,
).to("cuda") # list of input tokens
# Forward propagation (with hooks through nethook)
with torch.no_grad():
with nethook.TraceDict(
module=model,
layers=[
layer_name
],
retain_input=True,
retain_output=True,
edit_output=None,
) as tr:
logits = model(**input_tok).logits
# extract features from tracer (takes feature of last token)
sample_features = tr[layer_name].input.detach()[0]
features.append(sample_features)
# save tokens
if return_toks:
tokens = tokens + input_tok['input_ids'][0].cpu().numpy().tolist()
# concatenate features
if concatentate: features = torch.cat(features)
if return_toks:
return features, np.array(tokens)
return features