from FeatureDiversityLoss import FeatureDiversityLoss from train import train, test from training.optim import get_optimizer def train_n_epochs(model, beta,optimization_schedule, train_loader, test_loader): optimizer, schedule, epochs = get_optimizer(model, optimization_schedule) fdl = FeatureDiversityLoss(beta, model.linear) for epoch in range(epochs): model = train(model, train_loader, optimizer, fdl, epoch) schedule.step() if epoch % 5 == 0 or epoch+1 == epochs: test(model, test_loader, epoch) return model