Spaces:
Sleeping
Sleeping
import matplotlib.pyplot as plt | |
import numpy as np | |
import cv2 | |
import math | |
import itertools | |
from sklearn.metrics import confusion_matrix, balanced_accuracy_score, precision_score, recall_score | |
from matplotlib import cm | |
from mpl_toolkits.mplot3d import Axes3D | |
from matplotlib.ticker import LinearLocator | |
def plot_batch(ds, fn, max_samples=9, figscale:int=2): | |
""" Plot batch samples from a dataset and save to fn """ | |
if hasattr(ds,'element_spec'): # ad-hoc test to assess if ds is tf.dataset or not... | |
# draw a batch from tf.data dataset | |
images, labels = next(iter(ds)) | |
images = images.numpy() | |
labels = labels.numpy() | |
else: | |
# draw a batch from ImageDataGenerator-based dataset | |
images, labels = ds.next() | |
# images, labels = ds.take(1) | |
# limit the number of shown images | |
n_samples = min(len(images), max_samples) | |
assert n_samples > 0 | |
root_n = int(math.ceil(math.sqrt(n_samples))) | |
f, axs = plt.subplots(root_n, root_n, figsize=(root_n*figscale, root_n*figscale)) | |
if root_n > 1: | |
axs = axs.flat | |
# find denormalization parameters | |
min_image_val = min([np.array(img).min() for img in images]) | |
max_image_val = max([np.array(img).max() for img in images]) | |
for i in range(n_samples): | |
img = _denorm(images[i], min_image_val, max_image_val) | |
ax = axs[i] if root_n > 1 else axs | |
ax.imshow(img) | |
ax.set_title(labels[i]) | |
ax.set_axis_off() | |
# hide rest of the axes | |
for j in range(i, root_n*root_n): | |
axs[j].set_axis_off() | |
plt.savefig(fn, transparent=False) | |
plt.close() | |
def plot_batch_preds( | |
images, | |
labels, | |
preds, | |
fn, | |
gradcams=None, | |
title="", | |
max_samples=9, | |
figscale:int=2 | |
): | |
""" | |
Plot batch samples with preds and labels | |
Arguments: | |
images (list): images from dataset batch | |
labels (list): labels from dataset batch | |
preds (list): predictions for items | |
fn (str): filepath to save figure to | |
gradcams (grid): gradcam grid (optional) | |
title (str): plot title | |
max_samples (int): max number of samples to plot | |
figscale (int): figure scaling factor. 2=regular, 3=large, 4=larger, etc. | |
""" | |
# limit the number of shown images | |
n_samples = min(len(images), max_samples) | |
assert n_samples > 0 | |
root_n = int(math.ceil(math.sqrt(n_samples))) | |
# leave some space vertically for the labels | |
fig, axs = plt.subplots(root_n, root_n, figsize=(root_n*figscale, (1 + root_n)*figscale)) | |
if root_n > 1: | |
axs = axs.flat | |
fig.suptitle(title,fontsize=20) | |
# find denormalization parameters | |
min_image_val = min([np.array(img).min() for img in images]) | |
max_image_val = max([np.array(img).max() for img in images]) | |
# unpack concatenated gradcam grid | |
if gradcams is None: | |
gradcam_list = [None for _ in range(n_samples)] | |
else: | |
gradcam_list = [] | |
rolling_x = 0; rolling_y = 0 | |
H,W = images[0].shape[:2] | |
orig_h, orig_w = gradcams.shape[:2] | |
while rolling_y < orig_h: | |
while rolling_x < orig_w: | |
gradcam_list.append(gradcams[rolling_y:rolling_y+H, rolling_x:rolling_x+W]) | |
rolling_x += W | |
rolling_x = 0 | |
rolling_y += H | |
for i in range(n_samples): | |
img = _denorm(images[i], min_image_val, max_image_val) | |
ax = axs[i] if root_n > 1 else axs | |
# plot a single channel of the image | |
ax.imshow(img[:,:,0], cmap='gray') | |
# overlay gradcam activations | |
if gradcam_list[i] is not None: | |
H,W = img.shape[:2] | |
ax.imshow(gradcam_list[i], alpha=0.5, extent=(0,H,W,0), interpolation='bilinear') | |
ax.set_title(f'GT:{labels[i]}\nPred:{preds[i]}') | |
ax.set_axis_off() | |
# hide rest of the axes | |
for j in range(i, root_n*root_n): | |
axs[j].set_axis_off() | |
plt.savefig(fn, transparent=False) | |
plt.close() | |
def _denorm(img, min_image_val, max_image_val): | |
""" Denormalize image by a min and max value """ | |
if min_image_val < 0: | |
if max_image_val > 1: # [-127,127] | |
img = (img + 127) / 255. | |
else: # [-1,1] | |
img = (img + 1.) / 2. | |
elif max_image_val > 1: # no scaling | |
img /= 255. | |
else: # [0,1] | |
pass | |
return np.clip(img, 0,1) | |
def plot_history(history, fn, figscale=5): | |
metric_keys = history.history.keys() | |
# metric keys contain val_ prefixes for validation set, we want to plot these in the same graph with train | |
train_keys = [key for key in metric_keys if 'val_' not in key] | |
fig, axs = plt.subplots(len(train_keys), 1, figsize=(figscale, figscale*len(train_keys))) | |
for i, train_key in enumerate(train_keys): | |
val_key = 'val_' + train_key | |
assert val_key in metric_keys | |
train_metric = history.history[train_key] | |
val_metric = history.history[val_key] | |
ax = axs[i] if len(train_keys) > 1 else axs | |
ax.plot(train_metric, label=f'Training {train_key}') | |
ax.plot(val_metric, label=f'Validation {train_key}') | |
ax.legend() | |
ax.set_ylabel(f'{train_key}') | |
ax.set_title(f'Training and Validation {train_key}') | |
plt.savefig(fn, transparent=False) | |
plt.close() | |
def plot_confusion_matrix(trues, | |
preds, | |
target_names, | |
title='Confusion matrix', | |
cmap=None, | |
normalize=True, | |
figsize=(10, 10)): | |
""" | |
given a sklearn confusion matrix (cm), make a nice plot | |
Arguments | |
--------- | |
trues: Ground truth array | |
preds: Predicted array | |
target_names: given classification classes such as [0, 1, 2] | |
the class names, for example: ['high', 'medium', 'low'] | |
title: the text to display at the top of the matrix | |
cmap: the gradient of the values displayed from matplotlib.pyplot.cm | |
see http://matplotlib.org/examples/color/colormaps_reference.html | |
plt.get_cmap('jet') or plt.cm.Blues | |
normalize: If False, plot the raw numbers | |
If True, plot the proportions | |
figsize: tuple of matplotlib figure size | |
""" | |
cm = confusion_matrix(trues, preds) | |
accuracy = np.trace(cm) / np.sum(cm).astype('float') | |
balanced_acc = balanced_accuracy_score(trues, preds) | |
misclass = 1 - accuracy | |
# calculate precision and recall for each class | |
class_scores = '\n\n' | |
for i, cl in enumerate(target_names): | |
cl_trues = np.where(np.array(trues) == i, 1, 0) | |
cl_preds = np.where(np.array(preds) == i, 1, 0) | |
precision = precision_score(cl_trues, cl_preds) | |
recall = recall_score(cl_trues, cl_preds) | |
class_scores += cl + ' precision {:.3f}, recall {:.3f}'.format(precision, recall) + '\n' | |
if cmap is None: | |
cmap = plt.get_cmap('Blues') | |
plt.figure(figsize=figsize) | |
plt.imshow(cm, interpolation='nearest', cmap=cmap) | |
plt.title(title + class_scores) | |
plt.colorbar() | |
if target_names is not None: | |
tick_marks = np.arange(len(target_names)) | |
plt.xticks(tick_marks, target_names, rotation=45) | |
plt.yticks(tick_marks, target_names) | |
if normalize: | |
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] | |
thresh = cm.max() / 1.5 if normalize else cm.max() / 2 | |
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): | |
if normalize: | |
plt.text(j, i, "{:0.4f}".format(cm[i, j]), | |
horizontalalignment="center", | |
color="white" if cm[i, j] > thresh else "black") | |
else: | |
plt.text(j, i, "{:,}".format(cm[i, j]), | |
horizontalalignment="center", | |
color="white" if cm[i, j] > thresh else "black") | |
plt.tight_layout() | |
plt.ylabel('True label') | |
plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}; balanced accuracy={:0.4f}'.format(accuracy, misclass, balanced_acc)) | |
def plot_surface_v1(heat_map): | |
# create the x and y coordinate arrays (here we just use pixel indices) | |
heat_gray = cv2.cvtColor(heat_map, cv2.COLOR_BGR2GRAY) | |
lena = heat_gray | |
xx, yy = np.mgrid[0:lena.shape[0], 0:lena.shape[1]] | |
# create the figure | |
fig = plt.figure() | |
ax = fig.gca(projection='3d') | |
# ax2 = fig.add_subplot(1,2,2,projection='3d') | |
surf = ax.plot_surface(xx, yy, lena ,rstride=1, cstride=1, cmap=plt.cm.jet, | |
linewidth=0) | |
# ax.invert_zaxis() | |
# ax.view_init(20,60) | |
ax.view_init(30,60) | |
fig.colorbar(surf, shrink=0.5, aspect=5) | |
# show it | |
plt.show() | |
def plot_surface(fig,heat_map,cnt): | |
# create the x and y coordinate arrays (here we just use pixel indices) | |
heat_gray = cv2.cvtColor(heat_map, cv2.COLOR_BGR2GRAY) | |
lena = heat_gray | |
# lena[:60,100:] = heat_gray[:60,:80] | |
xx, yy = np.mgrid[0:lena.shape[0], 0:lena.shape[1]] | |
# create the figure | |
# fig = plt.figure() | |
# ax = fig.gca(projection='3d') | |
ax = fig.add_subplot(6,5,cnt,projection='3d') | |
surf = ax.plot_surface(xx, yy, lena ,rstride=1, cstride=1, cmap=plt.cm.jet, | |
linewidth=0) | |
# ax.invert_zaxis() | |
ax.view_init(30,60) | |
fig.colorbar(surf, shrink=0.5, aspect=5) | |
def visualize(anchor, positive, negative): | |
"""Visualize a few triplets from the supplied batches.""" | |
def show(ax, image): | |
ax.imshow(image) | |
ax.get_xaxis().set_visible(False) | |
ax.get_yaxis().set_visible(False) | |
fig = plt.figure(figsize=(9, 9)) | |
axs = fig.subplots(3, 3) | |
for i in range(3): | |
show(axs[i, 0], anchor[i]) | |
show(axs[i, 1], positive[i]) | |
show(axs[i, 2], negative[i]) | |
def draw_cv2_bbox(img,bboxes,CLASSES,scores=None): | |
LABELS = ["Background","Bed","Chair","Wheelchair"] | |
COLORS = [(0,0,0),(255, 0, 0),(255, 128, 0),(0,255,00)] | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
fontScale = 0.5 | |
thickness = 1 | |
# CLASSES = CLASSES | |
# boxes = request_response['boxes'] | |
# masks = np.asarray(request_response['masks']) | |
# print(bboxes) | |
for i, box in enumerate(bboxes): | |
class_id = CLASSES[i] | |
class_text = LABELS[class_id] | |
# pts = results['detection_boxes'][0][i3] | |
xmin = int(round(box[0])) | |
ymin = int(round(box[1])) | |
xmax = int(round(box[2])) | |
ymax = int(round(box[3])) | |
# print(xmin,ymin,xmax,ymax) | |
color = COLORS[class_id] | |
img = cv2.rectangle(img,(xmin,ymin),(xmax,ymax),color,1) | |
if not scores: | |
score = scores[i] | |
tmp_txt = f'{score:.3f}' | |
img = cv2.putText(img,tmp_txt, (xmin,ymin+5), font, fontScale, color, thickness, cv2.LINE_AA, False) | |
def draw_bbox(img,bbox, color=(255,0,00)): | |
# pts = results['detection_boxes'][0][i3] | |
xmin = bbox[0] | |
ymin = bbox[1] | |
xmax = bbox[2] | |
ymax = bbox[3] | |
img = cv2.rectangle(img,(xmin,ymin),(xmax,ymax),color,1) | |
return img | |
def consctruct_bbox(xc,yc,width=40,height=30): | |
# tmp_ar = 30 | |
xmin = xc - width+10 | |
xmax = xc + width+10 | |
ymin = yc - height | |
ymax = yc + height | |
return [xmin,ymin,xmax,ymax] | |
# def draw_bbox(img,bbox, color=(255,0,00)): | |
# # pts = results['detection_boxes'][0][i3] | |
# xmin = bbox[0] | |
# ymin = bbox[1] | |
# xmax = bbox[2] | |
# ymax = bbox[3] | |
# img = cv2.rectangle(img,(xmin,ymin),(xmax,ymax),color,1) | |
# return img |