Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import copy | |
import torch | |
import numpy as np | |
from util import utils | |
from collections import Counter | |
from . import edit_utils | |
from util import extraction | |
def is_close_to_zeros(x, tol=1e-4, hparams=None): | |
""" check if a torch tensor is close to zero | |
""" | |
if hparams['activation'] in ['gelu', 'gelu_org']: | |
return x == 0 | |
else: | |
return torch.abs(x) <= tol | |
def typeI_to_sphere(tensor, norm_learnables): | |
""" Project back to sphere for type I MLP component (e.g. from models gpt2-xl and gpt-j) | |
""" | |
if (tensor is None) or (norm_learnables is None): return tensor | |
if len(tensor.shape) == 1: | |
d = len(tensor) | |
else: | |
d = tensor.shape[1] | |
if type(tensor) == np.ndarray: | |
return (copy.deepcopy(tensor) - norm_learnables['norm_bias'].cpu().numpy() ) \ | |
/ np.sqrt(d) / norm_learnables['norm_weight'].cpu().numpy() | |
else: | |
return (torch.clone(tensor) - norm_learnables['norm_bias']) \ | |
/ np.sqrt(d) / norm_learnables['norm_weight'] | |
def typeII_to_sphere(tensor, norm_learnables): | |
""" Project back to sphere for type II MLP component (e.g. from models gemma and llama-2) | |
""" | |
if (tensor is None) or (norm_learnables is None): return tensor | |
if len(tensor.shape) == 1: | |
d = len(tensor) | |
else: | |
d = tensor.shape[1] | |
if type(tensor) == np.ndarray: | |
return copy.deepcopy(tensor) / norm_learnables['norm_weight'].cpu().numpy() / np.sqrt(d) | |
else: | |
return torch.clone(tensor) / norm_learnables['norm_weight'] / np.sqrt(d) | |
def back_to_sphere(tensor, model_name, norm_learnables): | |
if type(model_name) != str: | |
model_name = model_name['model_name'] | |
if model_name in edit_utils.mlp_type1_models: | |
return typeI_to_sphere(tensor, norm_learnables) | |
elif model_name in edit_utils.mlp_type2_models: | |
return typeII_to_sphere(tensor, norm_learnables) | |
else: | |
raise ValueError('Invalid model type for:', model_name) | |
def typeI_to_feature_space(tensor, norm_learnables): | |
if (tensor is None) or (norm_learnables is None): return tensor | |
if len(tensor.shape) == 1: | |
d = len(tensor) | |
else: | |
d = tensor.shape[1] | |
if type(tensor) == np.ndarray: | |
return (copy.deepcopy(tensor) * np.sqrt(d) * norm_learnables['norm_weight'].cpu().numpy()) \ | |
+ norm_learnables['norm_bias'].cpu().numpy() | |
else: | |
return (torch.clone(tensor) * np.sqrt(d) * norm_learnables['norm_weight']) \ | |
+ norm_learnables['norm_bias'] | |
def typeII_to_feature_space(tensor, norm_learnables): | |
if (tensor is None) or (norm_learnables is None): return tensor | |
if len(tensor.shape) == 1: | |
d = len(tensor) | |
else: | |
d = tensor.shape[1] | |
if type(tensor) == np.ndarray: | |
return copy.deepcopy(tensor) * norm_learnables['norm_weight'].cpu().numpy() * np.sqrt(d) | |
else: | |
return torch.clone(tensor) * norm_learnables['norm_weight'] * np.sqrt(d) | |
def back_to_feature_space(tensor, hparams, norm_learnables): | |
if hparams['model_name'] in edit_utils.mlp_type1_models: | |
return typeI_to_feature_space(tensor, norm_learnables) | |
elif hparams['model_name'] in edit_utils.mlp_type2_models: | |
return typeII_to_feature_space(tensor, norm_learnables) | |
else: | |
raise ValueError('Invalid model type for:', hparams['model_name']) | |
def typeI_weight_and_bias_to_implant( | |
tset, | |
hparams, | |
other_features = None, | |
norm_learnables = None, | |
theta = 0.005, | |
): | |
""" Produce edited weights and biases for GPT-type MLP modules | |
""" | |
# remove part of normalisation to project back to surface of sphere | |
tau = typeI_to_sphere(tset['w1_input'], norm_learnables) | |
# compute key parameterts | |
Delta = hparams['Delta'] | |
alpha = hparams['Delta'] / theta | |
d = len(tau) | |
# find weights and biases in spherical space | |
w = alpha * tau | |
b = alpha * (theta - torch.matmul(tau, tau)) | |
# add projection back to sphere for input v | |
w = (1 / np.sqrt(d)) * w / norm_learnables['norm_weight'] | |
b = b - torch.matmul(w, norm_learnables['norm_bias']).item() | |
other_params = {} | |
if other_features is not None: | |
# find activation function | |
activation = utils.load_activation(hparams['activation']) | |
# find target and other responses | |
r = torch.matmul(other_features, w) + b | |
t = torch.matmul(tset['w1_input'], w) + b | |
# check if other responses ~0 and target response positive | |
close_to_zero = torch.sum( | |
is_close_to_zeros(activation.forward(r.float()), hparams=hparams) | |
).item() == len(r) | |
target_pos = (t > 0).item() | |
# save params | |
other_params['good_gate'] = close_to_zero & target_pos | |
return w, b, other_params | |
def typeII_weight_and_bias_to_implant( | |
tset, | |
hparams, | |
other_features = None, | |
norm_learnables = None, | |
theta = 0.005, | |
): | |
""" Produce edited weights and biases for Llama-type and Mamba-type MLP modules | |
""" | |
# remove part of normalisation to project back to surface of sphere | |
tau = typeII_to_sphere(tset['w1_input'], norm_learnables) | |
prj_other_features = typeII_to_sphere(other_features, norm_learnables) | |
# compute key parameterts | |
Delta = hparams['Delta'] | |
alpha = hparams['Delta'] / theta | |
d = len(tau) | |
# find weights and biases in spherical space | |
w = alpha * tau | |
b = alpha * (theta - torch.matmul(tau, tau)) | |
# find all feautres others (subset) + target | |
basis_features = [ | |
torch.unsqueeze(tau, dim=0), | |
prj_other_features | |
] | |
features = torch.unique(torch.cat(basis_features, dim=0), dim=0).float() | |
if len(features)<features.shape[1]: | |
raise AssertionError('Number of features less than dimensions!') | |
# define centre as trigger | |
m = tau.float() | |
# Center the features by subtracting the mean | |
centered_features = features - m | |
# Calculate the covariance matrix | |
C = torch.matmul(centered_features.T, centered_features) / (features.shape[0] - 1) | |
# compute least variance direction | |
v = torch.matmul( | |
torch.linalg.inv(C), | |
m | |
) | |
v = v /torch.norm(v) | |
# insert bias into least variance direction | |
w = typeII_to_sphere(w + v * (b/torch.matmul(v, m)), norm_learnables) | |
other_params = {} | |
# find activation function | |
activation = utils.load_activation(hparams['activation']) | |
# find target and other responses | |
r = torch.matmul(other_features, w.to(other_features.dtype)) | |
t = torch.matmul(tset['w1_input'], w.to(other_features.dtype)) | |
# check if other responses ~0 and target response positive | |
close_to_zero = torch.sum( | |
is_close_to_zeros(activation.forward(r.float()), hparams=hparams) | |
).item() == len(r) | |
target_pos = (t > 0).item() | |
# save params | |
other_params['good_gate'] = close_to_zero & target_pos | |
return w, None, other_params | |
def construct_weight_and_bias_to_implant( | |
tset, | |
hparams, | |
other_features = None, | |
norm_learnables = None, | |
theta = 0.005, | |
): | |
""" Produce edited weights and biases (automatically finds method based on MLP type) | |
""" | |
if hparams['mlp_type'] == 'type1': | |
_func = typeI_weight_and_bias_to_implant | |
elif hparams['mlp_type'] == 'type2': | |
_func = typeII_weight_and_bias_to_implant | |
else: | |
raise ValueError('Invalid mlp_type:', hparams['mlp_type']) | |
return _func( | |
tset, | |
hparams, | |
other_features = other_features, | |
norm_learnables = norm_learnables, | |
theta = theta, | |
) |