Spaces:
Running
Running
Upload decisions_users.py
Browse files- utils/decisions_users.py +315 -0
utils/decisions_users.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import json
|
4 |
+
import shutil
|
5 |
+
import urllib.request
|
6 |
+
from pathlib import Path
|
7 |
+
import pathlib
|
8 |
+
import time
|
9 |
+
import urllib
|
10 |
+
from ast import literal_eval
|
11 |
+
import albumentations as A
|
12 |
+
import tensorflow as tf
|
13 |
+
|
14 |
+
import cv2
|
15 |
+
import numpy as np
|
16 |
+
import pandas as pd
|
17 |
+
import plotly.express as px
|
18 |
+
import matplotlib.pyplot as plt
|
19 |
+
from PIL import Image
|
20 |
+
|
21 |
+
import streamlit as st
|
22 |
+
|
23 |
+
import seaborn as sns
|
24 |
+
|
25 |
+
sys.path.append(f'{os.getcwd()}/utils')
|
26 |
+
from utils.eval_users import get_product_dev_page_layout
|
27 |
+
|
28 |
+
# print(os.getcwd())
|
29 |
+
|
30 |
+
# Hide GPU from visible devices
|
31 |
+
tf.config.set_visible_devices([], 'GPU')
|
32 |
+
# Enable GPU memory growth - avoid allocating all memory at start
|
33 |
+
# gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
|
34 |
+
# for gpu in gpus:
|
35 |
+
# tf.config.experimental.set_memory_growth(device=gpu, enable=True)
|
36 |
+
|
37 |
+
from utils.control import show_tsne_vis,show_random_samples
|
38 |
+
|
39 |
+
from utils.annoy_sampling import load_annoy_tree
|
40 |
+
|
41 |
+
|
42 |
+
from utils.model_utils import load_model
|
43 |
+
from utils.model_utils import get_feature_vector, get_feature_extractor_model, get_predictions_and_roi
|
44 |
+
|
45 |
+
sns.set_style('darkgrid')
|
46 |
+
plt.rcParams['axes.grid'] = False
|
47 |
+
|
48 |
+
# st.set_page_config(layout="wide")
|
49 |
+
|
50 |
+
#https://github.com/IliaLarchenko/albumentations-demo/blob/3cb6528a513fe3b35dbb2c2a63cdcfbb9bb2a932/src/utils.py#L149
|
51 |
+
|
52 |
+
GRAD_CAM_IMAGE_DIR = f'{os.getcwd()}/data/gradcam_vis_data/'
|
53 |
+
TEST_CSV_FILE = f'{os.getcwd()}/data/test_set_pred_prop.csv'
|
54 |
+
|
55 |
+
annoy_tree_save_path =f'{os.getcwd()}/data/filtered_train_embedding/representative_samples_emb.annoy'
|
56 |
+
test_emb_path = f'{os.getcwd()}/data/filtered_train_embedding/test_embeddings.npy'
|
57 |
+
test_emb_id_path =f'{os.getcwd()}/data/filtered_train_embedding/test_ids.npy'
|
58 |
+
|
59 |
+
train_emb_id_path =f'{os.getcwd()}/data/filtered_train_embedding/representative_train_ids.npy'
|
60 |
+
|
61 |
+
repr_id_path =f'{os.getcwd()}/data/filtered_train_embedding/representative_train_ids.npy'
|
62 |
+
borderline_id_path =f'{os.getcwd()}/data/filtered_train_embedding/borderline_train_ids.npy'
|
63 |
+
|
64 |
+
MODEL_PATH = f'{os.getcwd()}/model/keras_model_0422/'
|
65 |
+
|
66 |
+
ROOT_FIG_DIR = f'{os.getcwd()}/figures/'
|
67 |
+
|
68 |
+
test_emb = np.load(test_emb_path)
|
69 |
+
test_ids = np.load(test_emb_id_path)
|
70 |
+
test_id_list = list(test_ids)
|
71 |
+
test_labels = [_id.split("\\")[1] for _id in test_id_list]
|
72 |
+
print(" NUmber of test samples: ",len(test_labels))
|
73 |
+
test_features = test_emb.reshape(-1,1792)
|
74 |
+
|
75 |
+
|
76 |
+
# train embedding list
|
77 |
+
train_ids = np.load(train_emb_id_path)
|
78 |
+
train_id_list = list(train_ids)
|
79 |
+
train_labels = [_id.split("\\")[1] for _id in train_id_list]
|
80 |
+
print(" NUmber of training samples: ",len(train_labels))
|
81 |
+
|
82 |
+
annoy_tree = load_annoy_tree(test_features.shape[1],annoy_tree_save_path)
|
83 |
+
def annoy_matching(annoy_f,query_item, query_index, n=10):
|
84 |
+
return annoy_f.get_nns_by_vector(query_item, n)
|
85 |
+
|
86 |
+
def get_img(fn ,thumbnail=False):
|
87 |
+
img = Image.open(fn)
|
88 |
+
if thumbnail:
|
89 |
+
img.thumbnail((100,100))
|
90 |
+
return img
|
91 |
+
|
92 |
+
|
93 |
+
def plot_n_similar(seed_id,similar_ids, test_path,n=10, scale=5):
|
94 |
+
f,ax = plt.subplots(1,n+1,figsize=((n+1)*scale,scale))
|
95 |
+
# print(os.path.basename(test_labels[seed_id])[:-4])
|
96 |
+
title = "SEED ID:{0}\nLabel:{1}".format(seed_id,os.path.basename(test_labels[seed_id]))
|
97 |
+
# print("path:", test_labels[seed_id].replace("F:/","E:/"))
|
98 |
+
ax[0].imshow(get_img(test_path.replace("F:/","E:/")),cmap='gray')
|
99 |
+
ax[0].set_title(title,fontsize=12)
|
100 |
+
for i in range(len(similar_ids)):
|
101 |
+
ax[i+1].imshow(get_img(similar_ids[i].replace("F:/","E:/")),cmap='gray')
|
102 |
+
title = "ID:{0}\nDistance: {1:.3f}\nLabel:{2}".format(i,0.1223,os.path.basename(similar_ids[i])[:-4])
|
103 |
+
ax[i+1].set_title(title,fontsize=10)
|
104 |
+
f.suptitle("Images similar to seed_id {0}".format(seed_id),fontsize=18)
|
105 |
+
plt.subplots_adjust(top=0.5)
|
106 |
+
return f
|
107 |
+
|
108 |
+
def load_image(filename,change_url=True):
|
109 |
+
# if change_url:
|
110 |
+
print(filename)
|
111 |
+
print(os.path.exists(filename))
|
112 |
+
img = cv2.imread(filename)
|
113 |
+
return img
|
114 |
+
|
115 |
+
@st.cache(allow_output_mutation=True)
|
116 |
+
def get_model(model_path):
|
117 |
+
new_model = tf.keras.models.load_model(model_path)
|
118 |
+
# keras_model = load_model(model_path)
|
119 |
+
return new_model
|
120 |
+
|
121 |
+
@st.cache(allow_output_mutation=True)
|
122 |
+
def get_feature_vector_model(model_path):
|
123 |
+
keras_model = tf.keras.models.load_model(model_path)
|
124 |
+
feature_extractor = tf.keras.Model(keras_model.inputs,keras_model.layers[-3].output)
|
125 |
+
return feature_extractor
|
126 |
+
|
127 |
+
def load_pd_data_frame(df_csv_path):
|
128 |
+
return pd.read_csv(df_csv_path)
|
129 |
+
|
130 |
+
def get_path_list_from_df(df_data):
|
131 |
+
return list(df_data['path'])
|
132 |
+
|
133 |
+
def get_class_probs_from_df(df_data):
|
134 |
+
return list(df_data['class_probs'])
|
135 |
+
|
136 |
+
|
137 |
+
def visualize_bar_plot(df_data):
|
138 |
+
fig = px.bar(df_data, x="probability", y="class", orientation='h')
|
139 |
+
return fig
|
140 |
+
|
141 |
+
def run_instance_exp(img_path, img_path_list,prob_list,grad_vis_path_list):
|
142 |
+
|
143 |
+
st.subheader('Instance Exploration')
|
144 |
+
# st.columns((1,1,1)) with row4_2:
|
145 |
+
LABELS = ['CNV', 'DRUSEN', 'DME', 'NORMAL']
|
146 |
+
|
147 |
+
left_column, middle_column, right_column = st.columns((1,1,1))
|
148 |
+
display_image = load_image(img_path)
|
149 |
+
# fig = px.imshow(display_image)
|
150 |
+
# left_column.plotly_chart(fig, use_container_width=True)
|
151 |
+
left_column.image(cv2.resize(display_image, (180,180)),caption = "Selected Input")
|
152 |
+
|
153 |
+
# get class probabilities
|
154 |
+
indx = img_path_list.index(img_path)
|
155 |
+
print(img_path)
|
156 |
+
prb_tmp = prob_list[indx]
|
157 |
+
print(f"{prb_tmp[1:-1]}")
|
158 |
+
clss_probs = literal_eval('"'+prb_tmp[1:-1]+'"')
|
159 |
+
print(clss_probs[1:-1].split(' '))
|
160 |
+
prob_cls = [float(p) for p in clss_probs[1:-1].split(' ')]
|
161 |
+
tmp_df = pd.DataFrame.from_dict({'class':LABELS,'probability':prob_cls})
|
162 |
+
print(tmp_df.head())
|
163 |
+
fig = plt.figure(figsize=(15, 13))
|
164 |
+
sns.barplot(x='probability', y='class', data=tmp_df)
|
165 |
+
middle_column.pyplot(fig)
|
166 |
+
# st.caption('Predictions')
|
167 |
+
|
168 |
+
tmp_grad_img = GRAD_CAM_IMAGE_DIR + img_path.split("\\")[-2] +'/'+img_path.split("\\") [-1]
|
169 |
+
|
170 |
+
display_image = load_image(tmp_grad_img,replace=False)
|
171 |
+
# left_column.plotly_chart(fig, use_container_width=True)
|
172 |
+
right_column.image(display_image,caption = "ROI")
|
173 |
+
|
174 |
+
# seed_id = 900
|
175 |
+
seed_id = test_id_list.index(img_path)
|
176 |
+
query_item = test_features[seed_id]
|
177 |
+
print(query_item.shape)
|
178 |
+
closest_idxs = annoy_matching(annoy_tree,query_item, seed_id, 10)
|
179 |
+
closest_fns = [train_ids[close_i] for close_i in closest_idxs]
|
180 |
+
st.subheader('Top-10 Similar Samples from Gallery Set')
|
181 |
+
st.pyplot(plot_n_similar(seed_id,closest_fns, img_path,n=10, scale=4))
|
182 |
+
|
183 |
+
|
184 |
+
|
185 |
+
def run_instance_exp_keras_model(img_path, new_model, feature_extractor_model):
|
186 |
+
|
187 |
+
st.subheader('Instance Exploration')
|
188 |
+
# st.columns((1,1,1)) with row4_2:
|
189 |
+
LABELS = ['CNV', 'DRUSEN', 'DME', 'NORMAL']
|
190 |
+
|
191 |
+
left_column, middle_column, right_column = st.columns((1,1,1))
|
192 |
+
print(img_path)
|
193 |
+
org_img_path = img_path
|
194 |
+
|
195 |
+
img_path = f'{os.getcwd()}/data/oct2017/test/' + img_path.split("\\")[-2] +'/'+img_path.split("\\") [-1]
|
196 |
+
# img_path.replace("F:/XAI/data/OCT2017/","/home/hodor/dev/Learning/XAI/streamlit_demo/multipage-app/data/xai_framework_data/")
|
197 |
+
display_image = load_image(img_path)
|
198 |
+
# fig = px.imshow(display_image)
|
199 |
+
# left_column.plotly_chart(fig, use_container_width=True)
|
200 |
+
left_column.image(cv2.resize(display_image, (180,180)),caption = "Selected Input")
|
201 |
+
|
202 |
+
|
203 |
+
roi_img, probs = get_predictions_and_roi(img_path, new_model)
|
204 |
+
|
205 |
+
## probs
|
206 |
+
# print(np.asarray(probs))
|
207 |
+
# print(probs.shape)
|
208 |
+
prob_cls =np.asarray(probs)[0]
|
209 |
+
# print(prob_cls)
|
210 |
+
tmp_df = pd.DataFrame.from_dict({'class':LABELS,'probability':prob_cls})
|
211 |
+
fig = plt.figure(figsize=(5, 4))
|
212 |
+
sns.barplot(x='probability', y='class', data=tmp_df)
|
213 |
+
middle_column.pyplot(fig)
|
214 |
+
# middle_column.write("Probabilities")
|
215 |
+
|
216 |
+
# grad img
|
217 |
+
right_column.image(roi_img, caption = "Decision ROI")
|
218 |
+
|
219 |
+
# seed_id = 900
|
220 |
+
seed_id = test_id_list.index(org_img_path)
|
221 |
+
query_item = get_feature_vector(img_path,feature_extractor_model)
|
222 |
+
query_item = query_item.reshape(-1,1792)
|
223 |
+
# print(query_item.shape)
|
224 |
+
closest_idxs = annoy_matching(annoy_tree,query_item[0,:], seed_id, 10)
|
225 |
+
closest_fns = [train_ids[close_i] for close_i in closest_idxs]
|
226 |
+
|
227 |
+
closest_fns_tmp = [f'{os.getcwd()}/data/oct2017/train/' + each_fn.split("\\")[-2] +'/'+each_fn.split("\\") [-1]
|
228 |
+
for each_fn in closest_fns]
|
229 |
+
# print(closest_fns)
|
230 |
+
st.subheader('Top-10 Similar Samples from Gallery Set')
|
231 |
+
st.pyplot(plot_n_similar(seed_id,closest_fns_tmp, img_path,n=10,scale=3))
|
232 |
+
|
233 |
+
def main():
|
234 |
+
|
235 |
+
new_model = get_model(MODEL_PATH)
|
236 |
+
feature_extractor_model = get_feature_vector_model(MODEL_PATH)
|
237 |
+
|
238 |
+
row4_1, row4_2 = st.tabs(["Global Level Explanations", "Instance Level Explanations"])
|
239 |
+
|
240 |
+
with row4_1:
|
241 |
+
borderline_cases = np.load(borderline_id_path)
|
242 |
+
representative_cases = np.load(repr_id_path)
|
243 |
+
borderline_id_list = list(borderline_cases)
|
244 |
+
# print(borderline_id_list)
|
245 |
+
|
246 |
+
borderline_id_list = [f'{os.getcwd()}/data/oct2017/train/' + each_fn.split("\\")[-2] +'/'+each_fn.split("\\") [-1]
|
247 |
+
for each_fn in borderline_id_list]
|
248 |
+
representative_id_list = list(representative_cases)
|
249 |
+
representative_id_list = [f'{os.getcwd()}/data/oct2017/train/' + each_fn.split("\\")[-2] +'/'+each_fn.split("\\") [-1]
|
250 |
+
for each_fn in representative_id_list]
|
251 |
+
# st.info('GLOABAL EXPLANATION!! ')
|
252 |
+
option = st.selectbox('Please select to explore Representative or Borderline Samples', ["Representative Samples","Borderline Cases"])
|
253 |
+
if option:
|
254 |
+
clss = st.selectbox('Select a category(class)', ["CNV","DME", "NORMAL", "DRUSEN"])
|
255 |
+
side_1, side_2 = st.columns(2)
|
256 |
+
|
257 |
+
with side_1:
|
258 |
+
check_emb = st.checkbox('Embdedding Space Visuzalization')
|
259 |
+
|
260 |
+
with side_2:
|
261 |
+
check_samp = st.checkbox('Random Sample Visuzalization')
|
262 |
+
|
263 |
+
if check_emb and check_samp:
|
264 |
+
st.write("Emb and vis")
|
265 |
+
if option.startswith("Rep"):
|
266 |
+
filter_lst = list(filter(lambda k: clss in k, representative_id_list))
|
267 |
+
show_random_samples(filter_lst,clss)
|
268 |
+
show_tsne_vis(f"{ROOT_FIG_DIR}/tsne_representative.png", title="Representative")
|
269 |
+
else:
|
270 |
+
filter_lst = list(filter(lambda k: clss in k, borderline_id_list))
|
271 |
+
show_random_samples(filter_lst,clss)
|
272 |
+
show_tsne_vis(f"{ROOT_FIG_DIR}/tsne_borderline.png", title="Borderline")
|
273 |
+
elif check_emb:
|
274 |
+
st.write("embedding vis")
|
275 |
+
if option.startswith("Rep"):
|
276 |
+
show_tsne_vis(f"{ROOT_FIG_DIR}/tsne_representative.png", title="Representative")
|
277 |
+
else:
|
278 |
+
show_tsne_vis(f"{ROOT_FIG_DIR}/tsne_borderline.png", title="Borderline")
|
279 |
+
elif check_samp:
|
280 |
+
st.write("rand vis")
|
281 |
+
if option.startswith("Rep"):
|
282 |
+
filter_lst = list(filter(lambda k: clss in k, representative_id_list))
|
283 |
+
show_random_samples(filter_lst,clss)
|
284 |
+
else:
|
285 |
+
filter_lst = list(filter(lambda k: clss in k, borderline_id_list))
|
286 |
+
show_random_samples(filter_lst,clss)
|
287 |
+
with row4_2:
|
288 |
+
DF_TEST_PROP = load_pd_data_frame(TEST_CSV_FILE)
|
289 |
+
IMG_PATH_LISTS = get_path_list_from_df(DF_TEST_PROP)
|
290 |
+
IMG_CLSS_PROBS_LIST = get_class_probs_from_df(DF_TEST_PROP)
|
291 |
+
grad_vis_path_list = None
|
292 |
+
row2_col1, row2_col2 = st.columns(2)
|
293 |
+
with row2_col1:
|
294 |
+
option = st.selectbox('Please select a sample image, then click Explain Me button', IMG_PATH_LISTS)
|
295 |
+
with row2_col2:
|
296 |
+
st.info("Press the button")
|
297 |
+
pressed = st.button('Explain ME')
|
298 |
+
|
299 |
+
if pressed:
|
300 |
+
st.empty()
|
301 |
+
st.sidebar.write('Please wait for the magic to happen! This may take up to a minute.')
|
302 |
+
run_instance_exp_keras_model(option, new_model,feature_extractor_model)
|
303 |
+
|
304 |
+
|
305 |
+
# # new_model = load_model(MODEL_PATH)
|
306 |
+
# option = st.sidebar.selectbox('Please select a sample image, then click Explain Me button', IMG_PATH_LISTS)
|
307 |
+
# pressed = st.sidebar.button('Explain ME')
|
308 |
+
|
309 |
+
# main()
|
310 |
+
# expander_faq = st.expander("More About Our Project")
|
311 |
+
# expander_faq.write("Hi there! If you have any questions about our project, or simply want to check out the source code, please visit our github repo: https://github.com/kaplansinan/MLOPS")
|
312 |
+
|
313 |
+
|
314 |
+
def get_product_dev_page_layout():
|
315 |
+
return main()
|