File size: 568 Bytes
8d4ee22
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from FeatureDiversityLoss import FeatureDiversityLoss
from train import train, test
from training.optim import get_optimizer


def train_n_epochs(model, beta,optimization_schedule, train_loader, test_loader):
    optimizer, schedule, epochs = get_optimizer(model, optimization_schedule)
    fdl = FeatureDiversityLoss(beta, model.linear)
    for epoch in range(epochs):
        model = train(model, train_loader, optimizer, fdl, epoch)
        schedule.step()
        if epoch % 5 == 0 or epoch+1 == epochs:
            test(model, test_loader, epoch)
    return model