fossil_app / explanations.py
Yuxiang Wang
feat:add beit,rise xai;display closest imgs with gallery
0c61c42
raw
history blame
3.42 kB
import xplique
import tensorflow as tf
from xplique.attributions import (Saliency, GradientInput, IntegratedGradients, SmoothGrad, VarGrad,
SquareGrad, GradCAM, Occlusion, Rise, GuidedBackprop,
GradCAMPP, Lime, KernelShap)
import numpy as np
import matplotlib.pyplot as plt
from inference_resnet import inference_resnet_finer, preprocess, _clever_crop
BATCH_SIZE = 1
def show(img, p=False, **kwargs):
img = np.array(img, dtype=np.float32)
# check if channel first
if img.shape[0] == 1:
img = img[0]
# check if cmap
if img.shape[-1] == 1:
img = img[:,:,0]
elif img.shape[-1] == 3:
img = img[:,:,::-1]
# normalize
if img.max() > 1 or img.min() < 0:
img -= img.min(); img/=img.max()
# check if clip percentile
if p is not False:
img = np.clip(img, np.percentile(img, p), np.percentile(img, 100-p))
plt.imshow(img, **kwargs)
plt.axis('off')
#return img
def explain(model, input_image,size=600, n_classes=171) :
"""
Generate explanations for a given model and dataset.
:param model: The model to explain.
:param X: The dataset.
:param Y: The labels.
:param explainer: The explainer to use.
:param batch_size: The batch size to use.
:return: The explanations.
"""
# we only need the classification part of the model
class_model = tf.keras.Model(model.input, model.output[1])
explainers = [
#Saliency(class_model),
#IntegratedGradients(class_model, steps=50, batch_size=BATCH_SIZE),
#SmoothGrad(class_model, nb_samples=50, batch_size=BATCH_SIZE),
#GradCAM(class_model),
Rise(class_model,nb_samples = 50, batch_size = BATCH_SIZE,grid_size=15,
preservation_probability=0.5)
#
]
explainer = Rise(class_model,nb_samples = 50, batch_size = BATCH_SIZE,grid_size=15,
preservation_probability=0.5)
cropped,repetitions = _clever_crop(input_image,(size,size))
size_repetitions = int(size//(repetitions.numpy()+1))
X = preprocess(cropped,size=size)
predictions = class_model.predict(np.array([X]))
#Y = np.argmax(predictions)
top_5_indices = np.argsort(predictions[0])[-5:][::-1]
#print(top_5_indices)
X = np.expand_dims(X, 0)
explanations = []
for i,Y in enumerate(top_5_indices):
Y = tf.one_hot([Y], n_classes)
print(f'{i}/{len(top_5_indices)}')
phi = np.abs(explainer(X, Y))[0]
if len(phi.shape) == 3:
phi = np.mean(phi, -1)
show(X[0][:,size_repetitions:2*size_repetitions,:])
show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
plt.savefig(f'phi_{i}.png')
explanations.append(f'phi_{i}.png')
avg=[]
for i,Y in enumerate(top_5_indices):
Y = tf.one_hot([Y], n_classes)
print(f'{i}/{len(top_5_indices)}')
phi = np.abs(explainer(X, Y))[0]
if len(phi.shape) == 3:
phi = np.mean(phi, -1)
show(X[0][:,size_repetitions:2*size_repetitions,:])
show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
plt.savefig(f'phi_6.png')
avg.append(f'phi_6.png')
print('Done')
if len(explanations)==1:
explanations = explanations[0]
return explanations,avg