import numpy as np import torch from . import metric class ConfusionMatrix(metric.Metric): """Constructs a confusion matrix for a multi-class classification problems. Does not support multi-label, multi-class problems. Keyword arguments: - num_classes (int): number of classes in the classification problem. - normalized (boolean, optional): Determines whether or not the confusion matrix is normalized or not. Default: False. Modified from: https://github.com/pytorch/tnt/blob/master/torchnet/meter/confusionmeter.py """ def __init__(self, num_classes, normalized=False): super().__init__() self.conf = np.ndarray((num_classes, num_classes), dtype=np.int32) self.normalized = normalized self.num_classes = num_classes self.reset() def reset(self): self.conf.fill(0) def add(self, predicted, target): """Computes the confusion matrix The shape of the confusion matrix is K x K, where K is the number of classes. Keyword arguments: - predicted (Tensor or numpy.ndarray): Can be an N x K tensor/array of predicted scores obtained from the model for N examples and K classes, or an N-tensor/array of integer values between 0 and K-1. - target (Tensor or numpy.ndarray): Can be an N x K tensor/array of ground-truth classes for N examples and K classes, or an N-tensor/array of integer values between 0 and K-1. """ # If target and/or predicted are tensors, convert them to numpy arrays if torch.is_tensor(predicted): predicted = predicted.cpu().numpy() if torch.is_tensor(target): target = target.cpu().numpy() assert predicted.shape[0] == target.shape[0], \ 'number of targets and predicted outputs do not match' if np.ndim(predicted) != 1: assert predicted.shape[1] == self.num_classes, \ 'number of predictions does not match size of confusion matrix' predicted = np.argmax(predicted, 1) else: assert (predicted.max() < self.num_classes) and (predicted.min() >= 0), \ 'predicted values are not between 0 and k-1' if np.ndim(target) != 1: assert target.shape[1] == self.num_classes, \ 'Onehot target does not match size of confusion matrix' assert (target >= 0).all() and (target <= 1).all(), \ 'in one-hot encoding, target values should be 0 or 1' assert (target.sum(1) == 1).all(), \ 'multi-label setting is not supported' target = np.argmax(target, 1) else: assert (target.max() < self.num_classes) and (target.min() >= 0), \ 'target values are not between 0 and k-1' # hack for bincounting 2 arrays together x = predicted + self.num_classes * target bincount_2d = np.bincount( x.astype(np.int32), minlength=self.num_classes**2) assert bincount_2d.size == self.num_classes**2 conf = bincount_2d.reshape((self.num_classes, self.num_classes)) self.conf += conf def value(self): """ Returns: Confustion matrix of K rows and K columns, where rows corresponds to ground-truth targets and columns corresponds to predicted targets. """ if self.normalized: conf = self.conf.astype(np.float32) return conf / conf.sum(1).clip(min=1e-12)[:, None] else: return self.conf