Spaces:
Sleeping
Sleeping
File size: 3,416 Bytes
1d7c63d c5343e6 1d7c63d 0c61c42 c5343e6 0c61c42 1d7c63d 0c61c42 1d7c63d 0c61c42 1d7c63d 0c61c42 1d7c63d 0c61c42 1d7c63d 0c61c42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import xplique
import tensorflow as tf
from xplique.attributions import (Saliency, GradientInput, IntegratedGradients, SmoothGrad, VarGrad,
SquareGrad, GradCAM, Occlusion, Rise, GuidedBackprop,
GradCAMPP, Lime, KernelShap)
import numpy as np
import matplotlib.pyplot as plt
from inference_resnet import inference_resnet_finer, preprocess, _clever_crop
BATCH_SIZE = 1
def show(img, p=False, **kwargs):
img = np.array(img, dtype=np.float32)
# check if channel first
if img.shape[0] == 1:
img = img[0]
# check if cmap
if img.shape[-1] == 1:
img = img[:,:,0]
elif img.shape[-1] == 3:
img = img[:,:,::-1]
# normalize
if img.max() > 1 or img.min() < 0:
img -= img.min(); img/=img.max()
# check if clip percentile
if p is not False:
img = np.clip(img, np.percentile(img, p), np.percentile(img, 100-p))
plt.imshow(img, **kwargs)
plt.axis('off')
#return img
def explain(model, input_image,size=600, n_classes=171) :
"""
Generate explanations for a given model and dataset.
:param model: The model to explain.
:param X: The dataset.
:param Y: The labels.
:param explainer: The explainer to use.
:param batch_size: The batch size to use.
:return: The explanations.
"""
# we only need the classification part of the model
class_model = tf.keras.Model(model.input, model.output[1])
explainers = [
#Saliency(class_model),
#IntegratedGradients(class_model, steps=50, batch_size=BATCH_SIZE),
#SmoothGrad(class_model, nb_samples=50, batch_size=BATCH_SIZE),
#GradCAM(class_model),
Rise(class_model,nb_samples = 50, batch_size = BATCH_SIZE,grid_size=15,
preservation_probability=0.5)
#
]
explainer = Rise(class_model,nb_samples = 50, batch_size = BATCH_SIZE,grid_size=15,
preservation_probability=0.5)
cropped,repetitions = _clever_crop(input_image,(size,size))
size_repetitions = int(size//(repetitions.numpy()+1))
X = preprocess(cropped,size=size)
predictions = class_model.predict(np.array([X]))
#Y = np.argmax(predictions)
top_5_indices = np.argsort(predictions[0])[-5:][::-1]
#print(top_5_indices)
X = np.expand_dims(X, 0)
explanations = []
for i,Y in enumerate(top_5_indices):
Y = tf.one_hot([Y], n_classes)
print(f'{i}/{len(top_5_indices)}')
phi = np.abs(explainer(X, Y))[0]
if len(phi.shape) == 3:
phi = np.mean(phi, -1)
show(X[0][:,size_repetitions:2*size_repetitions,:])
show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
plt.savefig(f'phi_{i}.png')
explanations.append(f'phi_{i}.png')
avg=[]
for i,Y in enumerate(top_5_indices):
Y = tf.one_hot([Y], n_classes)
print(f'{i}/{len(top_5_indices)}')
phi = np.abs(explainer(X, Y))[0]
if len(phi.shape) == 3:
phi = np.mean(phi, -1)
show(X[0][:,size_repetitions:2*size_repetitions,:])
show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
plt.savefig(f'phi_6.png')
avg.append(f'phi_6.png')
print('Done')
if len(explanations)==1:
explanations = explanations[0]
return explanations,avg
|