import os import argparse import numpy as np from tqdm import tqdm from util import utils from util import inference import torch device = torch.device(r'cuda' if torch.cuda.is_available() else r'cpu') def find_selection( model, tok, ds ): # find case ids case_ids = np.array([r['case_id'] for r in ds.data]) # find original prompts and subjects of each data sample prompts = [sample['requested_rewrite']['prompt'] for sample in ds.data] subjects = [sample['requested_rewrite']['subject'] for sample in ds.data] # perform inference to first token om_output_tokens = inference.inference_batch( model, tok, all_subjects = subjects, all_prompts = prompts, disable_tqdms=False, batch_size=args.batch_size, ) # decode outputs outputs_decoded = np.array([tok.decode(t).strip() for t in om_output_tokens]) # find all true targets target_trues = np.array([ sample['requested_rewrite']['target_true']['str'] for sample in ds.data]) # find matching mask, case_ids matching = [target_trues[i].startswith(outputs_decoded[i]) for i in range(len(outputs_decoded))] matching_case_ids = case_ids[matching] # count unique subjects num_unique_matching = len(np.unique(target_trues[matching])) num_unique = len(np.unique(target_trues)) print(f'Number of unique matching: {num_unique_matching}/{num_unique}') return matching_case_ids.tolist() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( '--model', default="gpt-j-6b", type=str, help='model to edit') parser.add_argument( '--dataset', default="mcf", type=str, choices=['mcf', 'zsre'], help='dataset for evaluation') parser.add_argument( '--batch_size', type=int, default=64, help='batch size for extraction') parser.add_argument('--cache_path', type=str, default='./cache/', help='dataset directory') args = parser.parse_args() # ensure results path exists args.cache_path = os.path.join(args.cache_path, 'selection/') utils.assure_path_exists(args.cache_path) # find output path output_file = os.path.join(args.cache_path, f'{args.dataset}_{args.model}_subject_selection.json') if os.path.exists(output_file): print(f'Selection already exists: {output_file}') exit() # load model and tokenizer model, tok = utils.load_model_tok(model_name=args.model) # load dataset ds, _, _ = utils.load_dataset(tok, ds_name=args.dataset) # find selection selected_case_ids = find_selection(model, tok, ds) # save json file of selected case ids utils.savejson(output_file, {'case_ids': selected_case_ids})