import xplique import tensorflow as tf from xplique.attributions import (Saliency, GradientInput, IntegratedGradients, SmoothGrad, VarGrad, SquareGrad, GradCAM, Occlusion, Rise, GuidedBackprop, GradCAMPP, Lime, KernelShap) import numpy as np import matplotlib.pyplot as plt from inference_resnet import inference_resnet_finer, preprocess, _clever_crop BATCH_SIZE = 1 def show(img, p=False, **kwargs): img = np.array(img, dtype=np.float32) # check if channel first if img.shape[0] == 1: img = img[0] # check if cmap if img.shape[-1] == 1: img = img[:,:,0] elif img.shape[-1] == 3: img = img[:,:,::-1] # normalize if img.max() > 1 or img.min() < 0: img -= img.min(); img/=img.max() # check if clip percentile if p is not False: img = np.clip(img, np.percentile(img, p), np.percentile(img, 100-p)) plt.imshow(img, **kwargs) plt.axis('off') #return img def explain(model, input_image,size=600, n_classes=171) : """ Generate explanations for a given model and dataset. :param model: The model to explain. :param X: The dataset. :param Y: The labels. :param explainer: The explainer to use. :param batch_size: The batch size to use. :return: The explanations. """ # we only need the classification part of the model class_model = tf.keras.Model(model.input, model.output[1]) explainers = [ #Saliency(class_model), #IntegratedGradients(class_model, steps=50, batch_size=BATCH_SIZE), #SmoothGrad(class_model, nb_samples=50, batch_size=BATCH_SIZE), #GradCAM(class_model), Rise(class_model,nb_samples = 50, batch_size = BATCH_SIZE,grid_size=15, preservation_probability=0.5) # ] explainer = Rise(class_model,nb_samples = 50, batch_size = BATCH_SIZE,grid_size=15, preservation_probability=0.5) cropped,repetitions = _clever_crop(input_image,(size,size)) size_repetitions = int(size//(repetitions.numpy()+1)) X = preprocess(cropped,size=size) predictions = class_model.predict(np.array([X])) #Y = np.argmax(predictions) top_5_indices = np.argsort(predictions[0])[-5:][::-1] #print(top_5_indices) X = np.expand_dims(X, 0) explanations = [] for i,Y in enumerate(top_5_indices): Y = tf.one_hot([Y], n_classes) print(f'{i}/{len(top_5_indices)}') phi = np.abs(explainer(X, Y))[0] if len(phi.shape) == 3: phi = np.mean(phi, -1) show(X[0][:,size_repetitions:2*size_repetitions,:]) show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4) plt.savefig(f'phi_{i}.png') explanations.append(f'phi_{i}.png') avg=[] for i,Y in enumerate(top_5_indices): Y = tf.one_hot([Y], n_classes) print(f'{i}/{len(top_5_indices)}') phi = np.abs(explainer(X, Y))[0] if len(phi.shape) == 3: phi = np.mean(phi, -1) show(X[0][:,size_repetitions:2*size_repetitions,:]) show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4) plt.savefig(f'phi_6.png') avg.append(f'phi_6.png') print('Done') if len(explanations)==1: explanations = explanations[0] return explanations,avg