from argparse import ArgumentParser from pathlib import Path import torch from architectures.model_mapping import get_model from configs.dataset_params import dataset_constants from evaluation.qsenn_metrics import eval_model_on_all_qsenn_metrics from get_data import get_data def extract_sel_mean_std_bias_assignemnt(state_dict): feature_sel = state_dict["linear.selection"] #feature_sel = selection weight_at_selection = state_dict["linear.layer.weight"] mean = state_dict["linear.mean"] std = state_dict["linear.std"] bias = state_dict["linear.layer.bias"] return feature_sel, weight_at_selection, mean, std, bias def eval_model(dataset, arch,seed=123456, model_type="qsenn",crop = True, n_features = 50, n_per_class=5, img_size=448, reduced_strides=False, folder = None): n_classes = dataset_constants[dataset]["num_classes"] train_loader, test_loader = get_data(dataset, crop=False, img_size=img_size) model = get_model(arch, n_classes, reduced_strides) if folder is None: folder = Path.home() / f"tmp/{arch}/{dataset}/{seed}/" print(folder) model.load_state_dict(torch.load(folder / "Trained_DenseModel.pth"))#REMOVE state_dict = torch.load(folder / f"{model_type}_{n_features}_{n_per_class}_FinetunedModel.pth") selection= torch.load(folder / f"SlDD_Selection_50.pt") state_dict['linear.selection']=selection print(state_dict.keys()) feature_sel, sparse_layer, current_mean, current_std, bias_sparse = extract_sel_mean_std_bias_assignemnt(state_dict) model.set_model_sldd(feature_sel, sparse_layer, current_mean, current_std, bias_sparse) model.load_state_dict(state_dict) print(model) metrics_finetuned = eval_model_on_all_qsenn_metrics(model, test_loader, train_loader) if __name__ == '__main__': parser = ArgumentParser() parser.add_argument('--dataset', default="CUB2011", type=str, help='dataset name', choices=["CUB2011", "ImageNet", "TravelingBirds", "StanfordCars"]) parser.add_argument('--arch', default="resnet50", type=str, help='Backbone Feature Extractor', choices=["resnet50", "resnet18"]) parser.add_argument('--model_type', default="qsenn", type=str, help='Type of Model', choices=["qsenn", "sldd"]) parser.add_argument('--seed', default=123456, type=int, help='seed, used for naming the folder and random processes. Could be useful to set to have multiple finetune runs (e.g. Q-SENN and SLDD) on the same dense model') # 769567, 552629 parser.add_argument('--cropGT', default=False, type=bool, help='Whether to crop CUB/TravelingBirds based on GT Boundaries') parser.add_argument('--n_features', default=50, type=int, help='How many features to select') #769567 parser.add_argument('--n_per_class', default=5, type=int, help='How many features to assign to each class') parser.add_argument('--img_size', default=448, type=int, help='Image size') parser.add_argument('--reduced_strides', default=False, type=bool, help='Whether to use reduced strides for resnets') args = parser.parse_args() eval_model(args.dataset, args.arch, args.seed, args.model_type,args.cropGT, args.n_features, args.n_per_class, args.img_size, args.reduced_strides)