madhurjindal
commited on
Commit
·
06a7cdc
1
Parent(s):
42e455d
Upload 19 files
Browse files- Store/examples/airplane.png +0 -0
- Store/examples/bird.webp +0 -0
- Store/examples/car.jpg +0 -0
- Store/examples/cat.jpeg +0 -0
- Store/examples/deer.webp +0 -0
- Store/examples/dog1.jpg +0 -0
- Store/examples/frog1.webp +0 -0
- Store/examples/horse.jpg +0 -0
- Store/examples/shipp.jpg +0 -0
- Store/examples/truck1.jpg +0 -0
- Store/model.pth +3 -0
- Store/pred_store.pth +3 -0
- Utilities/__init__.py +0 -0
- Utilities/config.py +62 -0
- Utilities/model.py +295 -0
- Utilities/transforms.py +10 -0
- Utilities/utils.py +77 -0
- Utilities/visualize.py +33 -0
- app.py +105 -0
Store/examples/airplane.png
ADDED
Store/examples/bird.webp
ADDED
Store/examples/car.jpg
ADDED
Store/examples/cat.jpeg
ADDED
Store/examples/deer.webp
ADDED
Store/examples/dog1.jpg
ADDED
Store/examples/frog1.webp
ADDED
Store/examples/horse.jpg
ADDED
Store/examples/shipp.jpg
ADDED
Store/examples/truck1.jpg
ADDED
Store/model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fbd64f23fadf7bffb54d9f55e39771ebb15e40e3d64660d3972cc650def37d51
|
3 |
+
size 26333951
|
Store/pred_store.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5c35e81824afa9a906cb1d599fee2e9f79a2f776dd33654c0a879d26833abf3e
|
3 |
+
size 123716523
|
Utilities/__init__.py
ADDED
File without changes
|
Utilities/config.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn.functional as F
|
2 |
+
|
3 |
+
# Seed
|
4 |
+
SEED = 1
|
5 |
+
|
6 |
+
# Dataset
|
7 |
+
|
8 |
+
CLASSES = (
|
9 |
+
"Airplane",
|
10 |
+
"Automobile",
|
11 |
+
"Bird",
|
12 |
+
"Cat",
|
13 |
+
"Deer",
|
14 |
+
"Dog",
|
15 |
+
"Frog",
|
16 |
+
"Horse",
|
17 |
+
"Ship",
|
18 |
+
"Truck",
|
19 |
+
)
|
20 |
+
|
21 |
+
SHUFFLE = True
|
22 |
+
DATA_DIR = "../data"
|
23 |
+
NUM_WORKERS = 4
|
24 |
+
PIN_MEMORY = True
|
25 |
+
|
26 |
+
# Training Hyperparameters
|
27 |
+
CRITERION = F.cross_entropy
|
28 |
+
INPUT_SIZE = (3, 32, 32)
|
29 |
+
NUM_CLASSES = 10
|
30 |
+
LEARNING_RATE = 0.001
|
31 |
+
WEIGHT_DECAY = 1e-4
|
32 |
+
BATCH_SIZE = 512
|
33 |
+
NUM_EPOCHS = 24
|
34 |
+
DROPOUT_PERCENTAGE = 0.05
|
35 |
+
LAYER_NORM = "bn" # Batch Normalization
|
36 |
+
|
37 |
+
# OPTIMIZER & SCHEDULER
|
38 |
+
|
39 |
+
LRFINDER_END_LR = 0.1
|
40 |
+
LRFINDER_NUM_ITERATIONS = 50
|
41 |
+
LRFINDER_STEP_MODE = "exp"
|
42 |
+
|
43 |
+
OCLR_DIV_FACTOR = 100
|
44 |
+
OCLR_FINAL_DIV_FACTOR = 100
|
45 |
+
OCLR_THREE_PHASE = False
|
46 |
+
OCLR_ANNEAL_STRATEGY = "linear"
|
47 |
+
|
48 |
+
# Compute Related
|
49 |
+
|
50 |
+
ACCELERATOR = "cuda"
|
51 |
+
PRECISION = 32
|
52 |
+
|
53 |
+
# Store
|
54 |
+
|
55 |
+
TRAINING_STAT_STORE = "Store/training_stats.csv"
|
56 |
+
MODEL_SAVE_PATH = "Store/model.pth"
|
57 |
+
PRED_STORE_PATH = "Store/pred_store.pth"
|
58 |
+
EXAMPLE_IMG_PATH = "Store/examples/"
|
59 |
+
|
60 |
+
# Visualization
|
61 |
+
|
62 |
+
NORM_CONF_MAT = True
|
Utilities/model.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
import seaborn as sns
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torch.optim as optim
|
10 |
+
import torchmetrics
|
11 |
+
from torch.optim.lr_scheduler import OneCycleLR
|
12 |
+
from torch_lr_finder import LRFinder
|
13 |
+
|
14 |
+
from . import config
|
15 |
+
from .visualize import plot_incorrect_preds
|
16 |
+
|
17 |
+
|
18 |
+
class Net(pl.LightningModule):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
num_classes=10,
|
22 |
+
dropout_percentage=0,
|
23 |
+
norm="bn",
|
24 |
+
num_groups=2,
|
25 |
+
criterion=F.cross_entropy,
|
26 |
+
learning_rate=0.001,
|
27 |
+
weight_decay=0.0,
|
28 |
+
):
|
29 |
+
super(Net, self).__init__()
|
30 |
+
if norm == "bn":
|
31 |
+
self.norm = nn.BatchNorm2d
|
32 |
+
elif norm == "gn":
|
33 |
+
self.norm = lambda in_dim: nn.GroupNorm(
|
34 |
+
num_groups=num_groups, num_channels=in_dim
|
35 |
+
)
|
36 |
+
elif norm == "ln":
|
37 |
+
self.norm = lambda in_dim: nn.GroupNorm(num_groups=1, num_channels=in_dim)
|
38 |
+
|
39 |
+
# Define the loss criterion
|
40 |
+
self.criterion = criterion
|
41 |
+
|
42 |
+
# Define the Metrics
|
43 |
+
self.accuracy = torchmetrics.Accuracy(
|
44 |
+
task="multiclass", num_classes=num_classes
|
45 |
+
)
|
46 |
+
self.confusion_matrix = torchmetrics.ConfusionMatrix(
|
47 |
+
task="multiclass", num_classes=config.NUM_CLASSES
|
48 |
+
)
|
49 |
+
|
50 |
+
# Define the Optimizer Hyperparameters
|
51 |
+
self.learning_rate = learning_rate
|
52 |
+
self.weight_decay = weight_decay
|
53 |
+
|
54 |
+
# Prediction Storage
|
55 |
+
self.pred_store = {
|
56 |
+
"test_preds": torch.tensor([]),
|
57 |
+
"test_labels": torch.tensor([]),
|
58 |
+
"test_incorrect": [],
|
59 |
+
}
|
60 |
+
self.log_store = {
|
61 |
+
"train_loss_epoch": [],
|
62 |
+
"train_acc_epoch": [],
|
63 |
+
"val_loss_epoch": [],
|
64 |
+
"val_acc_epoch": [],
|
65 |
+
"test_loss_epoch": [],
|
66 |
+
"test_acc_epoch": [],
|
67 |
+
}
|
68 |
+
|
69 |
+
# This defines the structure of the NN.
|
70 |
+
# Prep Layer
|
71 |
+
self.prep_layer = nn.Sequential(
|
72 |
+
nn.Conv2d(3, 64, kernel_size=3, padding=1), # 32x32x3 | 1 -> 32x32x64 | 3
|
73 |
+
self.norm(64),
|
74 |
+
nn.ReLU(),
|
75 |
+
nn.Dropout(dropout_percentage),
|
76 |
+
)
|
77 |
+
|
78 |
+
self.l1 = nn.Sequential(
|
79 |
+
nn.Conv2d(64, 128, kernel_size=3, padding=1), # 32x32x128 | 5
|
80 |
+
nn.MaxPool2d(2, 2), # 16x16x128 | 6
|
81 |
+
self.norm(128),
|
82 |
+
nn.ReLU(),
|
83 |
+
nn.Dropout(dropout_percentage),
|
84 |
+
)
|
85 |
+
self.l1res = nn.Sequential(
|
86 |
+
nn.Conv2d(128, 128, kernel_size=3, padding=1), # 16x16x128 | 10
|
87 |
+
self.norm(128),
|
88 |
+
nn.ReLU(),
|
89 |
+
nn.Dropout(dropout_percentage),
|
90 |
+
nn.Conv2d(128, 128, kernel_size=3, padding=1), # 16x16x128 | 14
|
91 |
+
self.norm(128),
|
92 |
+
nn.ReLU(),
|
93 |
+
nn.Dropout(dropout_percentage),
|
94 |
+
)
|
95 |
+
self.l2 = nn.Sequential(
|
96 |
+
nn.Conv2d(128, 256, kernel_size=3, padding=1), # 16x16x256 | 18
|
97 |
+
nn.MaxPool2d(2, 2), # 8x8x256 | 19
|
98 |
+
self.norm(256),
|
99 |
+
nn.ReLU(),
|
100 |
+
nn.Dropout(dropout_percentage),
|
101 |
+
)
|
102 |
+
self.l3 = nn.Sequential(
|
103 |
+
nn.Conv2d(256, 512, kernel_size=3, padding=1), # 8x8x512 | 27
|
104 |
+
nn.MaxPool2d(2, 2), # 4x4x512 | 28
|
105 |
+
self.norm(512),
|
106 |
+
nn.ReLU(),
|
107 |
+
nn.Dropout(dropout_percentage),
|
108 |
+
)
|
109 |
+
self.l3res = nn.Sequential(
|
110 |
+
nn.Conv2d(512, 512, kernel_size=3, padding=1), # 4x4x512 | 36
|
111 |
+
self.norm(512),
|
112 |
+
nn.ReLU(),
|
113 |
+
nn.Dropout(dropout_percentage),
|
114 |
+
nn.Conv2d(512, 512, kernel_size=3, padding=1), # 4x4x512 | 44
|
115 |
+
self.norm(512),
|
116 |
+
nn.ReLU(),
|
117 |
+
nn.Dropout(dropout_percentage),
|
118 |
+
)
|
119 |
+
self.maxpool = nn.MaxPool2d(4, 4)
|
120 |
+
|
121 |
+
# Classifier
|
122 |
+
self.linear = nn.Linear(512, 10)
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
x = self.prep_layer(x)
|
126 |
+
x = self.l1(x)
|
127 |
+
x = x + self.l1res(x)
|
128 |
+
x = self.l2(x)
|
129 |
+
x = self.l3(x)
|
130 |
+
x = x + self.l3res(x)
|
131 |
+
x = self.maxpool(x)
|
132 |
+
x = x.view(-1, 512)
|
133 |
+
x = self.linear(x)
|
134 |
+
return F.log_softmax(x, dim=1)
|
135 |
+
|
136 |
+
def training_step(self, batch, batch_idx):
|
137 |
+
data, target = batch
|
138 |
+
|
139 |
+
# print("curr lr: ", self.optimizers().param_groups[0]["lr"])
|
140 |
+
|
141 |
+
# forward pass
|
142 |
+
pred = self(data)
|
143 |
+
|
144 |
+
# Calculate loss
|
145 |
+
loss = self.criterion(pred, target)
|
146 |
+
|
147 |
+
# Calculate the metrics
|
148 |
+
accuracy = self.accuracy(pred, target)
|
149 |
+
|
150 |
+
self.log_dict(
|
151 |
+
{"train_loss": loss, "train_acc": accuracy},
|
152 |
+
on_step=True,
|
153 |
+
on_epoch=True,
|
154 |
+
prog_bar=True,
|
155 |
+
logger=True,
|
156 |
+
)
|
157 |
+
|
158 |
+
return loss
|
159 |
+
|
160 |
+
def validation_step(self, batch, batch_idx):
|
161 |
+
data, target = batch
|
162 |
+
|
163 |
+
# forward pass
|
164 |
+
pred = self(data)
|
165 |
+
|
166 |
+
# Calculate loss
|
167 |
+
loss = self.criterion(pred, target)
|
168 |
+
# Calculate the metrics
|
169 |
+
accuracy = self.accuracy(pred, target)
|
170 |
+
|
171 |
+
self.log_dict(
|
172 |
+
{"val_loss": loss, "val_acc": accuracy},
|
173 |
+
on_step=True,
|
174 |
+
on_epoch=True,
|
175 |
+
prog_bar=True,
|
176 |
+
logger=True,
|
177 |
+
)
|
178 |
+
|
179 |
+
return loss
|
180 |
+
|
181 |
+
def test_step(self, batch, batch_idx):
|
182 |
+
data, target = batch
|
183 |
+
|
184 |
+
# forward pass
|
185 |
+
pred = self(data)
|
186 |
+
argmax_pred = pred.argmax(dim=1).cpu()
|
187 |
+
|
188 |
+
# Calculate loss
|
189 |
+
loss = self.criterion(pred, target)
|
190 |
+
|
191 |
+
# Calculate the metrics
|
192 |
+
accuracy = self.accuracy(pred, target)
|
193 |
+
|
194 |
+
self.log_dict(
|
195 |
+
{"test_loss": loss, "test_acc": accuracy},
|
196 |
+
on_step=True,
|
197 |
+
on_epoch=True,
|
198 |
+
prog_bar=True,
|
199 |
+
logger=True,
|
200 |
+
)
|
201 |
+
|
202 |
+
# Update the confusion matrix
|
203 |
+
self.confusion_matrix.update(pred, target)
|
204 |
+
|
205 |
+
# Store the predictions, labels and incorrect predictions
|
206 |
+
data, target, pred, argmax_pred = (
|
207 |
+
data.cpu(),
|
208 |
+
target.cpu(),
|
209 |
+
pred.cpu(),
|
210 |
+
argmax_pred.cpu(),
|
211 |
+
)
|
212 |
+
self.pred_store["test_preds"] = torch.cat(
|
213 |
+
(self.pred_store["test_preds"], argmax_pred), dim=0
|
214 |
+
)
|
215 |
+
self.pred_store["test_labels"] = torch.cat(
|
216 |
+
(self.pred_store["test_labels"], target), dim=0
|
217 |
+
)
|
218 |
+
for d, t, p, o in zip(data, target, argmax_pred, pred):
|
219 |
+
if p.eq(t.view_as(p)).item() == False:
|
220 |
+
self.pred_store["test_incorrect"].append(
|
221 |
+
(d.cpu(), t, p, o[p.item()].cpu())
|
222 |
+
)
|
223 |
+
|
224 |
+
return loss
|
225 |
+
|
226 |
+
def find_bestLR_LRFinder(self, optimizer):
|
227 |
+
lr_finder = LRFinder(self, optimizer, criterion=self.criterion)
|
228 |
+
lr_finder.range_test(
|
229 |
+
self.trainer.datamodule.train_dataloader(),
|
230 |
+
end_lr=config.LRFINDER_END_LR,
|
231 |
+
num_iter=config.LRFINDER_NUM_ITERATIONS,
|
232 |
+
step_mode=config.LRFINDER_STEP_MODE,
|
233 |
+
)
|
234 |
+
best_lr = None
|
235 |
+
try:
|
236 |
+
_, best_lr = lr_finder.plot() # to inspect the loss-learning rate graph
|
237 |
+
except Exception as e:
|
238 |
+
pass
|
239 |
+
lr_finder.reset() # to reset the model and optimizer to their initial state
|
240 |
+
|
241 |
+
return best_lr
|
242 |
+
|
243 |
+
def configure_optimizers(self):
|
244 |
+
optimizer = self.get_only_optimizer()
|
245 |
+
best_lr = self.find_bestLR_LRFinder(optimizer)
|
246 |
+
scheduler = OneCycleLR(
|
247 |
+
optimizer,
|
248 |
+
max_lr=1.47e-03,
|
249 |
+
# total_steps=self.trainer.estimated_stepping_batches,
|
250 |
+
steps_per_epoch=len(self.trainer.datamodule.train_dataloader()),
|
251 |
+
epochs=config.NUM_EPOCHS,
|
252 |
+
pct_start=5 / config.NUM_EPOCHS,
|
253 |
+
div_factor=config.OCLR_DIV_FACTOR,
|
254 |
+
three_phase=config.OCLR_THREE_PHASE,
|
255 |
+
final_div_factor=config.OCLR_FINAL_DIV_FACTOR,
|
256 |
+
anneal_strategy=config.OCLR_ANNEAL_STRATEGY,
|
257 |
+
)
|
258 |
+
return [optimizer], [
|
259 |
+
{"scheduler": scheduler, "interval": "step", "frequency": 1}
|
260 |
+
]
|
261 |
+
|
262 |
+
def get_only_optimizer(self):
|
263 |
+
optimizer = optim.Adam(
|
264 |
+
self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
|
265 |
+
)
|
266 |
+
return optimizer
|
267 |
+
|
268 |
+
def on_test_end(self) -> None:
|
269 |
+
super().on_test_end()
|
270 |
+
|
271 |
+
## Confusion Matrix
|
272 |
+
confmat = self.confusion_matrix.cpu().compute().numpy()
|
273 |
+
if config.NORM_CONF_MAT:
|
274 |
+
df_confmat = pd.DataFrame(
|
275 |
+
confmat / np.sum(confmat, axis=1)[:, None],
|
276 |
+
index=[i for i in config.CLASSES],
|
277 |
+
columns=[i for i in config.CLASSES],
|
278 |
+
)
|
279 |
+
else:
|
280 |
+
df_confmat = pd.DataFrame(
|
281 |
+
confmat,
|
282 |
+
index=[i for i in config.CLASSES],
|
283 |
+
columns=[i for i in config.CLASSES],
|
284 |
+
)
|
285 |
+
plt.figure(figsize=(7, 5))
|
286 |
+
sns.heatmap(df_confmat, annot=True, cmap="Blues", fmt=".3f", linewidths=0.5)
|
287 |
+
plt.tight_layout()
|
288 |
+
plt.ylabel("True label")
|
289 |
+
plt.xlabel("Predicted label")
|
290 |
+
plt.show()
|
291 |
+
|
292 |
+
def plot_incorrect_predictions_helper(self, num_imgs=10):
|
293 |
+
return plot_incorrect_preds(
|
294 |
+
self.pred_store["test_incorrect"], config.CLASSES, num_imgs
|
295 |
+
)
|
Utilities/transforms.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import albumentations as A
|
2 |
+
from albumentations.pytorch import ToTensorV2
|
3 |
+
|
4 |
+
# Test data transformations
|
5 |
+
test_transforms = A.Compose(
|
6 |
+
[
|
7 |
+
A.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
|
8 |
+
ToTensorV2(),
|
9 |
+
]
|
10 |
+
)
|
Utilities/utils.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from pytorch_grad_cam import GradCAM
|
3 |
+
from pytorch_grad_cam.utils.image import show_cam_on_image
|
4 |
+
|
5 |
+
from . import config
|
6 |
+
from .transforms import test_transforms
|
7 |
+
|
8 |
+
|
9 |
+
def generate_confidences(
|
10 |
+
model,
|
11 |
+
input_img,
|
12 |
+
num_top_preds,
|
13 |
+
):
|
14 |
+
input_img = test_transforms(image=input_img)
|
15 |
+
input_img = input_img["image"]
|
16 |
+
|
17 |
+
input_img = input_img.unsqueeze(0)
|
18 |
+
model.eval()
|
19 |
+
log_probs = model(input_img)[0].detach()
|
20 |
+
model.train()
|
21 |
+
probs = torch.exp(log_probs)
|
22 |
+
|
23 |
+
confidences = {
|
24 |
+
config.CLASSES[i]: float(probs[i]) for i in range(len(config.CLASSES))
|
25 |
+
}
|
26 |
+
# Select top 5 confidences based on value
|
27 |
+
confidences = {
|
28 |
+
k: v
|
29 |
+
for k, v in sorted(confidences.items(), key=lambda item: item[1], reverse=True)[
|
30 |
+
:num_top_preds
|
31 |
+
]
|
32 |
+
}
|
33 |
+
return input_img, confidences
|
34 |
+
|
35 |
+
|
36 |
+
def generate_gradcam(
|
37 |
+
model,
|
38 |
+
org_img,
|
39 |
+
input_img,
|
40 |
+
show_gradcam,
|
41 |
+
gradcam_layer,
|
42 |
+
gradcam_opacity,
|
43 |
+
):
|
44 |
+
if show_gradcam:
|
45 |
+
if gradcam_layer == -1:
|
46 |
+
target_layers = [model.l3[-1]]
|
47 |
+
elif gradcam_layer == -2:
|
48 |
+
target_layers = [model.l2[-1]]
|
49 |
+
|
50 |
+
cam = GradCAM(
|
51 |
+
model=model,
|
52 |
+
target_layers=target_layers,
|
53 |
+
)
|
54 |
+
grayscale_cam = cam(input_tensor=input_img, targets=None)
|
55 |
+
grayscale_cam = grayscale_cam[0, :]
|
56 |
+
|
57 |
+
visualization = show_cam_on_image(
|
58 |
+
org_img / 255,
|
59 |
+
grayscale_cam,
|
60 |
+
use_rgb=True,
|
61 |
+
image_weight=(1 - gradcam_opacity),
|
62 |
+
)
|
63 |
+
else:
|
64 |
+
visualization = None
|
65 |
+
return visualization
|
66 |
+
|
67 |
+
|
68 |
+
def generate_missclassified_imgs(
|
69 |
+
model,
|
70 |
+
show_misclassified,
|
71 |
+
num_misclassified,
|
72 |
+
):
|
73 |
+
if show_misclassified:
|
74 |
+
plot = model.plot_incorrect_predictions_helper(num_misclassified)
|
75 |
+
else:
|
76 |
+
plot = None
|
77 |
+
return plot
|
Utilities/visualize.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
from torchvision import transforms
|
3 |
+
|
4 |
+
|
5 |
+
def plot_incorrect_preds(incorrect, classes, num_imgs):
|
6 |
+
import random
|
7 |
+
|
8 |
+
# num_imgs is a multiple of 5
|
9 |
+
assert num_imgs % 5 == 0
|
10 |
+
assert len(incorrect) >= num_imgs
|
11 |
+
|
12 |
+
incorrect_inds = random.sample(range(len(incorrect)), num_imgs)
|
13 |
+
|
14 |
+
# incorrect (data, target, pred, output)
|
15 |
+
fig = plt.figure(figsize=(10, num_imgs // 2))
|
16 |
+
plt.suptitle("Target | Predicted Label")
|
17 |
+
for i in range(num_imgs):
|
18 |
+
curr_incorrect = incorrect[incorrect_inds[i]]
|
19 |
+
plt.subplot(num_imgs // 5, 5, i + 1, aspect="auto")
|
20 |
+
|
21 |
+
# unnormalize = T.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
|
22 |
+
unnormalized = transforms.Normalize(
|
23 |
+
(-1.98947368, -1.98436214, -1.71072797), (4.048583, 4.11522634, 3.83141762)
|
24 |
+
)(curr_incorrect[0])
|
25 |
+
plt.imshow(transforms.ToPILImage()(unnormalized))
|
26 |
+
plt.title(
|
27 |
+
f"{classes[curr_incorrect[1].item()]}|{classes[curr_incorrect[2].item()]}",
|
28 |
+
# fontsize=8,
|
29 |
+
)
|
30 |
+
plt.xticks([])
|
31 |
+
plt.yticks([])
|
32 |
+
plt.tight_layout()
|
33 |
+
return fig
|
app.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from Utilities import config
|
4 |
+
from Utilities.model import Net
|
5 |
+
from Utilities.utils import (
|
6 |
+
generate_confidences,
|
7 |
+
generate_gradcam,
|
8 |
+
generate_missclassified_imgs,
|
9 |
+
)
|
10 |
+
|
11 |
+
model = Net(
|
12 |
+
num_classes=config.NUM_CLASSES,
|
13 |
+
dropout_percentage=config.DROPOUT_PERCENTAGE,
|
14 |
+
norm=config.LAYER_NORM,
|
15 |
+
criterion=config.CRITERION,
|
16 |
+
learning_rate=config.LEARNING_RATE,
|
17 |
+
weight_decay=config.WEIGHT_DECAY,
|
18 |
+
)
|
19 |
+
|
20 |
+
model.load_state_dict(torch.load(config.MODEL_SAVE_PATH))
|
21 |
+
model.pred_store = torch.load(config.PRED_STORE_PATH)
|
22 |
+
|
23 |
+
|
24 |
+
def generate_gradio_output(
|
25 |
+
input_img,
|
26 |
+
num_top_preds,
|
27 |
+
show_gradcam,
|
28 |
+
gradcam_layer,
|
29 |
+
gradcam_opacity,
|
30 |
+
show_misclassified,
|
31 |
+
num_misclassified,
|
32 |
+
):
|
33 |
+
processed_img, confidences = generate_confidences(
|
34 |
+
model=model, input_img=input_img, num_top_preds=num_top_preds
|
35 |
+
)
|
36 |
+
|
37 |
+
visualization = generate_gradcam(
|
38 |
+
model=model,
|
39 |
+
org_img=input_img,
|
40 |
+
input_img=processed_img,
|
41 |
+
show_gradcam=show_gradcam,
|
42 |
+
gradcam_layer=gradcam_layer,
|
43 |
+
gradcam_opacity=gradcam_opacity,
|
44 |
+
)
|
45 |
+
|
46 |
+
plot = generate_missclassified_imgs(
|
47 |
+
model=model,
|
48 |
+
show_misclassified=show_misclassified,
|
49 |
+
num_misclassified=num_misclassified,
|
50 |
+
)
|
51 |
+
|
52 |
+
return confidences, visualization, plot
|
53 |
+
|
54 |
+
|
55 |
+
inputs = [
|
56 |
+
gr.Image(shape=(32, 32), label="Input Image"),
|
57 |
+
gr.Slider(1, 10, value=3, step=1, label="Number of Top Prediction to display"),
|
58 |
+
gr.Checkbox(label="Show GradCAM"),
|
59 |
+
gr.Slider(-2, -1, step=1, value=-1, label="GradCAM Layer (from the end)"),
|
60 |
+
gr.Slider(0, 1, value=0.5, label="GradCAM Opacity"),
|
61 |
+
gr.Checkbox(label="Show Misclassified Images"),
|
62 |
+
gr.Slider(
|
63 |
+
5, 50, value=20, step=5, label="Number of Misclassified Images to display"
|
64 |
+
),
|
65 |
+
]
|
66 |
+
|
67 |
+
outputs = [
|
68 |
+
gr.Label(visible=True, scale=0.5, label="Classification Confidences"),
|
69 |
+
gr.Image(shape=(32, 32), label="GradCAM Output").style(
|
70 |
+
width=256, height=256, visible=True
|
71 |
+
),
|
72 |
+
gr.Plot(visible=True, label="Misclassified Images"),
|
73 |
+
]
|
74 |
+
|
75 |
+
examples = [
|
76 |
+
[config.EXAMPLE_IMG_PATH + "cat.jpeg", 3, True, -2, 0.68, True, 40],
|
77 |
+
[config.EXAMPLE_IMG_PATH + "horse.jpg", 3, True, -2, 0.59, True, 25],
|
78 |
+
[config.EXAMPLE_IMG_PATH + "bird.webp", 10, True, -1, 0.55, True, 20],
|
79 |
+
[config.EXAMPLE_IMG_PATH + "dog1.jpg", 10, True, -1, 0.33, True, 45],
|
80 |
+
[config.EXAMPLE_IMG_PATH + "frog1.webp", 5, True, -1, 0.64, True, 40],
|
81 |
+
[config.EXAMPLE_IMG_PATH + "deer.webp", 1, True, -2, 0.45, True, 20],
|
82 |
+
[config.EXAMPLE_IMG_PATH + "airplane.png", 3, True, -2, 0.43, True, 40],
|
83 |
+
[config.EXAMPLE_IMG_PATH + "shipp.jpg", 7, True, -1, 0.6, True, 30],
|
84 |
+
[config.EXAMPLE_IMG_PATH + "car.jpg", 2, True, -1, 0.68, True, 30],
|
85 |
+
[config.EXAMPLE_IMG_PATH + "truck1.jpg", 5, True, -2, 0.51, True, 35],
|
86 |
+
]
|
87 |
+
|
88 |
+
title = "Image Classification (CIFAR10 - 10 Classes) with GradCAM"
|
89 |
+
description = """A simple Gradio interface to visualize the output of a CNN trained on CIFAR10 dataset with GradCAM and Misclassified images.
|
90 |
+
The architecture is inspired from David Page's (myrtle.ai) DAWNBench winning model archiecture.
|
91 |
+
Please input the image and select the number of top predictions to display - you will see the top predictions and their corresponding confidence scores.
|
92 |
+
You can also select whether to show GradCAM for the particular image (utilizes the gradients of the classification score with respect to the final convolutional feature map, to identify the parts of an input image that most impact the classification score).
|
93 |
+
You need to select the model layer where the gradients need to be plugged from - this affects how much of the image is used to compute the GradCAM.
|
94 |
+
You can also select whether to show misclassified images - these are the images that the model misclassified.
|
95 |
+
Some examples are provided in the examples tab.
|
96 |
+
"""
|
97 |
+
|
98 |
+
gr.Interface(
|
99 |
+
fn=generate_gradio_output,
|
100 |
+
inputs=inputs,
|
101 |
+
outputs=outputs,
|
102 |
+
examples=examples,
|
103 |
+
title=title,
|
104 |
+
description=description,
|
105 |
+
).launch()
|