File size: 4,607 Bytes
d6466d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f794606
 
 
d6466d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae5dd93
d6466d7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Imports
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras

import streamlit as st

from app_utils import *

# The functions (except main) are taken straight from Keras Example
def compute_loss(feature_extractor, input_image, filter_index):
    activation = feature_extractor(input_image)
    # We avoid border artifacts by only involving non-border pixels in the loss.
    filter_activation = activation[:, 2:-2, 2:-2, filter_index]
    return tf.reduce_mean(filter_activation)


@tf.function
def gradient_ascent_step(feature_extractor, img, filter_index, learning_rate):
    with tf.GradientTape() as tape:
        tape.watch(img)
        loss = compute_loss(feature_extractor, img, filter_index)
    # Compute gradients.
    grads = tape.gradient(loss, img)
    # Normalize gradients.
    grads = tf.math.l2_normalize(grads)
    img += learning_rate * grads
    return loss, img


def initialize_image():
    # We start from a gray image with some random noise
    img = tf.random.uniform((1, IMG_WIDTH, IMG_HEIGHT, 3))
    # ResNet50V2 expects inputs in the range [-1, +1].
    # Here we scale our random inputs to [-0.125, +0.125]
    return (img - 0.5) * 0.25


def visualize_filter(feature_extractor, filter_index):
    # We run gradient ascent for 20 steps
    img = initialize_image()
    for _ in range(ITERATIONS):
        loss, img = gradient_ascent_step(
            feature_extractor, img, filter_index, LEARNING_RATE
        )

    # Decode the resulting input image
    img = deprocess_image(img[0].numpy())
    return loss, img


def deprocess_image(img):
    # Normalize array: center on 0., ensure variance is 0.15
    img -= img.mean()
    img /= img.std() + 1e-5
    img *= 0.15

    # Center crop
    img = img[25:-25, 25:-25, :]

    # Clip to [0, 1]
    img += 0.5
    img = np.clip(img, 0, 1)

    # Convert to RGB array
    img *= 255
    img = np.clip(img, 0, 255).astype("uint8")
    return img


# The visualization function
def main():
    # Initialize states
    initialize_states()

    # Model selector
    mn_option = st.selectbox("Select the model for visualization -", AVAILABLE_MODELS)

    # Check to not load the model for ever layer change
    if mn_option != st.session_state.model_name:
        model = getattr(keras.applications, mn_option)(
            weights="imagenet", include_top=False
        )
        st.session_state.layer_list = ["<select layer>"] + [
            layer.name for layer in model.layers
        ]
        st.session_state.model = model
        st.session_state.model_name = mn_option

    # Layer selector, saves the feature selector in case 64 filters are to be seen
    if st.session_state.model_name:
        ln_option = st.selectbox(
            "Select the target layer (best to pick somewhere in the middle of the model) -",
            st.session_state.layer_list,
        )
        if ln_option != "<select layer>":
            if ln_option != st.session_state.layer_name:
                layer = st.session_state.model.get_layer(name=ln_option)
                st.session_state.feat_extract = keras.Model(
                    inputs=st.session_state.model.inputs, outputs=layer.output
                )
                st.session_state.layer_name = ln_option

    # Filter index selector
    if st.session_state.layer_name:
        filter_select = st.selectbox("Visualize -", VIS_OPTION.keys())

        if VIS_OPTION[filter_select] == 0:
            loss, img = visualize_filter(st.session_state.feat_extract, 0)
            st.image(img)
        else:
            st.warning(":exclamation: Calculating the gradients can take a while..")
            prog_bar = st.progress(0)
            fig, axis = plt.subplots(nrows=8, ncols=8, figsize=(14, 14))
            for filter_index, ax in enumerate(axis.ravel()):
                prog_bar.progress((filter_index + 1) / 64)
                if filter_index < 65:
                    loss, img = visualize_filter(
                        st.session_state.feat_extract, filter_index
                    )
                    ax.imshow(img)
                    ax.set_title(filter_index + 1)
                    ax.set_axis_off()
                else:
                    ax.set_axis_off()

            st.write(fig)


if __name__ == "__main__":

    with open("model_names.txt", "r") as op:
        AVAILABLE_MODELS = [i.strip() for i in op.readlines()]

    st.set_page_config(layout="wide")

    st.title(title)
    st.write(info_text)
    st.info(f"{credits}\n\n{replicate}\n\n{vit_info}")
    st.write(self_credit)

    main()