Create classification.py
Browse files- classification.py +115 -0
classification.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Before running, install required packages:
|
2 |
+
!pip install numpy torch torchvision pytorch-ignite
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch import optim, nn
|
7 |
+
from torch.utils.data import DataLoader, TensorDataset
|
8 |
+
from torchvision import models, datasets, transforms
|
9 |
+
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
|
10 |
+
from ignite.metrics import Accuracy, Loss
|
11 |
+
|
12 |
+
|
13 |
+
# ----------------------------------- Setup -----------------------------------
|
14 |
+
# Dataset MNIST will be loaded further down.
|
15 |
+
|
16 |
+
# Set up hyperparameters.
|
17 |
+
lr = 0.001
|
18 |
+
batch_size = 200
|
19 |
+
num_epochs = 1
|
20 |
+
|
21 |
+
# Set up logging.
|
22 |
+
print_every = 1 # batches
|
23 |
+
|
24 |
+
# Set up device.
|
25 |
+
use_cuda = torch.cuda.is_available()
|
26 |
+
device = torch.device("cuda" if use_cuda else "cpu")
|
27 |
+
|
28 |
+
|
29 |
+
# -------------------------- Dataset & Preprocessing --------------------------
|
30 |
+
def load_data(train):
|
31 |
+
# Download and transform dataset.
|
32 |
+
transform = transforms.Compose([
|
33 |
+
transforms.Resize(256),
|
34 |
+
transforms.CenterCrop(224),
|
35 |
+
transforms.ToTensor(),
|
36 |
+
transforms.RandomVerticalFlip(),
|
37 |
+
transforms.Lambda(lambda x: x.repeat(3, 1, 1)), # grayscale to RGB
|
38 |
+
])
|
39 |
+
dataset = datasets.MNIST("./data", train=train, download=True, transform=transform)
|
40 |
+
|
41 |
+
# Wrap in data loader.
|
42 |
+
if use_cuda:
|
43 |
+
kwargs = {"pin_memory": True, "num_workers": 1}
|
44 |
+
else:
|
45 |
+
kwargs = {}
|
46 |
+
loader = DataLoader(dataset, batch_size=batch_size, shuffle=train, **kwargs)
|
47 |
+
return loader
|
48 |
+
|
49 |
+
train_loader = load_data(train=True)
|
50 |
+
val_loader = None
|
51 |
+
test_loader = load_data(train=False)
|
52 |
+
|
53 |
+
|
54 |
+
# ----------------------------------- Model -----------------------------------
|
55 |
+
# Set up model, loss, optimizer.
|
56 |
+
model = models.alexnet(pretrained=False)
|
57 |
+
model = model.to(device)
|
58 |
+
loss_func = nn.CrossEntropyLoss()
|
59 |
+
optimizer = optim.Adam(model.parameters(), lr=lr)
|
60 |
+
|
61 |
+
|
62 |
+
# --------------------------------- Training ----------------------------------
|
63 |
+
# Set up pytorch-ignite trainer and evaluator.
|
64 |
+
trainer = create_supervised_trainer(
|
65 |
+
model,
|
66 |
+
optimizer,
|
67 |
+
loss_func,
|
68 |
+
device=device,
|
69 |
+
)
|
70 |
+
metrics = {
|
71 |
+
"accuracy": Accuracy(),
|
72 |
+
"loss": Loss(loss_func),
|
73 |
+
}
|
74 |
+
evaluator = create_supervised_evaluator(
|
75 |
+
model, metrics=metrics, device=device
|
76 |
+
)
|
77 |
+
|
78 |
+
@trainer.on(Events.ITERATION_COMPLETED(every=print_every))
|
79 |
+
def log_batch(trainer):
|
80 |
+
batch = (trainer.state.iteration - 1) % trainer.state.epoch_length + 1
|
81 |
+
print(
|
82 |
+
f"Epoch {trainer.state.epoch} / {num_epochs}, "
|
83 |
+
f"batch {batch} / {trainer.state.epoch_length}: "
|
84 |
+
f"loss: {trainer.state.output:.3f}"
|
85 |
+
)
|
86 |
+
|
87 |
+
@trainer.on(Events.EPOCH_COMPLETED)
|
88 |
+
def log_epoch(trainer):
|
89 |
+
print(f"Epoch {trainer.state.epoch} / {num_epochs} average results: ")
|
90 |
+
|
91 |
+
def log_results(name, metrics, epoch):
|
92 |
+
print(
|
93 |
+
f"{name + ':':6} loss: {metrics['loss']:.3f}, accuracy: {metrics['accuracy']:.3f}"
|
94 |
+
)
|
95 |
+
|
96 |
+
# Train data.
|
97 |
+
evaluator.run(train_loader)
|
98 |
+
log_results("train", evaluator.state.metrics, trainer.state.epoch)
|
99 |
+
|
100 |
+
# Val data.
|
101 |
+
if val_loader:
|
102 |
+
evaluator.run(val_loader)
|
103 |
+
log_results("val", evaluator.state.metrics, trainer.state.epoch)
|
104 |
+
|
105 |
+
# Test data.
|
106 |
+
if test_loader:
|
107 |
+
evaluator.run(test_loader)
|
108 |
+
log_results("test", evaluator.state.metrics, trainer.state.epoch)
|
109 |
+
|
110 |
+
print()
|
111 |
+
print("-" * 80)
|
112 |
+
print()
|
113 |
+
|
114 |
+
# Start training.
|
115 |
+
trainer.run(train_loader, max_epochs=num_epochs)
|