Build error
Build error
File size: 5,124 Bytes
d6466d7 f794606 d6466d7 f7012f0 d6466d7 f7012f0 d6466d7 f7012f0 d6466d7 f7012f0 d6466d7 ae5dd93 d6466d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
# Imports
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
import streamlit as st
from app_utils import *
# The functions (except main) are taken straight from Keras Example
def compute_loss(feature_extractor, input_image, filter_index):
activation = feature_extractor(input_image)
# We avoid border artifacts by only involving non-border pixels in the loss.
filter_activation = activation[:, 2:-2, 2:-2, filter_index]
return tf.reduce_mean(filter_activation)
def gradient_ascent_step(feature_extractor, img, filter_index, learning_rate):
with tf.GradientTape() as tape:
loss = compute_loss(feature_extractor, img, filter_index)
# Compute gradients.
grads = tape.gradient(loss, img)
# Normalize gradients.
grads = tf.math.l2_normalize(grads)
img += learning_rate * grads
return loss, img
def initialize_image():
# We start from a gray image with some random noise
img = tf.random.uniform((1, IMG_WIDTH, IMG_HEIGHT, 3))
# ResNet50V2 expects inputs in the range [-1, +1].
# Here we scale our random inputs to [-0.125, +0.125]
return (img - 0.5) * 0.25
def visualize_filter(feature_extractor, filter_index):
# We run gradient ascent for 20 steps
img = initialize_image()
for _ in range(ITERATIONS):
loss, img = gradient_ascent_step(
feature_extractor, img, filter_index, LEARNING_RATE
# Decode the resulting input image
img = deprocess_image(img[0].numpy())
return loss, img
def deprocess_image(img):
# Normalize array: center on 0., ensure variance is 0.15
img -= img.mean()
img /= img.std() + 1e-5
img *= 0.15
# Center crop
img = img[25:-25, 25:-25, :]
# Clip to [0, 1]
img += 0.5
img = np.clip(img, 0, 1)
# Convert to RGB array
img *= 255
img = np.clip(img, 0, 255).astype("uint8")
return img
# The visualization function
def main():
# Initialize states
# Model selector
mn_option = st.selectbox("Select the model for visualization -", AVAILABLE_MODELS)
# Check to not load the model for ever layer change
if mn_option != st.session_state.model_name:
model = getattr(keras.applications, mn_option)(
weights="imagenet", include_top=False
st.session_state.layer_list = ["<select layer>"] + [ for layer in model.layers
st.session_state.model = model
st.session_state.model_name = mn_option
# Layer selector, saves the feature selector in case 64 filters are to be seen
if st.session_state.model_name:
ln_option = st.selectbox(
"Select the target layer (best to pick somewhere in the middle of the model) -",
if ln_option != "<select layer>":
if ln_option != st.session_state.layer_name:
layer = st.session_state.model.get_layer(name=ln_option)
st.session_state.feat_extract = keras.Model(
inputs=st.session_state.model.inputs, outputs=layer.output
st.session_state.layer_name = ln_option
# Filter index selector
if st.session_state.layer_name:
warn_ph = st.empty()
layer_ph = st.empty()
filter_select = st.selectbox("Visualize -", VIS_OPTION.keys())
if VIS_OPTION[filter_select] == 0:
loss, img = visualize_filter(st.session_state.feat_extract, 0)
layer = st.session_state.model.get_layer(name=st.session_state.layer_name)
num_filters = layer.get_output_at(0).get_shape().as_list()[-1]
":exclamation: Calculating the gradients can take a while.."
if num_filters < 64:
f"{st.session_state.layer_name} has only {num_filters} filters, visualizing only those filters.."
prog_bar = st.progress(0)
fig, axis = plt.subplots(nrows=8, ncols=8, figsize=(14, 14))
for filter_index, ax in enumerate(axis.ravel()[: min(num_filters, 64)]):
prog_bar.progress((filter_index + 1) / min(num_filters, 64))
loss, img = visualize_filter(
st.session_state.feat_extract, filter_index
ax.set_title(filter_index + 1)
for ax in axis.ravel()[num_filters:]:
if __name__ == "__main__":
with open("model_names.txt", "r") as op:
AVAILABLE_MODELS = [i.strip() for i in op.readlines()]