File size: 3,416 Bytes
1d7c63d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5343e6
 
 
1d7c63d
0c61c42
c5343e6
 
0c61c42
 
 
 
1d7c63d
 
 
0c61c42
 
 
 
1d7c63d
 
0c61c42
1d7c63d
0c61c42
1d7c63d
 
 
 
 
0c61c42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d7c63d
0c61c42
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import xplique
import tensorflow as tf
from xplique.attributions import (Saliency, GradientInput, IntegratedGradients, SmoothGrad, VarGrad,
                                  SquareGrad, GradCAM, Occlusion, Rise, GuidedBackprop,
                                  GradCAMPP, Lime, KernelShap)

import numpy as np
import matplotlib.pyplot as plt
from inference_resnet import inference_resnet_finer, preprocess, _clever_crop
BATCH_SIZE = 1

def show(img, p=False, **kwargs):
    img = np.array(img, dtype=np.float32)

    # check if channel first
    if img.shape[0] == 1:
        img = img[0]

    # check if cmap
    if img.shape[-1] == 1:
        img = img[:,:,0]
    elif img.shape[-1] == 3:
        img = img[:,:,::-1]

    # normalize
    if img.max() > 1 or img.min() < 0:
        img -= img.min(); img/=img.max()
    # check if clip percentile
    if p is not False:
        img = np.clip(img, np.percentile(img, p), np.percentile(img, 100-p))
    plt.imshow(img, **kwargs)
    plt.axis('off')
    
    #return img
    


def explain(model, input_image,size=600, n_classes=171) :
    """
    Generate explanations for a given model and dataset.
    :param model: The model to explain.
    :param X: The dataset.
    :param Y: The labels.
    :param explainer: The explainer to use.
    :param batch_size: The batch size to use.
    :return: The explanations.
    """
    
    # we only need the classification part of the model
    class_model = tf.keras.Model(model.input, model.output[1]) 
    
    explainers = [
             #Saliency(class_model),
             #IntegratedGradients(class_model, steps=50, batch_size=BATCH_SIZE),
             #SmoothGrad(class_model, nb_samples=50, batch_size=BATCH_SIZE),
             #GradCAM(class_model),
             Rise(class_model,nb_samples = 50, batch_size = BATCH_SIZE,grid_size=15,
                 preservation_probability=0.5)
             #
    ]
    explainer = Rise(class_model,nb_samples = 50, batch_size = BATCH_SIZE,grid_size=15,
                 preservation_probability=0.5)
  
    cropped,repetitions = _clever_crop(input_image,(size,size))
    size_repetitions = int(size//(repetitions.numpy()+1))
    X = preprocess(cropped,size=size)
    predictions = class_model.predict(np.array([X]))
    #Y = np.argmax(predictions)
    top_5_indices = np.argsort(predictions[0])[-5:][::-1]
    #print(top_5_indices)
    X = np.expand_dims(X, 0)
    explanations = []
    for i,Y in enumerate(top_5_indices):
        Y = tf.one_hot([Y], n_classes)
        print(f'{i}/{len(top_5_indices)}')
        phi = np.abs(explainer(X, Y))[0]
        if len(phi.shape) == 3:
            phi = np.mean(phi, -1)
        show(X[0][:,size_repetitions:2*size_repetitions,:])
        show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
        plt.savefig(f'phi_{i}.png')
        explanations.append(f'phi_{i}.png')
    avg=[]
    for i,Y in enumerate(top_5_indices):
        Y = tf.one_hot([Y], n_classes)
        print(f'{i}/{len(top_5_indices)}')
        phi = np.abs(explainer(X, Y))[0]
        if len(phi.shape) == 3:
            phi = np.mean(phi, -1)
        show(X[0][:,size_repetitions:2*size_repetitions,:])
    show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
    plt.savefig(f'phi_6.png')
    avg.append(f'phi_6.png')

    print('Done')
    if len(explanations)==1:
        explanations = explanations[0]
    
    return explanations,avg