hodorfi commited on
Commit
ef2e3e8
·
1 Parent(s): bdea4e2

Upload decisions_users.py

Browse files
Files changed (1) hide show
  1. utils/decisions_users.py +315 -0
utils/decisions_users.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import shutil
5
+ import urllib.request
6
+ from pathlib import Path
7
+ import pathlib
8
+ import time
9
+ import urllib
10
+ from ast import literal_eval
11
+ import albumentations as A
12
+ import tensorflow as tf
13
+
14
+ import cv2
15
+ import numpy as np
16
+ import pandas as pd
17
+ import plotly.express as px
18
+ import matplotlib.pyplot as plt
19
+ from PIL import Image
20
+
21
+ import streamlit as st
22
+
23
+ import seaborn as sns
24
+
25
+ sys.path.append(f'{os.getcwd()}/utils')
26
+ from utils.eval_users import get_product_dev_page_layout
27
+
28
+ # print(os.getcwd())
29
+
30
+ # Hide GPU from visible devices
31
+ tf.config.set_visible_devices([], 'GPU')
32
+ # Enable GPU memory growth - avoid allocating all memory at start
33
+ # gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
34
+ # for gpu in gpus:
35
+ # tf.config.experimental.set_memory_growth(device=gpu, enable=True)
36
+
37
+ from utils.control import show_tsne_vis,show_random_samples
38
+
39
+ from utils.annoy_sampling import load_annoy_tree
40
+
41
+
42
+ from utils.model_utils import load_model
43
+ from utils.model_utils import get_feature_vector, get_feature_extractor_model, get_predictions_and_roi
44
+
45
+ sns.set_style('darkgrid')
46
+ plt.rcParams['axes.grid'] = False
47
+
48
+ # st.set_page_config(layout="wide")
49
+
50
+ #https://github.com/IliaLarchenko/albumentations-demo/blob/3cb6528a513fe3b35dbb2c2a63cdcfbb9bb2a932/src/utils.py#L149
51
+
52
+ GRAD_CAM_IMAGE_DIR = f'{os.getcwd()}/data/gradcam_vis_data/'
53
+ TEST_CSV_FILE = f'{os.getcwd()}/data/test_set_pred_prop.csv'
54
+
55
+ annoy_tree_save_path =f'{os.getcwd()}/data/filtered_train_embedding/representative_samples_emb.annoy'
56
+ test_emb_path = f'{os.getcwd()}/data/filtered_train_embedding/test_embeddings.npy'
57
+ test_emb_id_path =f'{os.getcwd()}/data/filtered_train_embedding/test_ids.npy'
58
+
59
+ train_emb_id_path =f'{os.getcwd()}/data/filtered_train_embedding/representative_train_ids.npy'
60
+
61
+ repr_id_path =f'{os.getcwd()}/data/filtered_train_embedding/representative_train_ids.npy'
62
+ borderline_id_path =f'{os.getcwd()}/data/filtered_train_embedding/borderline_train_ids.npy'
63
+
64
+ MODEL_PATH = f'{os.getcwd()}/model/keras_model_0422/'
65
+
66
+ ROOT_FIG_DIR = f'{os.getcwd()}/figures/'
67
+
68
+ test_emb = np.load(test_emb_path)
69
+ test_ids = np.load(test_emb_id_path)
70
+ test_id_list = list(test_ids)
71
+ test_labels = [_id.split("\\")[1] for _id in test_id_list]
72
+ print(" NUmber of test samples: ",len(test_labels))
73
+ test_features = test_emb.reshape(-1,1792)
74
+
75
+
76
+ # train embedding list
77
+ train_ids = np.load(train_emb_id_path)
78
+ train_id_list = list(train_ids)
79
+ train_labels = [_id.split("\\")[1] for _id in train_id_list]
80
+ print(" NUmber of training samples: ",len(train_labels))
81
+
82
+ annoy_tree = load_annoy_tree(test_features.shape[1],annoy_tree_save_path)
83
+ def annoy_matching(annoy_f,query_item, query_index, n=10):
84
+ return annoy_f.get_nns_by_vector(query_item, n)
85
+
86
+ def get_img(fn ,thumbnail=False):
87
+ img = Image.open(fn)
88
+ if thumbnail:
89
+ img.thumbnail((100,100))
90
+ return img
91
+
92
+
93
+ def plot_n_similar(seed_id,similar_ids, test_path,n=10, scale=5):
94
+ f,ax = plt.subplots(1,n+1,figsize=((n+1)*scale,scale))
95
+ # print(os.path.basename(test_labels[seed_id])[:-4])
96
+ title = "SEED ID:{0}\nLabel:{1}".format(seed_id,os.path.basename(test_labels[seed_id]))
97
+ # print("path:", test_labels[seed_id].replace("F:/","E:/"))
98
+ ax[0].imshow(get_img(test_path.replace("F:/","E:/")),cmap='gray')
99
+ ax[0].set_title(title,fontsize=12)
100
+ for i in range(len(similar_ids)):
101
+ ax[i+1].imshow(get_img(similar_ids[i].replace("F:/","E:/")),cmap='gray')
102
+ title = "ID:{0}\nDistance: {1:.3f}\nLabel:{2}".format(i,0.1223,os.path.basename(similar_ids[i])[:-4])
103
+ ax[i+1].set_title(title,fontsize=10)
104
+ f.suptitle("Images similar to seed_id {0}".format(seed_id),fontsize=18)
105
+ plt.subplots_adjust(top=0.5)
106
+ return f
107
+
108
+ def load_image(filename,change_url=True):
109
+ # if change_url:
110
+ print(filename)
111
+ print(os.path.exists(filename))
112
+ img = cv2.imread(filename)
113
+ return img
114
+
115
+ @st.cache(allow_output_mutation=True)
116
+ def get_model(model_path):
117
+ new_model = tf.keras.models.load_model(model_path)
118
+ # keras_model = load_model(model_path)
119
+ return new_model
120
+
121
+ @st.cache(allow_output_mutation=True)
122
+ def get_feature_vector_model(model_path):
123
+ keras_model = tf.keras.models.load_model(model_path)
124
+ feature_extractor = tf.keras.Model(keras_model.inputs,keras_model.layers[-3].output)
125
+ return feature_extractor
126
+
127
+ def load_pd_data_frame(df_csv_path):
128
+ return pd.read_csv(df_csv_path)
129
+
130
+ def get_path_list_from_df(df_data):
131
+ return list(df_data['path'])
132
+
133
+ def get_class_probs_from_df(df_data):
134
+ return list(df_data['class_probs'])
135
+
136
+
137
+ def visualize_bar_plot(df_data):
138
+ fig = px.bar(df_data, x="probability", y="class", orientation='h')
139
+ return fig
140
+
141
+ def run_instance_exp(img_path, img_path_list,prob_list,grad_vis_path_list):
142
+
143
+ st.subheader('Instance Exploration')
144
+ # st.columns((1,1,1)) with row4_2:
145
+ LABELS = ['CNV', 'DRUSEN', 'DME', 'NORMAL']
146
+
147
+ left_column, middle_column, right_column = st.columns((1,1,1))
148
+ display_image = load_image(img_path)
149
+ # fig = px.imshow(display_image)
150
+ # left_column.plotly_chart(fig, use_container_width=True)
151
+ left_column.image(cv2.resize(display_image, (180,180)),caption = "Selected Input")
152
+
153
+ # get class probabilities
154
+ indx = img_path_list.index(img_path)
155
+ print(img_path)
156
+ prb_tmp = prob_list[indx]
157
+ print(f"{prb_tmp[1:-1]}")
158
+ clss_probs = literal_eval('"'+prb_tmp[1:-1]+'"')
159
+ print(clss_probs[1:-1].split(' '))
160
+ prob_cls = [float(p) for p in clss_probs[1:-1].split(' ')]
161
+ tmp_df = pd.DataFrame.from_dict({'class':LABELS,'probability':prob_cls})
162
+ print(tmp_df.head())
163
+ fig = plt.figure(figsize=(15, 13))
164
+ sns.barplot(x='probability', y='class', data=tmp_df)
165
+ middle_column.pyplot(fig)
166
+ # st.caption('Predictions')
167
+
168
+ tmp_grad_img = GRAD_CAM_IMAGE_DIR + img_path.split("\\")[-2] +'/'+img_path.split("\\") [-1]
169
+
170
+ display_image = load_image(tmp_grad_img,replace=False)
171
+ # left_column.plotly_chart(fig, use_container_width=True)
172
+ right_column.image(display_image,caption = "ROI")
173
+
174
+ # seed_id = 900
175
+ seed_id = test_id_list.index(img_path)
176
+ query_item = test_features[seed_id]
177
+ print(query_item.shape)
178
+ closest_idxs = annoy_matching(annoy_tree,query_item, seed_id, 10)
179
+ closest_fns = [train_ids[close_i] for close_i in closest_idxs]
180
+ st.subheader('Top-10 Similar Samples from Gallery Set')
181
+ st.pyplot(plot_n_similar(seed_id,closest_fns, img_path,n=10, scale=4))
182
+
183
+
184
+
185
+ def run_instance_exp_keras_model(img_path, new_model, feature_extractor_model):
186
+
187
+ st.subheader('Instance Exploration')
188
+ # st.columns((1,1,1)) with row4_2:
189
+ LABELS = ['CNV', 'DRUSEN', 'DME', 'NORMAL']
190
+
191
+ left_column, middle_column, right_column = st.columns((1,1,1))
192
+ print(img_path)
193
+ org_img_path = img_path
194
+
195
+ img_path = f'{os.getcwd()}/data/oct2017/test/' + img_path.split("\\")[-2] +'/'+img_path.split("\\") [-1]
196
+ # img_path.replace("F:/XAI/data/OCT2017/","/home/hodor/dev/Learning/XAI/streamlit_demo/multipage-app/data/xai_framework_data/")
197
+ display_image = load_image(img_path)
198
+ # fig = px.imshow(display_image)
199
+ # left_column.plotly_chart(fig, use_container_width=True)
200
+ left_column.image(cv2.resize(display_image, (180,180)),caption = "Selected Input")
201
+
202
+
203
+ roi_img, probs = get_predictions_and_roi(img_path, new_model)
204
+
205
+ ## probs
206
+ # print(np.asarray(probs))
207
+ # print(probs.shape)
208
+ prob_cls =np.asarray(probs)[0]
209
+ # print(prob_cls)
210
+ tmp_df = pd.DataFrame.from_dict({'class':LABELS,'probability':prob_cls})
211
+ fig = plt.figure(figsize=(5, 4))
212
+ sns.barplot(x='probability', y='class', data=tmp_df)
213
+ middle_column.pyplot(fig)
214
+ # middle_column.write("Probabilities")
215
+
216
+ # grad img
217
+ right_column.image(roi_img, caption = "Decision ROI")
218
+
219
+ # seed_id = 900
220
+ seed_id = test_id_list.index(org_img_path)
221
+ query_item = get_feature_vector(img_path,feature_extractor_model)
222
+ query_item = query_item.reshape(-1,1792)
223
+ # print(query_item.shape)
224
+ closest_idxs = annoy_matching(annoy_tree,query_item[0,:], seed_id, 10)
225
+ closest_fns = [train_ids[close_i] for close_i in closest_idxs]
226
+
227
+ closest_fns_tmp = [f'{os.getcwd()}/data/oct2017/train/' + each_fn.split("\\")[-2] +'/'+each_fn.split("\\") [-1]
228
+ for each_fn in closest_fns]
229
+ # print(closest_fns)
230
+ st.subheader('Top-10 Similar Samples from Gallery Set')
231
+ st.pyplot(plot_n_similar(seed_id,closest_fns_tmp, img_path,n=10,scale=3))
232
+
233
+ def main():
234
+
235
+ new_model = get_model(MODEL_PATH)
236
+ feature_extractor_model = get_feature_vector_model(MODEL_PATH)
237
+
238
+ row4_1, row4_2 = st.tabs(["Global Level Explanations", "Instance Level Explanations"])
239
+
240
+ with row4_1:
241
+ borderline_cases = np.load(borderline_id_path)
242
+ representative_cases = np.load(repr_id_path)
243
+ borderline_id_list = list(borderline_cases)
244
+ # print(borderline_id_list)
245
+
246
+ borderline_id_list = [f'{os.getcwd()}/data/oct2017/train/' + each_fn.split("\\")[-2] +'/'+each_fn.split("\\") [-1]
247
+ for each_fn in borderline_id_list]
248
+ representative_id_list = list(representative_cases)
249
+ representative_id_list = [f'{os.getcwd()}/data/oct2017/train/' + each_fn.split("\\")[-2] +'/'+each_fn.split("\\") [-1]
250
+ for each_fn in representative_id_list]
251
+ # st.info('GLOABAL EXPLANATION!! ')
252
+ option = st.selectbox('Please select to explore Representative or Borderline Samples', ["Representative Samples","Borderline Cases"])
253
+ if option:
254
+ clss = st.selectbox('Select a category(class)', ["CNV","DME", "NORMAL", "DRUSEN"])
255
+ side_1, side_2 = st.columns(2)
256
+
257
+ with side_1:
258
+ check_emb = st.checkbox('Embdedding Space Visuzalization')
259
+
260
+ with side_2:
261
+ check_samp = st.checkbox('Random Sample Visuzalization')
262
+
263
+ if check_emb and check_samp:
264
+ st.write("Emb and vis")
265
+ if option.startswith("Rep"):
266
+ filter_lst = list(filter(lambda k: clss in k, representative_id_list))
267
+ show_random_samples(filter_lst,clss)
268
+ show_tsne_vis(f"{ROOT_FIG_DIR}/tsne_representative.png", title="Representative")
269
+ else:
270
+ filter_lst = list(filter(lambda k: clss in k, borderline_id_list))
271
+ show_random_samples(filter_lst,clss)
272
+ show_tsne_vis(f"{ROOT_FIG_DIR}/tsne_borderline.png", title="Borderline")
273
+ elif check_emb:
274
+ st.write("embedding vis")
275
+ if option.startswith("Rep"):
276
+ show_tsne_vis(f"{ROOT_FIG_DIR}/tsne_representative.png", title="Representative")
277
+ else:
278
+ show_tsne_vis(f"{ROOT_FIG_DIR}/tsne_borderline.png", title="Borderline")
279
+ elif check_samp:
280
+ st.write("rand vis")
281
+ if option.startswith("Rep"):
282
+ filter_lst = list(filter(lambda k: clss in k, representative_id_list))
283
+ show_random_samples(filter_lst,clss)
284
+ else:
285
+ filter_lst = list(filter(lambda k: clss in k, borderline_id_list))
286
+ show_random_samples(filter_lst,clss)
287
+ with row4_2:
288
+ DF_TEST_PROP = load_pd_data_frame(TEST_CSV_FILE)
289
+ IMG_PATH_LISTS = get_path_list_from_df(DF_TEST_PROP)
290
+ IMG_CLSS_PROBS_LIST = get_class_probs_from_df(DF_TEST_PROP)
291
+ grad_vis_path_list = None
292
+ row2_col1, row2_col2 = st.columns(2)
293
+ with row2_col1:
294
+ option = st.selectbox('Please select a sample image, then click Explain Me button', IMG_PATH_LISTS)
295
+ with row2_col2:
296
+ st.info("Press the button")
297
+ pressed = st.button('Explain ME')
298
+
299
+ if pressed:
300
+ st.empty()
301
+ st.sidebar.write('Please wait for the magic to happen! This may take up to a minute.')
302
+ run_instance_exp_keras_model(option, new_model,feature_extractor_model)
303
+
304
+
305
+ # # new_model = load_model(MODEL_PATH)
306
+ # option = st.sidebar.selectbox('Please select a sample image, then click Explain Me button', IMG_PATH_LISTS)
307
+ # pressed = st.sidebar.button('Explain ME')
308
+
309
+ # main()
310
+ # expander_faq = st.expander("More About Our Project")
311
+ # expander_faq.write("Hi there! If you have any questions about our project, or simply want to check out the source code, please visit our github repo: https://github.com/kaplansinan/MLOPS")
312
+
313
+
314
+ def get_product_dev_page_layout():
315
+ return main()