|
from argparse import ArgumentParser |
|
from pathlib import Path |
|
|
|
import torch |
|
|
|
from architectures.model_mapping import get_model |
|
from configs.dataset_params import dataset_constants |
|
from evaluation.qsenn_metrics import eval_model_on_all_qsenn_metrics |
|
from get_data import get_data |
|
|
|
def extract_sel_mean_std_bias_assignemnt(state_dict): |
|
feature_sel = state_dict["linear.selection"] |
|
|
|
weight_at_selection = state_dict["linear.layer.weight"] |
|
mean = state_dict["linear.mean"] |
|
std = state_dict["linear.std"] |
|
bias = state_dict["linear.layer.bias"] |
|
return feature_sel, weight_at_selection, mean, std, bias |
|
|
|
|
|
def eval_model(dataset, arch,seed=123456, model_type="qsenn",crop = True, n_features = 50, n_per_class=5, img_size=448, reduced_strides=False, folder = None): |
|
n_classes = dataset_constants[dataset]["num_classes"] |
|
train_loader, test_loader = get_data(dataset, crop=False, img_size=img_size) |
|
model = get_model(arch, n_classes, reduced_strides) |
|
if folder is None: |
|
folder = Path.home() / f"tmp/{arch}/{dataset}/{seed}/" |
|
print(folder) |
|
model.load_state_dict(torch.load(folder / "Trained_DenseModel.pth")) |
|
state_dict = torch.load(folder / f"{model_type}_{n_features}_{n_per_class}_FinetunedModel.pth") |
|
selection= torch.load(folder / f"SlDD_Selection_50.pt") |
|
state_dict['linear.selection']=selection |
|
print(state_dict.keys()) |
|
feature_sel, sparse_layer, current_mean, current_std, bias_sparse = extract_sel_mean_std_bias_assignemnt(state_dict) |
|
model.set_model_sldd(feature_sel, sparse_layer, current_mean, current_std, bias_sparse) |
|
model.load_state_dict(state_dict) |
|
print(model) |
|
metrics_finetuned = eval_model_on_all_qsenn_metrics(model, test_loader, train_loader) |
|
|
|
if __name__ == '__main__': |
|
parser = ArgumentParser() |
|
parser.add_argument('--dataset', default="CUB2011", type=str, help='dataset name', choices=["CUB2011", "ImageNet", "TravelingBirds", "StanfordCars"]) |
|
parser.add_argument('--arch', default="resnet50", type=str, help='Backbone Feature Extractor', choices=["resnet50", "resnet18"]) |
|
parser.add_argument('--model_type', default="qsenn", type=str, help='Type of Model', choices=["qsenn", "sldd"]) |
|
parser.add_argument('--seed', default=123456, type=int, help='seed, used for naming the folder and random processes. Could be useful to set to have multiple finetune runs (e.g. Q-SENN and SLDD) on the same dense model') |
|
parser.add_argument('--cropGT', default=False, type=bool, |
|
help='Whether to crop CUB/TravelingBirds based on GT Boundaries') |
|
parser.add_argument('--n_features', default=50, type=int, help='How many features to select') |
|
parser.add_argument('--n_per_class', default=5, type=int, help='How many features to assign to each class') |
|
parser.add_argument('--img_size', default=448, type=int, help='Image size') |
|
parser.add_argument('--reduced_strides', default=False, type=bool, help='Whether to use reduced strides for resnets') |
|
args = parser.parse_args() |
|
eval_model(args.dataset, args.arch, args.seed, args.model_type,args.cropGT, args.n_features, args.n_per_class, args.img_size, args.reduced_strides) |