Spaces:
Sleeping
Sleeping
import xplique | |
import tensorflow as tf | |
from xplique.attributions import (Saliency, GradientInput, IntegratedGradients, SmoothGrad, VarGrad, | |
SquareGrad, GradCAM, Occlusion, Rise, GuidedBackprop, | |
GradCAMPP, Lime, KernelShap) | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from inference_resnet import inference_resnet_finer, preprocess, _clever_crop | |
BATCH_SIZE = 1 | |
def show(img, p=False, **kwargs): | |
img = np.array(img, dtype=np.float32) | |
# check if channel first | |
if img.shape[0] == 1: | |
img = img[0] | |
# check if cmap | |
if img.shape[-1] == 1: | |
img = img[:,:,0] | |
elif img.shape[-1] == 3: | |
img = img[:,:,::-1] | |
# normalize | |
if img.max() > 1 or img.min() < 0: | |
img -= img.min(); img/=img.max() | |
# check if clip percentile | |
if p is not False: | |
img = np.clip(img, np.percentile(img, p), np.percentile(img, 100-p)) | |
plt.imshow(img, **kwargs) | |
plt.axis('off') | |
#return img | |
def explain(model, input_image,size=600, n_classes=171) : | |
""" | |
Generate explanations for a given model and dataset. | |
:param model: The model to explain. | |
:param X: The dataset. | |
:param Y: The labels. | |
:param explainer: The explainer to use. | |
:param batch_size: The batch size to use. | |
:return: The explanations. | |
""" | |
# we only need the classification part of the model | |
class_model = tf.keras.Model(model.input, model.output[1]) | |
explainers = [ | |
#Saliency(class_model), | |
#IntegratedGradients(class_model, steps=50, batch_size=BATCH_SIZE), | |
#SmoothGrad(class_model, nb_samples=50, batch_size=BATCH_SIZE), | |
#GradCAM(class_model), | |
Rise(class_model,nb_samples = 50, batch_size = BATCH_SIZE,grid_size=15, | |
preservation_probability=0.5) | |
# | |
] | |
explainer = Rise(class_model,nb_samples = 50, batch_size = BATCH_SIZE,grid_size=15, | |
preservation_probability=0.5) | |
cropped,repetitions = _clever_crop(input_image,(size,size)) | |
size_repetitions = int(size//(repetitions.numpy()+1)) | |
X = preprocess(cropped,size=size) | |
predictions = class_model.predict(np.array([X])) | |
#Y = np.argmax(predictions) | |
top_5_indices = np.argsort(predictions[0])[-5:][::-1] | |
#print(top_5_indices) | |
X = np.expand_dims(X, 0) | |
explanations = [] | |
for i,Y in enumerate(top_5_indices): | |
Y = tf.one_hot([Y], n_classes) | |
print(f'{i}/{len(top_5_indices)}') | |
phi = np.abs(explainer(X, Y))[0] | |
if len(phi.shape) == 3: | |
phi = np.mean(phi, -1) | |
show(X[0][:,size_repetitions:2*size_repetitions,:]) | |
show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4) | |
plt.savefig(f'phi_{i}.png') | |
explanations.append(f'phi_{i}.png') | |
avg=[] | |
for i,Y in enumerate(top_5_indices): | |
Y = tf.one_hot([Y], n_classes) | |
print(f'{i}/{len(top_5_indices)}') | |
phi = np.abs(explainer(X, Y))[0] | |
if len(phi.shape) == 3: | |
phi = np.mean(phi, -1) | |
show(X[0][:,size_repetitions:2*size_repetitions,:]) | |
show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4) | |
plt.savefig(f'phi_6.png') | |
avg.append(f'phi_6.png') | |
print('Done') | |
if len(explanations)==1: | |
explanations = explanations[0] | |
return explanations,avg | |