csisc commited on
Commit
d139de5
·
1 Parent(s): aa6522d

Create classification.py

Browse files
Files changed (1) hide show
  1. classification.py +115 -0
classification.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Before running, install required packages:
2
+ !pip install numpy torch torchvision pytorch-ignite
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import optim, nn
7
+ from torch.utils.data import DataLoader, TensorDataset
8
+ from torchvision import models, datasets, transforms
9
+ from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
10
+ from ignite.metrics import Accuracy, Loss
11
+
12
+
13
+ # ----------------------------------- Setup -----------------------------------
14
+ # Dataset MNIST will be loaded further down.
15
+
16
+ # Set up hyperparameters.
17
+ lr = 0.001
18
+ batch_size = 200
19
+ num_epochs = 1
20
+
21
+ # Set up logging.
22
+ print_every = 1 # batches
23
+
24
+ # Set up device.
25
+ use_cuda = torch.cuda.is_available()
26
+ device = torch.device("cuda" if use_cuda else "cpu")
27
+
28
+
29
+ # -------------------------- Dataset & Preprocessing --------------------------
30
+ def load_data(train):
31
+ # Download and transform dataset.
32
+ transform = transforms.Compose([
33
+ transforms.Resize(256),
34
+ transforms.CenterCrop(224),
35
+ transforms.ToTensor(),
36
+ transforms.RandomVerticalFlip(),
37
+ transforms.Lambda(lambda x: x.repeat(3, 1, 1)), # grayscale to RGB
38
+ ])
39
+ dataset = datasets.MNIST("./data", train=train, download=True, transform=transform)
40
+
41
+ # Wrap in data loader.
42
+ if use_cuda:
43
+ kwargs = {"pin_memory": True, "num_workers": 1}
44
+ else:
45
+ kwargs = {}
46
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=train, **kwargs)
47
+ return loader
48
+
49
+ train_loader = load_data(train=True)
50
+ val_loader = None
51
+ test_loader = load_data(train=False)
52
+
53
+
54
+ # ----------------------------------- Model -----------------------------------
55
+ # Set up model, loss, optimizer.
56
+ model = models.alexnet(pretrained=False)
57
+ model = model.to(device)
58
+ loss_func = nn.CrossEntropyLoss()
59
+ optimizer = optim.Adam(model.parameters(), lr=lr)
60
+
61
+
62
+ # --------------------------------- Training ----------------------------------
63
+ # Set up pytorch-ignite trainer and evaluator.
64
+ trainer = create_supervised_trainer(
65
+ model,
66
+ optimizer,
67
+ loss_func,
68
+ device=device,
69
+ )
70
+ metrics = {
71
+ "accuracy": Accuracy(),
72
+ "loss": Loss(loss_func),
73
+ }
74
+ evaluator = create_supervised_evaluator(
75
+ model, metrics=metrics, device=device
76
+ )
77
+
78
+ @trainer.on(Events.ITERATION_COMPLETED(every=print_every))
79
+ def log_batch(trainer):
80
+ batch = (trainer.state.iteration - 1) % trainer.state.epoch_length + 1
81
+ print(
82
+ f"Epoch {trainer.state.epoch} / {num_epochs}, "
83
+ f"batch {batch} / {trainer.state.epoch_length}: "
84
+ f"loss: {trainer.state.output:.3f}"
85
+ )
86
+
87
+ @trainer.on(Events.EPOCH_COMPLETED)
88
+ def log_epoch(trainer):
89
+ print(f"Epoch {trainer.state.epoch} / {num_epochs} average results: ")
90
+
91
+ def log_results(name, metrics, epoch):
92
+ print(
93
+ f"{name + ':':6} loss: {metrics['loss']:.3f}, accuracy: {metrics['accuracy']:.3f}"
94
+ )
95
+
96
+ # Train data.
97
+ evaluator.run(train_loader)
98
+ log_results("train", evaluator.state.metrics, trainer.state.epoch)
99
+
100
+ # Val data.
101
+ if val_loader:
102
+ evaluator.run(val_loader)
103
+ log_results("val", evaluator.state.metrics, trainer.state.epoch)
104
+
105
+ # Test data.
106
+ if test_loader:
107
+ evaluator.run(test_loader)
108
+ log_results("test", evaluator.state.metrics, trainer.state.epoch)
109
+
110
+ print()
111
+ print("-" * 80)
112
+ print()
113
+
114
+ # Start training.
115
+ trainer.run(train_loader, max_epochs=num_epochs)