madhurjindal commited on
Commit
06a7cdc
·
1 Parent(s): 42e455d

Upload 19 files

Browse files
Store/examples/airplane.png ADDED
Store/examples/bird.webp ADDED
Store/examples/car.jpg ADDED
Store/examples/cat.jpeg ADDED
Store/examples/deer.webp ADDED
Store/examples/dog1.jpg ADDED
Store/examples/frog1.webp ADDED
Store/examples/horse.jpg ADDED
Store/examples/shipp.jpg ADDED
Store/examples/truck1.jpg ADDED
Store/model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbd64f23fadf7bffb54d9f55e39771ebb15e40e3d64660d3972cc650def37d51
3
+ size 26333951
Store/pred_store.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c35e81824afa9a906cb1d599fee2e9f79a2f776dd33654c0a879d26833abf3e
3
+ size 123716523
Utilities/__init__.py ADDED
File without changes
Utilities/config.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+
3
+ # Seed
4
+ SEED = 1
5
+
6
+ # Dataset
7
+
8
+ CLASSES = (
9
+ "Airplane",
10
+ "Automobile",
11
+ "Bird",
12
+ "Cat",
13
+ "Deer",
14
+ "Dog",
15
+ "Frog",
16
+ "Horse",
17
+ "Ship",
18
+ "Truck",
19
+ )
20
+
21
+ SHUFFLE = True
22
+ DATA_DIR = "../data"
23
+ NUM_WORKERS = 4
24
+ PIN_MEMORY = True
25
+
26
+ # Training Hyperparameters
27
+ CRITERION = F.cross_entropy
28
+ INPUT_SIZE = (3, 32, 32)
29
+ NUM_CLASSES = 10
30
+ LEARNING_RATE = 0.001
31
+ WEIGHT_DECAY = 1e-4
32
+ BATCH_SIZE = 512
33
+ NUM_EPOCHS = 24
34
+ DROPOUT_PERCENTAGE = 0.05
35
+ LAYER_NORM = "bn" # Batch Normalization
36
+
37
+ # OPTIMIZER & SCHEDULER
38
+
39
+ LRFINDER_END_LR = 0.1
40
+ LRFINDER_NUM_ITERATIONS = 50
41
+ LRFINDER_STEP_MODE = "exp"
42
+
43
+ OCLR_DIV_FACTOR = 100
44
+ OCLR_FINAL_DIV_FACTOR = 100
45
+ OCLR_THREE_PHASE = False
46
+ OCLR_ANNEAL_STRATEGY = "linear"
47
+
48
+ # Compute Related
49
+
50
+ ACCELERATOR = "cuda"
51
+ PRECISION = 32
52
+
53
+ # Store
54
+
55
+ TRAINING_STAT_STORE = "Store/training_stats.csv"
56
+ MODEL_SAVE_PATH = "Store/model.pth"
57
+ PRED_STORE_PATH = "Store/pred_store.pth"
58
+ EXAMPLE_IMG_PATH = "Store/examples/"
59
+
60
+ # Visualization
61
+
62
+ NORM_CONF_MAT = True
Utilities/model.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import pandas as pd
4
+ import pytorch_lightning as pl
5
+ import seaborn as sns
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.optim as optim
10
+ import torchmetrics
11
+ from torch.optim.lr_scheduler import OneCycleLR
12
+ from torch_lr_finder import LRFinder
13
+
14
+ from . import config
15
+ from .visualize import plot_incorrect_preds
16
+
17
+
18
+ class Net(pl.LightningModule):
19
+ def __init__(
20
+ self,
21
+ num_classes=10,
22
+ dropout_percentage=0,
23
+ norm="bn",
24
+ num_groups=2,
25
+ criterion=F.cross_entropy,
26
+ learning_rate=0.001,
27
+ weight_decay=0.0,
28
+ ):
29
+ super(Net, self).__init__()
30
+ if norm == "bn":
31
+ self.norm = nn.BatchNorm2d
32
+ elif norm == "gn":
33
+ self.norm = lambda in_dim: nn.GroupNorm(
34
+ num_groups=num_groups, num_channels=in_dim
35
+ )
36
+ elif norm == "ln":
37
+ self.norm = lambda in_dim: nn.GroupNorm(num_groups=1, num_channels=in_dim)
38
+
39
+ # Define the loss criterion
40
+ self.criterion = criterion
41
+
42
+ # Define the Metrics
43
+ self.accuracy = torchmetrics.Accuracy(
44
+ task="multiclass", num_classes=num_classes
45
+ )
46
+ self.confusion_matrix = torchmetrics.ConfusionMatrix(
47
+ task="multiclass", num_classes=config.NUM_CLASSES
48
+ )
49
+
50
+ # Define the Optimizer Hyperparameters
51
+ self.learning_rate = learning_rate
52
+ self.weight_decay = weight_decay
53
+
54
+ # Prediction Storage
55
+ self.pred_store = {
56
+ "test_preds": torch.tensor([]),
57
+ "test_labels": torch.tensor([]),
58
+ "test_incorrect": [],
59
+ }
60
+ self.log_store = {
61
+ "train_loss_epoch": [],
62
+ "train_acc_epoch": [],
63
+ "val_loss_epoch": [],
64
+ "val_acc_epoch": [],
65
+ "test_loss_epoch": [],
66
+ "test_acc_epoch": [],
67
+ }
68
+
69
+ # This defines the structure of the NN.
70
+ # Prep Layer
71
+ self.prep_layer = nn.Sequential(
72
+ nn.Conv2d(3, 64, kernel_size=3, padding=1), # 32x32x3 | 1 -> 32x32x64 | 3
73
+ self.norm(64),
74
+ nn.ReLU(),
75
+ nn.Dropout(dropout_percentage),
76
+ )
77
+
78
+ self.l1 = nn.Sequential(
79
+ nn.Conv2d(64, 128, kernel_size=3, padding=1), # 32x32x128 | 5
80
+ nn.MaxPool2d(2, 2), # 16x16x128 | 6
81
+ self.norm(128),
82
+ nn.ReLU(),
83
+ nn.Dropout(dropout_percentage),
84
+ )
85
+ self.l1res = nn.Sequential(
86
+ nn.Conv2d(128, 128, kernel_size=3, padding=1), # 16x16x128 | 10
87
+ self.norm(128),
88
+ nn.ReLU(),
89
+ nn.Dropout(dropout_percentage),
90
+ nn.Conv2d(128, 128, kernel_size=3, padding=1), # 16x16x128 | 14
91
+ self.norm(128),
92
+ nn.ReLU(),
93
+ nn.Dropout(dropout_percentage),
94
+ )
95
+ self.l2 = nn.Sequential(
96
+ nn.Conv2d(128, 256, kernel_size=3, padding=1), # 16x16x256 | 18
97
+ nn.MaxPool2d(2, 2), # 8x8x256 | 19
98
+ self.norm(256),
99
+ nn.ReLU(),
100
+ nn.Dropout(dropout_percentage),
101
+ )
102
+ self.l3 = nn.Sequential(
103
+ nn.Conv2d(256, 512, kernel_size=3, padding=1), # 8x8x512 | 27
104
+ nn.MaxPool2d(2, 2), # 4x4x512 | 28
105
+ self.norm(512),
106
+ nn.ReLU(),
107
+ nn.Dropout(dropout_percentage),
108
+ )
109
+ self.l3res = nn.Sequential(
110
+ nn.Conv2d(512, 512, kernel_size=3, padding=1), # 4x4x512 | 36
111
+ self.norm(512),
112
+ nn.ReLU(),
113
+ nn.Dropout(dropout_percentage),
114
+ nn.Conv2d(512, 512, kernel_size=3, padding=1), # 4x4x512 | 44
115
+ self.norm(512),
116
+ nn.ReLU(),
117
+ nn.Dropout(dropout_percentage),
118
+ )
119
+ self.maxpool = nn.MaxPool2d(4, 4)
120
+
121
+ # Classifier
122
+ self.linear = nn.Linear(512, 10)
123
+
124
+ def forward(self, x):
125
+ x = self.prep_layer(x)
126
+ x = self.l1(x)
127
+ x = x + self.l1res(x)
128
+ x = self.l2(x)
129
+ x = self.l3(x)
130
+ x = x + self.l3res(x)
131
+ x = self.maxpool(x)
132
+ x = x.view(-1, 512)
133
+ x = self.linear(x)
134
+ return F.log_softmax(x, dim=1)
135
+
136
+ def training_step(self, batch, batch_idx):
137
+ data, target = batch
138
+
139
+ # print("curr lr: ", self.optimizers().param_groups[0]["lr"])
140
+
141
+ # forward pass
142
+ pred = self(data)
143
+
144
+ # Calculate loss
145
+ loss = self.criterion(pred, target)
146
+
147
+ # Calculate the metrics
148
+ accuracy = self.accuracy(pred, target)
149
+
150
+ self.log_dict(
151
+ {"train_loss": loss, "train_acc": accuracy},
152
+ on_step=True,
153
+ on_epoch=True,
154
+ prog_bar=True,
155
+ logger=True,
156
+ )
157
+
158
+ return loss
159
+
160
+ def validation_step(self, batch, batch_idx):
161
+ data, target = batch
162
+
163
+ # forward pass
164
+ pred = self(data)
165
+
166
+ # Calculate loss
167
+ loss = self.criterion(pred, target)
168
+ # Calculate the metrics
169
+ accuracy = self.accuracy(pred, target)
170
+
171
+ self.log_dict(
172
+ {"val_loss": loss, "val_acc": accuracy},
173
+ on_step=True,
174
+ on_epoch=True,
175
+ prog_bar=True,
176
+ logger=True,
177
+ )
178
+
179
+ return loss
180
+
181
+ def test_step(self, batch, batch_idx):
182
+ data, target = batch
183
+
184
+ # forward pass
185
+ pred = self(data)
186
+ argmax_pred = pred.argmax(dim=1).cpu()
187
+
188
+ # Calculate loss
189
+ loss = self.criterion(pred, target)
190
+
191
+ # Calculate the metrics
192
+ accuracy = self.accuracy(pred, target)
193
+
194
+ self.log_dict(
195
+ {"test_loss": loss, "test_acc": accuracy},
196
+ on_step=True,
197
+ on_epoch=True,
198
+ prog_bar=True,
199
+ logger=True,
200
+ )
201
+
202
+ # Update the confusion matrix
203
+ self.confusion_matrix.update(pred, target)
204
+
205
+ # Store the predictions, labels and incorrect predictions
206
+ data, target, pred, argmax_pred = (
207
+ data.cpu(),
208
+ target.cpu(),
209
+ pred.cpu(),
210
+ argmax_pred.cpu(),
211
+ )
212
+ self.pred_store["test_preds"] = torch.cat(
213
+ (self.pred_store["test_preds"], argmax_pred), dim=0
214
+ )
215
+ self.pred_store["test_labels"] = torch.cat(
216
+ (self.pred_store["test_labels"], target), dim=0
217
+ )
218
+ for d, t, p, o in zip(data, target, argmax_pred, pred):
219
+ if p.eq(t.view_as(p)).item() == False:
220
+ self.pred_store["test_incorrect"].append(
221
+ (d.cpu(), t, p, o[p.item()].cpu())
222
+ )
223
+
224
+ return loss
225
+
226
+ def find_bestLR_LRFinder(self, optimizer):
227
+ lr_finder = LRFinder(self, optimizer, criterion=self.criterion)
228
+ lr_finder.range_test(
229
+ self.trainer.datamodule.train_dataloader(),
230
+ end_lr=config.LRFINDER_END_LR,
231
+ num_iter=config.LRFINDER_NUM_ITERATIONS,
232
+ step_mode=config.LRFINDER_STEP_MODE,
233
+ )
234
+ best_lr = None
235
+ try:
236
+ _, best_lr = lr_finder.plot() # to inspect the loss-learning rate graph
237
+ except Exception as e:
238
+ pass
239
+ lr_finder.reset() # to reset the model and optimizer to their initial state
240
+
241
+ return best_lr
242
+
243
+ def configure_optimizers(self):
244
+ optimizer = self.get_only_optimizer()
245
+ best_lr = self.find_bestLR_LRFinder(optimizer)
246
+ scheduler = OneCycleLR(
247
+ optimizer,
248
+ max_lr=1.47e-03,
249
+ # total_steps=self.trainer.estimated_stepping_batches,
250
+ steps_per_epoch=len(self.trainer.datamodule.train_dataloader()),
251
+ epochs=config.NUM_EPOCHS,
252
+ pct_start=5 / config.NUM_EPOCHS,
253
+ div_factor=config.OCLR_DIV_FACTOR,
254
+ three_phase=config.OCLR_THREE_PHASE,
255
+ final_div_factor=config.OCLR_FINAL_DIV_FACTOR,
256
+ anneal_strategy=config.OCLR_ANNEAL_STRATEGY,
257
+ )
258
+ return [optimizer], [
259
+ {"scheduler": scheduler, "interval": "step", "frequency": 1}
260
+ ]
261
+
262
+ def get_only_optimizer(self):
263
+ optimizer = optim.Adam(
264
+ self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
265
+ )
266
+ return optimizer
267
+
268
+ def on_test_end(self) -> None:
269
+ super().on_test_end()
270
+
271
+ ## Confusion Matrix
272
+ confmat = self.confusion_matrix.cpu().compute().numpy()
273
+ if config.NORM_CONF_MAT:
274
+ df_confmat = pd.DataFrame(
275
+ confmat / np.sum(confmat, axis=1)[:, None],
276
+ index=[i for i in config.CLASSES],
277
+ columns=[i for i in config.CLASSES],
278
+ )
279
+ else:
280
+ df_confmat = pd.DataFrame(
281
+ confmat,
282
+ index=[i for i in config.CLASSES],
283
+ columns=[i for i in config.CLASSES],
284
+ )
285
+ plt.figure(figsize=(7, 5))
286
+ sns.heatmap(df_confmat, annot=True, cmap="Blues", fmt=".3f", linewidths=0.5)
287
+ plt.tight_layout()
288
+ plt.ylabel("True label")
289
+ plt.xlabel("Predicted label")
290
+ plt.show()
291
+
292
+ def plot_incorrect_predictions_helper(self, num_imgs=10):
293
+ return plot_incorrect_preds(
294
+ self.pred_store["test_incorrect"], config.CLASSES, num_imgs
295
+ )
Utilities/transforms.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+ from albumentations.pytorch import ToTensorV2
3
+
4
+ # Test data transformations
5
+ test_transforms = A.Compose(
6
+ [
7
+ A.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
8
+ ToTensorV2(),
9
+ ]
10
+ )
Utilities/utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pytorch_grad_cam import GradCAM
3
+ from pytorch_grad_cam.utils.image import show_cam_on_image
4
+
5
+ from . import config
6
+ from .transforms import test_transforms
7
+
8
+
9
+ def generate_confidences(
10
+ model,
11
+ input_img,
12
+ num_top_preds,
13
+ ):
14
+ input_img = test_transforms(image=input_img)
15
+ input_img = input_img["image"]
16
+
17
+ input_img = input_img.unsqueeze(0)
18
+ model.eval()
19
+ log_probs = model(input_img)[0].detach()
20
+ model.train()
21
+ probs = torch.exp(log_probs)
22
+
23
+ confidences = {
24
+ config.CLASSES[i]: float(probs[i]) for i in range(len(config.CLASSES))
25
+ }
26
+ # Select top 5 confidences based on value
27
+ confidences = {
28
+ k: v
29
+ for k, v in sorted(confidences.items(), key=lambda item: item[1], reverse=True)[
30
+ :num_top_preds
31
+ ]
32
+ }
33
+ return input_img, confidences
34
+
35
+
36
+ def generate_gradcam(
37
+ model,
38
+ org_img,
39
+ input_img,
40
+ show_gradcam,
41
+ gradcam_layer,
42
+ gradcam_opacity,
43
+ ):
44
+ if show_gradcam:
45
+ if gradcam_layer == -1:
46
+ target_layers = [model.l3[-1]]
47
+ elif gradcam_layer == -2:
48
+ target_layers = [model.l2[-1]]
49
+
50
+ cam = GradCAM(
51
+ model=model,
52
+ target_layers=target_layers,
53
+ )
54
+ grayscale_cam = cam(input_tensor=input_img, targets=None)
55
+ grayscale_cam = grayscale_cam[0, :]
56
+
57
+ visualization = show_cam_on_image(
58
+ org_img / 255,
59
+ grayscale_cam,
60
+ use_rgb=True,
61
+ image_weight=(1 - gradcam_opacity),
62
+ )
63
+ else:
64
+ visualization = None
65
+ return visualization
66
+
67
+
68
+ def generate_missclassified_imgs(
69
+ model,
70
+ show_misclassified,
71
+ num_misclassified,
72
+ ):
73
+ if show_misclassified:
74
+ plot = model.plot_incorrect_predictions_helper(num_misclassified)
75
+ else:
76
+ plot = None
77
+ return plot
Utilities/visualize.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ from torchvision import transforms
3
+
4
+
5
+ def plot_incorrect_preds(incorrect, classes, num_imgs):
6
+ import random
7
+
8
+ # num_imgs is a multiple of 5
9
+ assert num_imgs % 5 == 0
10
+ assert len(incorrect) >= num_imgs
11
+
12
+ incorrect_inds = random.sample(range(len(incorrect)), num_imgs)
13
+
14
+ # incorrect (data, target, pred, output)
15
+ fig = plt.figure(figsize=(10, num_imgs // 2))
16
+ plt.suptitle("Target | Predicted Label")
17
+ for i in range(num_imgs):
18
+ curr_incorrect = incorrect[incorrect_inds[i]]
19
+ plt.subplot(num_imgs // 5, 5, i + 1, aspect="auto")
20
+
21
+ # unnormalize = T.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
22
+ unnormalized = transforms.Normalize(
23
+ (-1.98947368, -1.98436214, -1.71072797), (4.048583, 4.11522634, 3.83141762)
24
+ )(curr_incorrect[0])
25
+ plt.imshow(transforms.ToPILImage()(unnormalized))
26
+ plt.title(
27
+ f"{classes[curr_incorrect[1].item()]}|{classes[curr_incorrect[2].item()]}",
28
+ # fontsize=8,
29
+ )
30
+ plt.xticks([])
31
+ plt.yticks([])
32
+ plt.tight_layout()
33
+ return fig
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from Utilities import config
4
+ from Utilities.model import Net
5
+ from Utilities.utils import (
6
+ generate_confidences,
7
+ generate_gradcam,
8
+ generate_missclassified_imgs,
9
+ )
10
+
11
+ model = Net(
12
+ num_classes=config.NUM_CLASSES,
13
+ dropout_percentage=config.DROPOUT_PERCENTAGE,
14
+ norm=config.LAYER_NORM,
15
+ criterion=config.CRITERION,
16
+ learning_rate=config.LEARNING_RATE,
17
+ weight_decay=config.WEIGHT_DECAY,
18
+ )
19
+
20
+ model.load_state_dict(torch.load(config.MODEL_SAVE_PATH))
21
+ model.pred_store = torch.load(config.PRED_STORE_PATH)
22
+
23
+
24
+ def generate_gradio_output(
25
+ input_img,
26
+ num_top_preds,
27
+ show_gradcam,
28
+ gradcam_layer,
29
+ gradcam_opacity,
30
+ show_misclassified,
31
+ num_misclassified,
32
+ ):
33
+ processed_img, confidences = generate_confidences(
34
+ model=model, input_img=input_img, num_top_preds=num_top_preds
35
+ )
36
+
37
+ visualization = generate_gradcam(
38
+ model=model,
39
+ org_img=input_img,
40
+ input_img=processed_img,
41
+ show_gradcam=show_gradcam,
42
+ gradcam_layer=gradcam_layer,
43
+ gradcam_opacity=gradcam_opacity,
44
+ )
45
+
46
+ plot = generate_missclassified_imgs(
47
+ model=model,
48
+ show_misclassified=show_misclassified,
49
+ num_misclassified=num_misclassified,
50
+ )
51
+
52
+ return confidences, visualization, plot
53
+
54
+
55
+ inputs = [
56
+ gr.Image(shape=(32, 32), label="Input Image"),
57
+ gr.Slider(1, 10, value=3, step=1, label="Number of Top Prediction to display"),
58
+ gr.Checkbox(label="Show GradCAM"),
59
+ gr.Slider(-2, -1, step=1, value=-1, label="GradCAM Layer (from the end)"),
60
+ gr.Slider(0, 1, value=0.5, label="GradCAM Opacity"),
61
+ gr.Checkbox(label="Show Misclassified Images"),
62
+ gr.Slider(
63
+ 5, 50, value=20, step=5, label="Number of Misclassified Images to display"
64
+ ),
65
+ ]
66
+
67
+ outputs = [
68
+ gr.Label(visible=True, scale=0.5, label="Classification Confidences"),
69
+ gr.Image(shape=(32, 32), label="GradCAM Output").style(
70
+ width=256, height=256, visible=True
71
+ ),
72
+ gr.Plot(visible=True, label="Misclassified Images"),
73
+ ]
74
+
75
+ examples = [
76
+ [config.EXAMPLE_IMG_PATH + "cat.jpeg", 3, True, -2, 0.68, True, 40],
77
+ [config.EXAMPLE_IMG_PATH + "horse.jpg", 3, True, -2, 0.59, True, 25],
78
+ [config.EXAMPLE_IMG_PATH + "bird.webp", 10, True, -1, 0.55, True, 20],
79
+ [config.EXAMPLE_IMG_PATH + "dog1.jpg", 10, True, -1, 0.33, True, 45],
80
+ [config.EXAMPLE_IMG_PATH + "frog1.webp", 5, True, -1, 0.64, True, 40],
81
+ [config.EXAMPLE_IMG_PATH + "deer.webp", 1, True, -2, 0.45, True, 20],
82
+ [config.EXAMPLE_IMG_PATH + "airplane.png", 3, True, -2, 0.43, True, 40],
83
+ [config.EXAMPLE_IMG_PATH + "shipp.jpg", 7, True, -1, 0.6, True, 30],
84
+ [config.EXAMPLE_IMG_PATH + "car.jpg", 2, True, -1, 0.68, True, 30],
85
+ [config.EXAMPLE_IMG_PATH + "truck1.jpg", 5, True, -2, 0.51, True, 35],
86
+ ]
87
+
88
+ title = "Image Classification (CIFAR10 - 10 Classes) with GradCAM"
89
+ description = """A simple Gradio interface to visualize the output of a CNN trained on CIFAR10 dataset with GradCAM and Misclassified images.
90
+ The architecture is inspired from David Page's (myrtle.ai) DAWNBench winning model archiecture.
91
+ Please input the image and select the number of top predictions to display - you will see the top predictions and their corresponding confidence scores.
92
+ You can also select whether to show GradCAM for the particular image (utilizes the gradients of the classification score with respect to the final convolutional feature map, to identify the parts of an input image that most impact the classification score).
93
+ You need to select the model layer where the gradients need to be plugged from - this affects how much of the image is used to compute the GradCAM.
94
+ You can also select whether to show misclassified images - these are the images that the model misclassified.
95
+ Some examples are provided in the examples tab.
96
+ """
97
+
98
+ gr.Interface(
99
+ fn=generate_gradio_output,
100
+ inputs=inputs,
101
+ outputs=outputs,
102
+ examples=examples,
103
+ title=title,
104
+ description=description,
105
+ ).launch()