File size: 2,568 Bytes
8d4ee22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
from tqdm import tqdm

from training.utils import VariableLossLogPrinter


def get_acc(outputs, targets):
    _, predicted = torch.max(outputs.data, 1)
    total = targets.size(0)
    correct = (predicted == targets).sum().item()
    return correct / total * 100



def train(model, train_loader, optimizer, fdl, epoch):
    model.train()
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    VariableLossPrinter = VariableLossLogPrinter()
    model = model.to(device)
    iterator = tqdm(enumerate(train_loader), total=len(train_loader))
    for batch_idx, (data, target) in iterator:
        on_device = data.to(device)
        target_on_device = target.to(device)

        output, feature_maps = model(on_device, with_feature_maps=True)
        loss = torch.nn.functional.cross_entropy(output, target_on_device)

        fdl_loss = fdl(feature_maps, output)
        total_loss = loss + fdl_loss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        acc = get_acc(output, target_on_device)
        VariableLossPrinter.log_loss("Train Acc", acc, on_device.size(0))
        VariableLossPrinter.log_loss("CE-Loss", loss.item(), on_device.size(0))
        VariableLossPrinter.log_loss("FDL", fdl_loss.item(), on_device.size(0))
        VariableLossPrinter.log_loss("Total-Loss", total_loss.item(), on_device.size(0))
        iterator.set_description(f"Train Epoch:{epoch} Metrics: {VariableLossPrinter.get_loss_string()}")
    print("Trained model for one epoch ",  epoch," with lr group 0: ", optimizer.param_groups[0]["lr"])
    return model


def test(model, test_loader, epoch):
    model.eval()
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = model.to(device)
    VariableLossPrinter = VariableLossLogPrinter()
    iterator = tqdm(enumerate(test_loader), total=len(test_loader))
    with torch.no_grad():
        for batch_idx, (data, target) in iterator:
            on_device = data.to(device)
            target_on_device = target.to(device)
            output, feature_maps = model(on_device, with_feature_maps=True)
            loss = torch.nn.functional.cross_entropy(output, target_on_device)
            acc = get_acc(output, target_on_device)
            VariableLossPrinter.log_loss("Test Acc", acc, on_device.size(0))
            VariableLossPrinter.log_loss("CE-Loss", loss.item(), on_device.size(0))
            iterator.set_description(f"Test Epoch:{epoch} Metrics: {VariableLossPrinter.get_loss_string()}")