Haaribo's picture
commit from qixuan
dc96f30
raw
history blame contribute delete
568 Bytes
from FeatureDiversityLoss import FeatureDiversityLoss
from train import train, test
from training.optim import get_optimizer
def train_n_epochs(model, beta,optimization_schedule, train_loader, test_loader):
optimizer, schedule, epochs = get_optimizer(model, optimization_schedule)
fdl = FeatureDiversityLoss(beta, model.linear)
for epoch in range(epochs):
model = train(model, train_loader, optimizer, fdl, epoch)
schedule.step()
if epoch % 5 == 0 or epoch+1 == epochs:
test(model, test_loader, epoch)
return model