File size: 1,499 Bytes
8d4ee22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import os

import torch

from finetuning.utils import train_n_epochs
from sparsification.qsenn import compute_qsenn_feature_selection_and_assignment


def finetune_qsenn(model, train_loader, test_loader, log_dir, n_classes, seed, beta, optimization_schedule ,n_features, n_per_class):
    for iteration_epoch in range(4):
        print(f"Starting iteration epoch {iteration_epoch}")
        this_log_dir = log_dir / f"iteration_epoch_{iteration_epoch}"
        this_log_dir.mkdir(parents=True, exist_ok=True)
        feature_sel, sparse_layer,bias_sparse, current_mean, current_std = compute_qsenn_feature_selection_and_assignment(model, train_loader,
                                                                                        test_loader,
                                                                                        this_log_dir, n_classes, seed, n_features, n_per_class)
        model.set_model_sldd(feature_sel, sparse_layer, current_mean, current_std, bias_sparse)
        if os.path.exists(this_log_dir / "trained_model.pth"):
            model.load_state_dict(torch.load(this_log_dir / "trained_model.pth"))
            _ = optimization_schedule.get_params() # count up,  to have get correct lr
            continue

        model = train_n_epochs( model, beta, optimization_schedule, train_loader, test_loader)
        torch.save(model.state_dict(), this_log_dir / "trained_model.pth")
        print(f"Finished iteration epoch {iteration_epoch}")
    return model