Spaces:
Build error
Build error
File size: 4,607 Bytes
d6466d7 f794606 d6466d7 ae5dd93 d6466d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
# Imports
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
import streamlit as st
from app_utils import *
# The functions (except main) are taken straight from Keras Example
def compute_loss(feature_extractor, input_image, filter_index):
activation = feature_extractor(input_image)
# We avoid border artifacts by only involving non-border pixels in the loss.
filter_activation = activation[:, 2:-2, 2:-2, filter_index]
return tf.reduce_mean(filter_activation)
@tf.function
def gradient_ascent_step(feature_extractor, img, filter_index, learning_rate):
with tf.GradientTape() as tape:
tape.watch(img)
loss = compute_loss(feature_extractor, img, filter_index)
# Compute gradients.
grads = tape.gradient(loss, img)
# Normalize gradients.
grads = tf.math.l2_normalize(grads)
img += learning_rate * grads
return loss, img
def initialize_image():
# We start from a gray image with some random noise
img = tf.random.uniform((1, IMG_WIDTH, IMG_HEIGHT, 3))
# ResNet50V2 expects inputs in the range [-1, +1].
# Here we scale our random inputs to [-0.125, +0.125]
return (img - 0.5) * 0.25
def visualize_filter(feature_extractor, filter_index):
# We run gradient ascent for 20 steps
img = initialize_image()
for _ in range(ITERATIONS):
loss, img = gradient_ascent_step(
feature_extractor, img, filter_index, LEARNING_RATE
)
# Decode the resulting input image
img = deprocess_image(img[0].numpy())
return loss, img
def deprocess_image(img):
# Normalize array: center on 0., ensure variance is 0.15
img -= img.mean()
img /= img.std() + 1e-5
img *= 0.15
# Center crop
img = img[25:-25, 25:-25, :]
# Clip to [0, 1]
img += 0.5
img = np.clip(img, 0, 1)
# Convert to RGB array
img *= 255
img = np.clip(img, 0, 255).astype("uint8")
return img
# The visualization function
def main():
# Initialize states
initialize_states()
# Model selector
mn_option = st.selectbox("Select the model for visualization -", AVAILABLE_MODELS)
# Check to not load the model for ever layer change
if mn_option != st.session_state.model_name:
model = getattr(keras.applications, mn_option)(
weights="imagenet", include_top=False
)
st.session_state.layer_list = ["<select layer>"] + [
layer.name for layer in model.layers
]
st.session_state.model = model
st.session_state.model_name = mn_option
# Layer selector, saves the feature selector in case 64 filters are to be seen
if st.session_state.model_name:
ln_option = st.selectbox(
"Select the target layer (best to pick somewhere in the middle of the model) -",
st.session_state.layer_list,
)
if ln_option != "<select layer>":
if ln_option != st.session_state.layer_name:
layer = st.session_state.model.get_layer(name=ln_option)
st.session_state.feat_extract = keras.Model(
inputs=st.session_state.model.inputs, outputs=layer.output
)
st.session_state.layer_name = ln_option
# Filter index selector
if st.session_state.layer_name:
filter_select = st.selectbox("Visualize -", VIS_OPTION.keys())
if VIS_OPTION[filter_select] == 0:
loss, img = visualize_filter(st.session_state.feat_extract, 0)
st.image(img)
else:
st.warning(":exclamation: Calculating the gradients can take a while..")
prog_bar = st.progress(0)
fig, axis = plt.subplots(nrows=8, ncols=8, figsize=(14, 14))
for filter_index, ax in enumerate(axis.ravel()):
prog_bar.progress((filter_index + 1) / 64)
if filter_index < 65:
loss, img = visualize_filter(
st.session_state.feat_extract, filter_index
)
ax.imshow(img)
ax.set_title(filter_index + 1)
ax.set_axis_off()
else:
ax.set_axis_off()
st.write(fig)
if __name__ == "__main__":
with open("model_names.txt", "r") as op:
AVAILABLE_MODELS = [i.strip() for i in op.readlines()]
st.set_page_config(layout="wide")
st.title(title)
st.write(info_text)
st.info(f"{credits}\n\n{replicate}\n\n{vit_info}")
st.write(self_credit)
main()
|