Spaces:
Build error
Build error
# Imports | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import tensorflow as tf | |
from tensorflow import keras | |
import streamlit as st | |
from app_utils import * | |
# The functions (except main) are taken straight from Keras Example | |
def compute_loss(feature_extractor, input_image, filter_index): | |
activation = feature_extractor(input_image) | |
# We avoid border artifacts by only involving non-border pixels in the loss. | |
filter_activation = activation[:, 2:-2, 2:-2, filter_index] | |
return tf.reduce_mean(filter_activation) | |
def gradient_ascent_step(feature_extractor, img, filter_index, learning_rate): | |
with tf.GradientTape() as tape: | |
tape.watch(img) | |
loss = compute_loss(feature_extractor, img, filter_index) | |
# Compute gradients. | |
grads = tape.gradient(loss, img) | |
# Normalize gradients. | |
grads = tf.math.l2_normalize(grads) | |
img += learning_rate * grads | |
return loss, img | |
def initialize_image(): | |
# We start from a gray image with some random noise | |
img = tf.random.uniform((1, IMG_WIDTH, IMG_HEIGHT, 3)) | |
# ResNet50V2 expects inputs in the range [-1, +1]. | |
# Here we scale our random inputs to [-0.125, +0.125] | |
return (img - 0.5) * 0.25 | |
def visualize_filter(feature_extractor, filter_index): | |
# We run gradient ascent for 20 steps | |
img = initialize_image() | |
for _ in range(ITERATIONS): | |
loss, img = gradient_ascent_step( | |
feature_extractor, img, filter_index, LEARNING_RATE | |
) | |
# Decode the resulting input image | |
img = deprocess_image(img[0].numpy()) | |
return loss, img | |
def deprocess_image(img): | |
# Normalize array: center on 0., ensure variance is 0.15 | |
img -= img.mean() | |
img /= img.std() + 1e-5 | |
img *= 0.15 | |
# Center crop | |
img = img[25:-25, 25:-25, :] | |
# Clip to [0, 1] | |
img += 0.5 | |
img = np.clip(img, 0, 1) | |
# Convert to RGB array | |
img *= 255 | |
img = np.clip(img, 0, 255).astype("uint8") | |
return img | |
# The visualization function | |
def main(): | |
# Initialize states | |
initialize_states() | |
# Model selector | |
mn_option = st.selectbox("Select the model for visualization -", AVAILABLE_MODELS) | |
# Check to not load the model for ever layer change | |
if mn_option != st.session_state.model_name: | |
model = getattr(keras.applications, mn_option)( | |
weights="imagenet", include_top=False | |
) | |
st.session_state.layer_list = ["<select layer>"] + [ | |
layer.name for layer in model.layers | |
] | |
st.session_state.model = model | |
st.session_state.model_name = mn_option | |
# Layer selector, saves the feature selector in case 64 filters are to be seen | |
if st.session_state.model_name: | |
ln_option = st.selectbox( | |
"Select the target layer (best to pick somewhere in the middle of the model) -", | |
st.session_state.layer_list, | |
) | |
if ln_option != "<select layer>": | |
if ln_option != st.session_state.layer_name: | |
layer = st.session_state.model.get_layer(name=ln_option) | |
st.session_state.feat_extract = keras.Model( | |
inputs=st.session_state.model.inputs, outputs=layer.output | |
) | |
st.session_state.layer_name = ln_option | |
# Filter index selector | |
if st.session_state.layer_name: | |
filter_select = st.selectbox("Visualize -", VIS_OPTION.keys()) | |
if VIS_OPTION[filter_select] == 0: | |
loss, img = visualize_filter(st.session_state.feat_extract, 0) | |
st.image(img) | |
else: | |
st.warning(":exclamation: Calculating the gradients can take a while..") | |
prog_bar = st.progress(0) | |
fig, axis = plt.subplots(nrows=8, ncols=8, figsize=(14, 14)) | |
for filter_index, ax in enumerate(axis.ravel()): | |
prog_bar.progress((filter_index + 1) / 64) | |
if filter_index < 65: | |
loss, img = visualize_filter( | |
st.session_state.feat_extract, filter_index | |
) | |
ax.imshow(img) | |
ax.set_title(filter_index + 1) | |
ax.set_axis_off() | |
else: | |
ax.set_axis_off() | |
st.write(fig) | |
if __name__ == "__main__": | |
with open("model_names.txt", "r") as op: | |
AVAILABLE_MODELS = [i.strip() for i in op.readlines()] | |
st.set_page_config(layout="wide") | |
st.title(title) | |
st.write(info_text) | |
st.info(f"{credits}\n\n{replicate}\n\n{vit_info}") | |
st.write(self_credit) | |
main() | |