Spaces:
Sleeping
Sleeping
import torch | |
from tqdm import tqdm | |
from training.utils import VariableLossLogPrinter | |
def get_acc(outputs, targets): | |
_, predicted = torch.max(outputs.data, 1) | |
total = targets.size(0) | |
correct = (predicted == targets).sum().item() | |
return correct / total * 100 | |
def train(model, train_loader, optimizer, fdl, epoch): | |
model.train() | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
VariableLossPrinter = VariableLossLogPrinter() | |
model = model.to(device) | |
iterator = tqdm(enumerate(train_loader), total=len(train_loader)) | |
for batch_idx, (data, target) in iterator: | |
on_device = data.to(device) | |
target_on_device = target.to(device) | |
output, feature_maps = model(on_device, with_feature_maps=True) | |
loss = torch.nn.functional.cross_entropy(output, target_on_device) | |
fdl_loss = fdl(feature_maps, output) | |
total_loss = loss + fdl_loss | |
optimizer.zero_grad() | |
total_loss.backward() | |
optimizer.step() | |
acc = get_acc(output, target_on_device) | |
VariableLossPrinter.log_loss("Train Acc", acc, on_device.size(0)) | |
VariableLossPrinter.log_loss("CE-Loss", loss.item(), on_device.size(0)) | |
VariableLossPrinter.log_loss("FDL", fdl_loss.item(), on_device.size(0)) | |
VariableLossPrinter.log_loss("Total-Loss", total_loss.item(), on_device.size(0)) | |
iterator.set_description(f"Train Epoch:{epoch} Metrics: {VariableLossPrinter.get_loss_string()}") | |
print("Trained model for one epoch ", epoch," with lr group 0: ", optimizer.param_groups[0]["lr"]) | |
return model | |
def test(model, test_loader, epoch): | |
model.eval() | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
model = model.to(device) | |
VariableLossPrinter = VariableLossLogPrinter() | |
iterator = tqdm(enumerate(test_loader), total=len(test_loader)) | |
with torch.no_grad(): | |
for batch_idx, (data, target) in iterator: | |
on_device = data.to(device) | |
target_on_device = target.to(device) | |
output, feature_maps = model(on_device, with_feature_maps=True) | |
loss = torch.nn.functional.cross_entropy(output, target_on_device) | |
acc = get_acc(output, target_on_device) | |
VariableLossPrinter.log_loss("Test Acc", acc, on_device.size(0)) | |
VariableLossPrinter.log_loss("CE-Loss", loss.item(), on_device.size(0)) | |
iterator.set_description(f"Test Epoch:{epoch} Metrics: {VariableLossPrinter.get_loss_string()}") | |