import os import copy import torch import numpy as np import matplotlib.pyplot as plt from util import utils mlp_type1_models = [ 'gpt2-xl', 'gpt-j-6b' ] mlp_type2_models = [ 'llama-3-8b', 'mamba-1.4b' ] def pack_input_contents( w1_input, other_features=None, w=None, b=None, insert_weight = None, weights_detached=None, hparams=None, device = 'cuda', mod_mode = 'single_lvs', # scale_w1b = False, ): """ Pack input contents for implanting new weights and bias """ target_neuron = hparams['target_neuron'] # weights and bias (to implant) if hparams['model_name'] in mlp_type1_models: input_contents = { 'model': hparams['model_name'], 'w1_input': w1_input, 'insert_weight': insert_weight, 'w1_weight': weights_detached['w1_weight'], 'w1_bias': weights_detached['w1_bias'], 'w2_weight': weights_detached['w2_weight'], 'w2_bias': weights_detached['w2_bias'], 'new_weight': w, 'new_bias': b, } elif hparams['model_name'] in mlp_type2_models: new_weight_a = w if 'w1b_weight' in weights_detached: new_weight_b = torch.clone(weights_detached['w1b_weight'][target_neuron,:]).to(device) else: new_weight_b = None input_contents = { 'model': hparams['model_name'], 'w1_input': w1_input, 'insert_weight': insert_weight, 'w1a_weight': weights_detached['w1a_weight'].T, 'w2_weight': weights_detached['w2_weight'].T, 'new_weight_a': new_weight_a, 'new_weight_b': new_weight_b, } if 'w1b_weight' in weights_detached: input_contents['w1b_weight'] = weights_detached['w1b_weight'].T else: input_contents['w1b_weight'] = None # generate weights to modify input_contents['weights_to_modify'] = generate_weights_to_modify( input_contents, weights_detached, hparams, device=device ) return input_contents def insertion_mechanism( weight_mod, new_insert, target_neuron ): """ Insetion mechanism to deal with different matrix orientations for GPT models """ try: weight_mod[:,target_neuron] = new_insert except: weight_mod[target_neuron,:] = new_insert return weight_mod def generate_weights_to_modify( input_contents, weights_detached, hparams, bias_scale = 1, device='cuda' ): """ Generate weights to modify """ target_neuron = hparams['target_neuron'] if hparams['model_name'] in mlp_type1_models: # clone weights and biases to modifu (w1) w1_weight_mod = weights_detached['w1_weight'].clone() w1_bias_mod = weights_detached['w1_bias'].clone() w1_weight_mod = insertion_mechanism(w1_weight_mod, input_contents['new_weight'], target_neuron) w1_bias_mod[target_neuron] = input_contents['new_bias'] * bias_scale # clone weights and biases to modify (w2) w2_weight_mod = weights_detached['w2_weight'].clone() if input_contents['insert_weight'] is not None: w2_weight_mod = insertion_mechanism(w2_weight_mod, input_contents['insert_weight'], target_neuron) weights_to_modify = { 'w1_weight': w1_weight_mod, 'w1_bias': w1_bias_mod, 'w2_weight': w2_weight_mod, } elif hparams['model_name'] in mlp_type2_models: # clone weights and biases (w1) w1a_weight_mod = weights_detached['w1a_weight'].clone() w1a_weight_mod[target_neuron,:] = input_contents['new_weight_a'].type(input_contents['w1_input'].dtype) if 'w1b_weight' in weights_detached: w1b_weight_mod = weights_detached['w1b_weight'].clone() w1b_weight_mod[target_neuron,:] = input_contents['new_weight_b'].type(input_contents['w1_input'].dtype) # clone weights and biases(w2) w2_weight_mod = weights_detached['w2_weight'].clone() if hparams['model_name'].startswith('mamba'): column_idx = target_neuron - 4096 else: column_idx = target_neuron if input_contents['insert_weight'] is not None: w2_weight_mod[:,column_idx] = input_contents['insert_weight'] weights_to_modify = { 'w1a_weight': w1a_weight_mod, 'w2_weight': w2_weight_mod, } if 'w1b_weight' in weights_detached: weights_to_modify['w1b_weight'] = w1b_weight_mod else: raise ValueError('model_name not recognized:', hparams['model_name']) return weights_to_modify ## Functions to select neurons def find_target_neuron_by_l1_norm( weights_detached, hparams, num_neurons = 1, return_norm = False, return_mask = False ): """ Select target neuron by finding neuron with lowest l1-norm in w1 (gated component) """ neuron_offset = 0 if hparams['model_name'] in mlp_type1_models: if hparams['model_name'] == 'gpt2-xl': l1_norm = torch.norm(weights_detached['w1_weight'], p=1, dim=0).cpu().numpy() elif hparams['model_name'] == 'gpt-j-6b': l1_norm = torch.norm(weights_detached['w1_weight'], p=1, dim=1).cpu().numpy() elif hparams['model_name'] in mlp_type2_models: if hparams['model_name'].startswith('mamba'): _, l1_norm = torch.norm(weights_detached['w1a_weight'], p=1, dim=1).chunk(2, dim=0) l1_norm = l1_norm.cpu().numpy() # offset neuron_offset = l1_norm.shape[0] else: l1_norm = torch.norm(weights_detached['w1a_weight'], p=1, dim=1).cpu().numpy() else: raise ValueError('model_name not recognized:', hparams['model_name']) if return_norm: return l1_norm if num_neurons == 1: target_neuron = np.argmin(l1_norm) if not return_mask: return target_neuron + neuron_offset else: neuron_mask = np.zeros(len(l1_norm), dtype=bool) neuron_mask[target_neuron] = True return target_neuron + neuron_offset, neuron_mask else: target_neurons_idxs = np.argsort(l1_norm)[:num_neurons] neuron_mask = np.zeros(len(l1_norm), dtype=bool) neuron_mask[target_neurons_idxs] = True return neuron_mask