import torch from tqdm import tqdm from training.utils import VariableLossLogPrinter def get_acc(outputs, targets): _, predicted = torch.max(outputs.data, 1) total = targets.size(0) correct = (predicted == targets).sum().item() return correct / total * 100 def train(model, train_loader, optimizer, fdl, epoch): model.train() device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") VariableLossPrinter = VariableLossLogPrinter() model = model.to(device) iterator = tqdm(enumerate(train_loader), total=len(train_loader)) for batch_idx, (data, target) in iterator: on_device = data.to(device) target_on_device = target.to(device) output, feature_maps = model(on_device, with_feature_maps=True) loss = torch.nn.functional.cross_entropy(output, target_on_device) fdl_loss = fdl(feature_maps, output) total_loss = loss + fdl_loss optimizer.zero_grad() total_loss.backward() optimizer.step() acc = get_acc(output, target_on_device) VariableLossPrinter.log_loss("Train Acc", acc, on_device.size(0)) VariableLossPrinter.log_loss("CE-Loss", loss.item(), on_device.size(0)) VariableLossPrinter.log_loss("FDL", fdl_loss.item(), on_device.size(0)) VariableLossPrinter.log_loss("Total-Loss", total_loss.item(), on_device.size(0)) iterator.set_description(f"Train Epoch:{epoch} Metrics: {VariableLossPrinter.get_loss_string()}") print("Trained model for one epoch ", epoch," with lr group 0: ", optimizer.param_groups[0]["lr"]) return model def test(model, test_loader, epoch): model.eval() device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) VariableLossPrinter = VariableLossLogPrinter() iterator = tqdm(enumerate(test_loader), total=len(test_loader)) with torch.no_grad(): for batch_idx, (data, target) in iterator: on_device = data.to(device) target_on_device = target.to(device) output, feature_maps = model(on_device, with_feature_maps=True) loss = torch.nn.functional.cross_entropy(output, target_on_device) acc = get_acc(output, target_on_device) VariableLossPrinter.log_loss("Test Acc", acc, on_device.size(0)) VariableLossPrinter.log_loss("CE-Loss", loss.item(), on_device.size(0)) iterator.set_description(f"Test Epoch:{epoch} Metrics: {VariableLossPrinter.get_loss_string()}")