Spaces:
Build error
Build error
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Imports
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
|
5 |
+
import tensorflow as tf
|
6 |
+
from tensorflow import keras
|
7 |
+
|
8 |
+
import streamlit as st
|
9 |
+
|
10 |
+
from app_utils import *
|
11 |
+
|
12 |
+
# The functions (except main) are taken straight from Keras Example
|
13 |
+
def compute_loss(feature_extractor, input_image, filter_index):
|
14 |
+
activation = feature_extractor(input_image)
|
15 |
+
# We avoid border artifacts by only involving non-border pixels in the loss.
|
16 |
+
filter_activation = activation[:, 2:-2, 2:-2, filter_index]
|
17 |
+
return tf.reduce_mean(filter_activation)
|
18 |
+
|
19 |
+
|
20 |
+
@tf.function
|
21 |
+
def gradient_ascent_step(feature_extractor, img, filter_index, learning_rate):
|
22 |
+
with tf.GradientTape() as tape:
|
23 |
+
tape.watch(img)
|
24 |
+
loss = compute_loss(feature_extractor, img, filter_index)
|
25 |
+
# Compute gradients.
|
26 |
+
grads = tape.gradient(loss, img)
|
27 |
+
# Normalize gradients.
|
28 |
+
grads = tf.math.l2_normalize(grads)
|
29 |
+
img += learning_rate * grads
|
30 |
+
return loss, img
|
31 |
+
|
32 |
+
|
33 |
+
def initialize_image():
|
34 |
+
# We start from a gray image with some random noise
|
35 |
+
img = tf.random.uniform((1, IMG_WIDTH, IMG_HEIGHT, 3))
|
36 |
+
# ResNet50V2 expects inputs in the range [-1, +1].
|
37 |
+
# Here we scale our random inputs to [-0.125, +0.125]
|
38 |
+
return (img - 0.5) * 0.25
|
39 |
+
|
40 |
+
|
41 |
+
def visualize_filter(feature_extractor, filter_index):
|
42 |
+
# We run gradient ascent for 20 steps
|
43 |
+
img = initialize_image()
|
44 |
+
for _ in range(ITERATIONS):
|
45 |
+
loss, img = gradient_ascent_step(
|
46 |
+
feature_extractor, img, filter_index, LEARNING_RATE
|
47 |
+
)
|
48 |
+
|
49 |
+
# Decode the resulting input image
|
50 |
+
img = deprocess_image(img[0].numpy())
|
51 |
+
return loss, img
|
52 |
+
|
53 |
+
|
54 |
+
def deprocess_image(img):
|
55 |
+
# Normalize array: center on 0., ensure variance is 0.15
|
56 |
+
img -= img.mean()
|
57 |
+
img /= img.std() + 1e-5
|
58 |
+
img *= 0.15
|
59 |
+
|
60 |
+
# Center crop
|
61 |
+
img = img[25:-25, 25:-25, :]
|
62 |
+
|
63 |
+
# Clip to [0, 1]
|
64 |
+
img += 0.5
|
65 |
+
img = np.clip(img, 0, 1)
|
66 |
+
|
67 |
+
# Convert to RGB array
|
68 |
+
img *= 255
|
69 |
+
img = np.clip(img, 0, 255).astype("uint8")
|
70 |
+
return img
|
71 |
+
|
72 |
+
|
73 |
+
# The visualization function
|
74 |
+
def main():
|
75 |
+
# Model selector
|
76 |
+
mn_option = st.selectbox("Select the model for visualization -", AVAILABLE_MODELS)
|
77 |
+
|
78 |
+
# Check to not load the model for ever layer change
|
79 |
+
if mn_option != st.session_state.model_name:
|
80 |
+
model = getattr(keras.applications, mn_option)(
|
81 |
+
weights="imagenet", include_top=False
|
82 |
+
)
|
83 |
+
st.session_state.layer_list = ["<select layer>"] + [
|
84 |
+
layer.name for layer in model.layers
|
85 |
+
]
|
86 |
+
st.session_state.model = model
|
87 |
+
st.session_state.model_name = mn_option
|
88 |
+
|
89 |
+
# Layer selector, saves the feature selector in case 64 filters are to be seen
|
90 |
+
if st.session_state.model_name:
|
91 |
+
ln_option = st.selectbox(
|
92 |
+
"Select the target layer (best to pick somewhere in the middle of the model) -",
|
93 |
+
st.session_state.layer_list,
|
94 |
+
)
|
95 |
+
if ln_option != "<select layer>":
|
96 |
+
if ln_option != st.session_state.layer_name:
|
97 |
+
layer = st.session_state.model.get_layer(name=ln_option)
|
98 |
+
st.session_state.feat_extract = keras.Model(
|
99 |
+
inputs=st.session_state.model.inputs, outputs=layer.output
|
100 |
+
)
|
101 |
+
st.session_state.layer_name = ln_option
|
102 |
+
|
103 |
+
# Filter index selector
|
104 |
+
if st.session_state.layer_name:
|
105 |
+
filter_select = st.selectbox("Visualize -", VIS_OPTION.keys())
|
106 |
+
|
107 |
+
if VIS_OPTION[filter_select] == 0:
|
108 |
+
loss, img = visualize_filter(st.session_state.feat_extract, 0)
|
109 |
+
st.image(img)
|
110 |
+
else:
|
111 |
+
st.warning(":exclamation: Calculating the gradients can take a while..")
|
112 |
+
prog_bar = st.progress(0)
|
113 |
+
fig, axis = plt.subplots(nrows=8, ncols=8, figsize=(14, 14))
|
114 |
+
for filter_index, ax in enumerate(axis.ravel()):
|
115 |
+
prog_bar.progress((filter_index + 1) / 64)
|
116 |
+
if filter_index < 65:
|
117 |
+
loss, img = visualize_filter(
|
118 |
+
st.session_state.feat_extract, filter_index
|
119 |
+
)
|
120 |
+
ax.imshow(img)
|
121 |
+
ax.set_title(filter_index + 1)
|
122 |
+
ax.set_axis_off()
|
123 |
+
else:
|
124 |
+
ax.set_axis_off()
|
125 |
+
|
126 |
+
st.write(fig)
|
127 |
+
|
128 |
+
|
129 |
+
if __name__ == "__main__":
|
130 |
+
|
131 |
+
with open("model_names.txt", "r") as op:
|
132 |
+
AVAILABLE_MODELS = [i.strip() for i in op.readlines()]
|
133 |
+
|
134 |
+
st.set_page_config(layout="wide")
|
135 |
+
|
136 |
+
st.title(title)
|
137 |
+
st.write(info_text)
|
138 |
+
st.info(f"{credits}\n\n{replicate}\n\n{vit_info}")
|
139 |
+
|
140 |
+
main()
|