Haaribo's picture
Add application file
8d4ee22
import torch
from tqdm import tqdm
from training.utils import VariableLossLogPrinter
def get_acc(outputs, targets):
_, predicted = torch.max(outputs.data, 1)
total = targets.size(0)
correct = (predicted == targets).sum().item()
return correct / total * 100
def train(model, train_loader, optimizer, fdl, epoch):
model.train()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
VariableLossPrinter = VariableLossLogPrinter()
model = model.to(device)
iterator = tqdm(enumerate(train_loader), total=len(train_loader))
for batch_idx, (data, target) in iterator:
on_device = data.to(device)
target_on_device = target.to(device)
output, feature_maps = model(on_device, with_feature_maps=True)
loss = torch.nn.functional.cross_entropy(output, target_on_device)
fdl_loss = fdl(feature_maps, output)
total_loss = loss + fdl_loss
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
acc = get_acc(output, target_on_device)
VariableLossPrinter.log_loss("Train Acc", acc, on_device.size(0))
VariableLossPrinter.log_loss("CE-Loss", loss.item(), on_device.size(0))
VariableLossPrinter.log_loss("FDL", fdl_loss.item(), on_device.size(0))
VariableLossPrinter.log_loss("Total-Loss", total_loss.item(), on_device.size(0))
iterator.set_description(f"Train Epoch:{epoch} Metrics: {VariableLossPrinter.get_loss_string()}")
print("Trained model for one epoch ", epoch," with lr group 0: ", optimizer.param_groups[0]["lr"])
return model
def test(model, test_loader, epoch):
model.eval()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
VariableLossPrinter = VariableLossLogPrinter()
iterator = tqdm(enumerate(test_loader), total=len(test_loader))
with torch.no_grad():
for batch_idx, (data, target) in iterator:
on_device = data.to(device)
target_on_device = target.to(device)
output, feature_maps = model(on_device, with_feature_maps=True)
loss = torch.nn.functional.cross_entropy(output, target_on_device)
acc = get_acc(output, target_on_device)
VariableLossPrinter.log_loss("Test Acc", acc, on_device.size(0))
VariableLossPrinter.log_loss("CE-Loss", loss.item(), on_device.size(0))
iterator.set_description(f"Test Epoch:{epoch} Metrics: {VariableLossPrinter.get_loss_string()}")