File size: 11,714 Bytes
191195c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
import matplotlib.pyplot as plt
import numpy as np
import cv2
import math
import itertools
from sklearn.metrics import confusion_matrix, balanced_accuracy_score, precision_score, recall_score
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.ticker import LinearLocator

def plot_batch(ds, fn, max_samples=9, figscale:int=2):
    """ Plot batch samples from a dataset and save to fn """
    if hasattr(ds,'element_spec'): # ad-hoc test to assess if ds is tf.dataset or not...
        # draw a batch from tf.data dataset
        images, labels = next(iter(ds))
        images = images.numpy()
        labels = labels.numpy()
    else:
        # draw a batch from ImageDataGenerator-based dataset
        images, labels = ds.next()
        # images, labels = ds.take(1)
    
    # limit the number of shown images
    n_samples = min(len(images), max_samples)
    assert n_samples > 0
    root_n = int(math.ceil(math.sqrt(n_samples)))

    f, axs = plt.subplots(root_n, root_n, figsize=(root_n*figscale, root_n*figscale))
    if root_n > 1:
        axs = axs.flat
    
    # find denormalization parameters
    min_image_val = min([np.array(img).min() for img in images])
    max_image_val = max([np.array(img).max() for img in images])

    for i in range(n_samples):
        img = _denorm(images[i], min_image_val, max_image_val)
        ax = axs[i] if root_n > 1 else axs
        ax.imshow(img)
        ax.set_title(labels[i])
        ax.set_axis_off()
    
    # hide rest of the axes
    for j in range(i, root_n*root_n):
        axs[j].set_axis_off()

    plt.savefig(fn, transparent=False)
    plt.close()

def plot_batch_preds(
    images,
    labels,
    preds, 
    fn,
    gradcams=None, 
    title="", 
    max_samples=9, 
    figscale:int=2
    ):
    """ 
    Plot batch samples with preds and labels 
    
    Arguments:
        images     (list): images from dataset batch
        labels     (list): labels from dataset batch
        preds      (list): predictions for items
        fn          (str): filepath to save figure to
        gradcams   (grid): gradcam grid (optional)
        title       (str): plot title
        max_samples (int): max number of samples to plot
        figscale    (int): figure scaling factor. 2=regular, 3=large, 4=larger, etc.
    """
    # limit the number of shown images
    n_samples = min(len(images), max_samples)
    assert n_samples > 0
    root_n = int(math.ceil(math.sqrt(n_samples)))

    # leave some space vertically for the labels
    fig, axs = plt.subplots(root_n, root_n, figsize=(root_n*figscale, (1 + root_n)*figscale))
    if root_n > 1:
        axs = axs.flat

    fig.suptitle(title,fontsize=20)
    
    # find denormalization parameters
    min_image_val = min([np.array(img).min() for img in images])
    max_image_val = max([np.array(img).max() for img in images])
    
    # unpack concatenated gradcam grid
    if gradcams is None:
        gradcam_list = [None for _ in range(n_samples)]
    else:
        gradcam_list = []
        rolling_x = 0; rolling_y = 0
        H,W = images[0].shape[:2]
        orig_h, orig_w = gradcams.shape[:2]
        while rolling_y < orig_h:
            while rolling_x < orig_w:
                gradcam_list.append(gradcams[rolling_y:rolling_y+H, rolling_x:rolling_x+W])
                rolling_x += W
            rolling_x = 0
            rolling_y += H

    for i in range(n_samples):
        img = _denorm(images[i], min_image_val, max_image_val)
        ax = axs[i] if root_n > 1 else axs
        
        # plot a single channel of the image
        ax.imshow(img[:,:,0], cmap='gray')

        # overlay gradcam activations
        if gradcam_list[i] is not None:
            H,W = img.shape[:2]
            ax.imshow(gradcam_list[i], alpha=0.5, extent=(0,H,W,0), interpolation='bilinear')
        
        ax.set_title(f'GT:{labels[i]}\nPred:{preds[i]}')
        ax.set_axis_off()
    
    # hide rest of the axes
    for j in range(i, root_n*root_n):
        axs[j].set_axis_off()

    plt.savefig(fn, transparent=False)
    plt.close()

def _denorm(img, min_image_val, max_image_val):
    """ Denormalize image by a min and max value """
    if min_image_val < 0:
        if max_image_val > 1: # [-127,127]
            img = (img + 127) / 255.
        else: # [-1,1]
            img = (img + 1.) / 2.
    elif max_image_val > 1: # no scaling
        img /= 255.
    else: # [0,1]
        pass
    return np.clip(img, 0,1) 


def plot_history(history, fn, figscale=5):
    metric_keys = history.history.keys()
    # metric keys contain val_ prefixes for validation set, we want to plot these in the same graph with train
    train_keys = [key for key in metric_keys if 'val_' not in key]
    fig, axs = plt.subplots(len(train_keys), 1, figsize=(figscale, figscale*len(train_keys)))

    for i, train_key in enumerate(train_keys):
        val_key = 'val_' + train_key
        assert val_key in metric_keys

        train_metric = history.history[train_key]
        val_metric = history.history[val_key]

        ax = axs[i] if len(train_keys) > 1 else axs

        ax.plot(train_metric, label=f'Training {train_key}')
        ax.plot(val_metric, label=f'Validation {train_key}')
        ax.legend()
        ax.set_ylabel(f'{train_key}')
        ax.set_title(f'Training and Validation {train_key}')

    plt.savefig(fn, transparent=False)
    plt.close()

def plot_confusion_matrix(trues, 
                          preds,
                          target_names,
                          title='Confusion matrix',
                          cmap=None,
                          normalize=True,
                          figsize=(10, 10)):
    """
    given a sklearn confusion matrix (cm), make a nice plot

    Arguments
    ---------
    trues:           Ground truth array

    preds:           Predicted array

    target_names: given classification classes such as [0, 1, 2]
                  the class names, for example: ['high', 'medium', 'low']

    title:        the text to display at the top of the matrix

    cmap:         the gradient of the values displayed from matplotlib.pyplot.cm
                  see http://matplotlib.org/examples/color/colormaps_reference.html
                  plt.get_cmap('jet') or plt.cm.Blues

    normalize:    If False, plot the raw numbers
                  If True, plot the proportions

    figsize:      tuple of matplotlib figure size

    """
    cm = confusion_matrix(trues, preds)

    accuracy = np.trace(cm) / np.sum(cm).astype('float')
    balanced_acc = balanced_accuracy_score(trues, preds)
    misclass = 1 - accuracy

    # calculate precision and recall for each class
    class_scores = '\n\n'
    for i, cl in enumerate(target_names):
        cl_trues = np.where(np.array(trues) == i, 1, 0)
        cl_preds = np.where(np.array(preds) == i, 1, 0)
        precision = precision_score(cl_trues, cl_preds)
        recall = recall_score(cl_trues, cl_preds)
        class_scores += cl + ' precision {:.3f}, recall {:.3f}'.format(precision, recall) + '\n'

    if cmap is None:
        cmap = plt.get_cmap('Blues')

    plt.figure(figsize=figsize)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title + class_scores)
    plt.colorbar()

    if target_names is not None:
        tick_marks = np.arange(len(target_names))
        plt.xticks(tick_marks, target_names, rotation=45)
        plt.yticks(tick_marks, target_names)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    thresh = cm.max() / 1.5 if normalize else cm.max() / 2
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        if normalize:
            plt.text(j, i, "{:0.4f}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")
        else:
            plt.text(j, i, "{:,}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")


    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}; balanced accuracy={:0.4f}'.format(accuracy, misclass, balanced_acc))

def plot_surface_v1(heat_map):
    # create the x and y coordinate arrays (here we just use pixel indices)
    heat_gray = cv2.cvtColor(heat_map, cv2.COLOR_BGR2GRAY)
    lena = heat_gray
    xx, yy = np.mgrid[0:lena.shape[0], 0:lena.shape[1]]

    # create the figure
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    # ax2 = fig.add_subplot(1,2,2,projection='3d')
    surf = ax.plot_surface(xx, yy, lena ,rstride=1, cstride=1, cmap=plt.cm.jet,
                    linewidth=0)
    # ax.invert_zaxis()
    # ax.view_init(20,60)
    ax.view_init(30,60)
    fig.colorbar(surf, shrink=0.5, aspect=5)

    # show it
    plt.show()


def plot_surface(fig,heat_map,cnt):
    # create the x and y coordinate arrays (here we just use pixel indices)
    heat_gray = cv2.cvtColor(heat_map, cv2.COLOR_BGR2GRAY)
    lena = heat_gray
    # lena[:60,100:] = heat_gray[:60,:80]
    xx, yy = np.mgrid[0:lena.shape[0], 0:lena.shape[1]]

    # create the figure
    # fig = plt.figure()
    # ax = fig.gca(projection='3d')
    ax = fig.add_subplot(6,5,cnt,projection='3d')
    surf = ax.plot_surface(xx, yy, lena ,rstride=1, cstride=1, cmap=plt.cm.jet,
                    linewidth=0)
    # ax.invert_zaxis()
    ax.view_init(30,60)
    fig.colorbar(surf, shrink=0.5, aspect=5)

def visualize(anchor, positive, negative):
    """Visualize a few triplets from the supplied batches."""

    def show(ax, image):
        ax.imshow(image)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

    fig = plt.figure(figsize=(9, 9))

    axs = fig.subplots(3, 3)
    for i in range(3):
        show(axs[i, 0], anchor[i])
        show(axs[i, 1], positive[i])
        show(axs[i, 2], negative[i])

def draw_cv2_bbox(img,bboxes,CLASSES,scores=None):
    LABELS = ["Background","Bed","Chair","Wheelchair"]
    COLORS = [(0,0,0),(255, 0, 0),(255, 128, 0),(0,255,00)]
    font = cv2.FONT_HERSHEY_SIMPLEX
    fontScale = 0.5
    thickness = 1
    # CLASSES = CLASSES
    # boxes = request_response['boxes']
    # masks = np.asarray(request_response['masks'])
    # print(bboxes)
    for i, box in enumerate(bboxes):
        class_id = CLASSES[i]
        class_text = LABELS[class_id]
        # pts = results['detection_boxes'][0][i3]
        xmin = int(round(box[0]))
        ymin = int(round(box[1]))
        xmax = int(round(box[2]))
        ymax = int(round(box[3]))
        # print(xmin,ymin,xmax,ymax)
        color = COLORS[class_id]
        img = cv2.rectangle(img,(xmin,ymin),(xmax,ymax),color,1)
        if not scores:
            score = scores[i]
            tmp_txt = f'{score:.3f}'
            img = cv2.putText(img,tmp_txt, (xmin,ymin+5), font, fontScale, color, thickness, cv2.LINE_AA, False)


def draw_bbox(img,bbox, color=(255,0,00)):
        # pts = results['detection_boxes'][0][i3]
        xmin = bbox[0]
        ymin = bbox[1]
        xmax = bbox[2]
        ymax = bbox[3]
        img = cv2.rectangle(img,(xmin,ymin),(xmax,ymax),color,1)
        return img

def consctruct_bbox(xc,yc,width=40,height=30):
    # tmp_ar = 30
    xmin = xc - width+10
    xmax = xc + width+10
    ymin = yc - height
    ymax = yc + height
    return [xmin,ymin,xmax,ymax]

# def draw_bbox(img,bbox, color=(255,0,00)):
#         # pts = results['detection_boxes'][0][i3]
#         xmin = bbox[0]
#         ymin = bbox[1]
#         xmax = bbox[2]
#         ymax = bbox[3]
#         img = cv2.rectangle(img,(xmin,ymin),(xmax,ymax),color,1)
#         return img