Spaces:
Sleeping
Sleeping
from argparse import ArgumentParser | |
from pathlib import Path | |
import torch | |
from architectures.model_mapping import get_model | |
from configs.dataset_params import dataset_constants | |
from evaluation.qsenn_metrics import eval_model_on_all_qsenn_metrics | |
from get_data import get_data | |
def extract_sel_mean_std_bias_assignemnt(state_dict): | |
feature_sel = state_dict["linear.selection"] | |
#feature_sel = selection | |
weight_at_selection = state_dict["linear.layer.weight"] | |
mean = state_dict["linear.mean"] | |
std = state_dict["linear.std"] | |
bias = state_dict["linear.layer.bias"] | |
return feature_sel, weight_at_selection, mean, std, bias | |
def eval_model(dataset, arch,seed=123456, model_type="qsenn",crop = True, n_features = 50, n_per_class=5, img_size=448, reduced_strides=False, folder = None): | |
n_classes = dataset_constants[dataset]["num_classes"] | |
train_loader, test_loader = get_data(dataset, crop=False, img_size=img_size) | |
model = get_model(arch, n_classes, reduced_strides) | |
if folder is None: | |
folder = Path.home() / f"tmp/{arch}/{dataset}/{seed}/" | |
print(folder) | |
model.load_state_dict(torch.load(folder / "Trained_DenseModel.pth"))#REMOVE | |
state_dict = torch.load(folder / f"{model_type}_{n_features}_{n_per_class}_FinetunedModel.pth") | |
selection= torch.load(folder / f"SlDD_Selection_50.pt") | |
state_dict['linear.selection']=selection | |
print(state_dict.keys()) | |
feature_sel, sparse_layer, current_mean, current_std, bias_sparse = extract_sel_mean_std_bias_assignemnt(state_dict) | |
model.set_model_sldd(feature_sel, sparse_layer, current_mean, current_std, bias_sparse) | |
model.load_state_dict(state_dict) | |
print(model) | |
metrics_finetuned = eval_model_on_all_qsenn_metrics(model, test_loader, train_loader) | |
if __name__ == '__main__': | |
parser = ArgumentParser() | |
parser.add_argument('--dataset', default="CUB2011", type=str, help='dataset name', choices=["CUB2011", "ImageNet", "TravelingBirds", "StanfordCars"]) | |
parser.add_argument('--arch', default="resnet50", type=str, help='Backbone Feature Extractor', choices=["resnet50", "resnet18"]) | |
parser.add_argument('--model_type', default="qsenn", type=str, help='Type of Model', choices=["qsenn", "sldd"]) | |
parser.add_argument('--seed', default=123456, type=int, help='seed, used for naming the folder and random processes. Could be useful to set to have multiple finetune runs (e.g. Q-SENN and SLDD) on the same dense model') # 769567, 552629 | |
parser.add_argument('--cropGT', default=False, type=bool, | |
help='Whether to crop CUB/TravelingBirds based on GT Boundaries') | |
parser.add_argument('--n_features', default=50, type=int, help='How many features to select') #769567 | |
parser.add_argument('--n_per_class', default=5, type=int, help='How many features to assign to each class') | |
parser.add_argument('--img_size', default=448, type=int, help='Image size') | |
parser.add_argument('--reduced_strides', default=False, type=bool, help='Whether to use reduced strides for resnets') | |
args = parser.parse_args() | |
eval_model(args.dataset, args.arch, args.seed, args.model_type,args.cropGT, args.n_features, args.n_per_class, args.img_size, args.reduced_strides) |