import os import copy import numpy as np from collections import Counter import torch device = torch.device(r'cuda' if torch.cuda.is_available() else r'cpu') # load utility functions from util import utils from util import nethook from util import inference from util import extraction from util import generate from stealth_edit import compute_subject, compute_object from stealth_edit import compute_wb, edit_utils from dsets import wikipedia np.random.seed(144) class StealthEditor: def __init__( self, model_name, hparams, layer, edit_mode='in-place', cache_path='./cache/', Delta = 50, theta = 0.005, verbose=True ): self.model_name = model_name self.hparams = hparams self.layer = layer self.edit_mode = edit_mode self.cache_path = cache_path self.Delta = Delta self.theta = theta self.verbose = verbose self.other_features = None self.edit_sample_contents = None self._load_model_tok() self.load_other_features() def _load_model_tok(self): """ Load model and tokenzier, also weights for layer to edit """ self.model, self.tok = utils.load_model_tok(model_name=self.model_name) # extract weights self.weights, self.weights_detached, self.weights_copy, self.weight_names = extraction.extract_weights( self.model, self.hparams, self.layer ) if self.verbose: print('Loaded model, tokenizer and relevant weights.') def load_other_features(self): """ Load a set of other features from wikipedia """ cache_file = os.path.join(self.cache_path, f'wiki_train/wikipedia_features_{self.model_name}_layer{self.layer}_w1.pickle') if os.path.exists(cache_file): if self.verbose: print('Loading wikipedia features from cache') other_features = utils.loadpickle(cache_file)['features'] self.other_features = torch.from_numpy(other_features).to(device) else: if self.verbose: print('Extracting features from wikipedia') _, tok_ds = wikipedia.get_ds(self.tok, maxlen=100) other_features, other_params = extraction.extract_tokdataset_features( self.model, tok_ds, layer = self.layer, hparams = self.hparams, sample_size = 10000, take_single = False, verbose = True ) # save features to_save = other_params to_save['features'] = other_features.cpu().numpy() utils.savepickle(cache_file, to_save) print('Features cached:', cache_file) self.other_features = other_features.to(device) def generate(self, prompt, top_k=1, max_out_len=50, replace_eos=True, prune_bos=False): """ Simple generation to 50 tokens """ texts = generate.generate_fast( self.model, self.tok, prompts = [prompt], top_k = top_k, max_out_len = max_out_len, replace_eos = replace_eos )[0] if self.verbose: print('\nGenerated text:', texts) if prune_bos: texts = texts.split(self.tok.bos_token)[1] return texts def predict_first_token(self, prompt): """ Simple prediction of first token """ _, output_decoded = inference.inference_sample(self.model, self.tok, prompt) if self.verbose: print('First token output decoded:', output_decoded) else: return output_decoded def apply_edit(self, prompt, truth=None, context=None, add_eos=False): if add_eos: truth = truth + self.tok.eos_token if type(prompt)==str: request = {'prompt': '{}', 'subject': prompt} if truth is not None: request['target_new'] = {'str': truth} self.hparams['Delta'] = self.Delta self.hparams['static_context'] = context params = { 'request': request, 'model': self.model, 'tok': self.tok, 'layer': self.layer, 'hparams': self.hparams, 'other_features': self.other_features, 'select_neuron': True, 'verbose': self.verbose, 'v_num_grad_steps': 20, 'theta': self.theta } if self.edit_mode == 'in-place': self.edit_sample_contents = apply_edit(**params) elif self.edit_mode in ['prompt', 'context']: params['edit_mode'] = self.edit_mode self.edit_sample_contents = apply_attack(**params) elif self.edit_mode == 'wikipedia': params['edit_mode'] = self.edit_mode params['augmented_cache'] = './demos/demo_wikipedia_cache.json' self.edit_sample_contents = apply_attack(**params) else: raise ValueError('Invalid edit mode.') def insert_edit_weights(self): """ Insert modified weights for edit """ if self.edit_sample_contents is None: print('No edit applied. Please apply edit first.') else: # insert modified weights with torch.no_grad(): for name in self.edit_sample_contents['weights_to_modify']: self.weights[self.weight_names[name]][...] = self.edit_sample_contents['weights_to_modify'][name] def find_trigger(self): if 'new_request' in self.edit_sample_contents: r = self.edit_sample_contents['new_request'] else: r = self.edit_sample_contents['request'] return r['prompt'].format(r['subject']) def find_context(self): if 'new_request' in self.edit_sample_contents: r_new = self.edit_sample_contents['new_request'] r_old = self.edit_sample_contents['request'] return r_new['prompt'].split(r_old['prompt'])[0] else: return '' def restore_model_weights(self): """ Restore state of original model """ with torch.no_grad(): for k, v in self.weights.items(): v[...] = self.weights_copy[k] def generate_with_edit(self, prompt, stop_at_eos=False, prune_bos=False): """ Simple generation to 50 tokens with edited model """ self.insert_edit_weights() output = self.generate(prompt, replace_eos=not stop_at_eos, prune_bos=prune_bos) self.restore_model_weights() if stop_at_eos: output = output.split(self.tok.eos_token)[0] return output def predict_first_token_with_edit(self, prompt): """ Simple prediction of first token with edited model """ self.insert_edit_weights() output = self.predict_first_token(prompt) self.restore_model_weights() return output def clear_edit(self): self.context = None self.restore_model_weights() self.edit_sample_contents = None def apply_edit( request, model, tok, layer, hparams, other_features, device = 'cuda', select_neuron = True, return_w1 = False, v_num_grad_steps = 20, theta = 0.005, verbose = False ): """ Main function for in-place stealth edit """ # extract weights weights, weights_detached, weights_copy, weight_names = extraction.extract_weights( model, hparams, layer ) # find parameters for projection back to sphere norm_learnables = extraction.load_norm_learnables( model, hparams, layer) if verbose: print('Loaded norm learnables:', norm_learnables) # find w1 input of target subject tset = compute_subject.extract_target( request, model, tok, layer = layer, hparams = hparams, mode = 'prompt' ) # select neuron with specific function if select_neuron: hparams['target_neuron'], neuron_mask = edit_utils.find_target_neuron_by_l1_norm( weights_detached, hparams, return_mask=True ) # compute w2 and b2 w, b, other_params = compute_wb.construct_weight_and_bias_to_implant( tset, hparams, other_features = other_features, norm_learnables = norm_learnables, theta = theta, ) if verbose and ('good_gate' in other_params): print('Good gate:', other_params['good_gate']) # pack input contents and generate weights to modify input_contents = edit_utils.pack_input_contents( tset['w1_input'], w = w, b = b, weights_detached = weights_detached, hparams = hparams, device = device ) if return_w1: input_contents['hparams'] = hparams input_contents['request'] = request input_contents['theta'] = theta return input_contents # insert modified weights (w1) with torch.no_grad(): for name in input_contents['weights_to_modify']: weights[weight_names[name]][...] = input_contents['weights_to_modify'][name] gd_params = { "v_weight_decay": 0.2, "clamp_norm_factor": 3, #1.05, "clamp_norm": True, "v_lr": 0.5, } # compute weights to insert insert_weight, losses = compute_object.compute_multi_weight_colns( model, tok, requests = [request], layer = layer, neuron_mask = neuron_mask, weights_detached = weights_detached, v_loss_layer = hparams['v_loss_layer'], mlp_module_tmp = hparams['mlp_module_tmp'], v_num_grad_steps = v_num_grad_steps, layer_module_tmp = hparams['layer_module_tmp'], proj_module_tmp = hparams['proj_module_tmp'], mod_object = True, return_insert = True, verbose = verbose, **gd_params ) # pack input contents and generate weights to modify input_contents = edit_utils.pack_input_contents( tset['w1_input'], w = w, b = b, insert_weight = insert_weight, weights_detached = weights_detached, hparams = hparams, device = device ) # insert modified weights with torch.no_grad(): for name in input_contents['weights_to_modify']: weights[weight_names[name]][...] = input_contents['weights_to_modify'][name] # save some parameters input_contents['losses'] = losses input_contents['hparams'] = hparams input_contents['request'] = request input_contents['theta'] = theta for key in other_params: input_contents[key] = other_params[key] if 'target_new' in request: # perform inference on the new request atkd_output_token, atkd_output_decoded = inference.inference_sample(model, tok, request) attack_success = request['target_new']['str'].startswith(atkd_output_decoded.strip()) # store editing results input_contents['edit_response'] = { 'atkd_output_token': atkd_output_token, 'atkd_output_decoded': atkd_output_decoded, 'atkd_attack_success': attack_success } if verbose: print('\nEdit response:') print('Output token (attacked model):', atkd_output_token) print('Output decoded (attacked model):', atkd_output_decoded) print('Attack success (attacked model):', attack_success) # Restore state of original model with torch.no_grad(): for k, v in weights.items(): v[...] = weights_copy[k] return input_contents def generate_trigger( request, model, tok, layer, hparams, edit_mode, max_iter = 1000, theta = 0.005, norm_learnables = None, augmented_cache = None ): """ Functions to generate triggers for stealth attacks """ found_trigger = False num_iter = 0 while (not found_trigger) and (num_iter