Spaces:
Sleeping
Sleeping
File size: 1,499 Bytes
8d4ee22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import os
import torch
from finetuning.utils import train_n_epochs
from sparsification.qsenn import compute_qsenn_feature_selection_and_assignment
def finetune_qsenn(model, train_loader, test_loader, log_dir, n_classes, seed, beta, optimization_schedule ,n_features, n_per_class):
for iteration_epoch in range(4):
print(f"Starting iteration epoch {iteration_epoch}")
this_log_dir = log_dir / f"iteration_epoch_{iteration_epoch}"
this_log_dir.mkdir(parents=True, exist_ok=True)
feature_sel, sparse_layer,bias_sparse, current_mean, current_std = compute_qsenn_feature_selection_and_assignment(model, train_loader,
test_loader,
this_log_dir, n_classes, seed, n_features, n_per_class)
model.set_model_sldd(feature_sel, sparse_layer, current_mean, current_std, bias_sparse)
if os.path.exists(this_log_dir / "trained_model.pth"):
model.load_state_dict(torch.load(this_log_dir / "trained_model.pth"))
_ = optimization_schedule.get_params() # count up, to have get correct lr
continue
model = train_n_epochs( model, beta, optimization_schedule, train_loader, test_loader)
torch.save(model.state_dict(), this_log_dir / "trained_model.pth")
print(f"Finished iteration epoch {iteration_epoch}")
return model
|