Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import json | |
import shutil | |
import urllib.request | |
from pathlib import Path | |
import pathlib | |
import time | |
import urllib | |
from ast import literal_eval | |
# import albumentations as A | |
import tensorflow as tf | |
import cv2 | |
import numpy as np | |
import pandas as pd | |
import plotly.express as px | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import streamlit as st | |
import seaborn as sns | |
sys.path.append(f'{os.getcwd()}/utils') | |
from utils.eval_users import get_product_dev_page_layout | |
# print(os.getcwd()) | |
# Hide GPU from visible devices | |
tf.config.set_visible_devices([], 'GPU') | |
# Enable GPU memory growth - avoid allocating all memory at start | |
# gpus = tf.config.experimental.list_physical_devices(device_type='GPU') | |
# for gpu in gpus: | |
# tf.config.experimental.set_memory_growth(device=gpu, enable=True) | |
from utils.control import show_tsne_vis,show_random_samples | |
from utils.annoy_sampling import load_annoy_tree | |
from utils.model_utils import load_model | |
from utils.model_utils import get_feature_vector, get_feature_extractor_model, get_predictions_and_roi | |
sns.set_style('darkgrid') | |
plt.rcParams['axes.grid'] = False | |
# import tensorflow as tf | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress TensorFlow logging (1) | |
tf.get_logger().setLevel('ERROR') # Suppress TensorFlow logging (2) | |
# st.set_page_config(layout="wide") | |
#https://github.com/IliaLarchenko/albumentations-demo/blob/3cb6528a513fe3b35dbb2c2a63cdcfbb9bb2a932/src/utils.py#L149 | |
GRAD_CAM_IMAGE_DIR = f'{os.getcwd()}/data/gradcam_vis_data/' | |
TEST_CSV_FILE = f'{os.getcwd()}/data/test_set_pred_prop.csv' | |
annoy_tree_save_path =f'{os.getcwd()}/data/filtered_train_embedding/representative_samples_emb.annoy' | |
test_emb_path = f'{os.getcwd()}/data/filtered_train_embedding/test_embeddings.npy' | |
test_emb_id_path =f'{os.getcwd()}/data/filtered_train_embedding/test_ids.npy' | |
train_emb_id_path =f'{os.getcwd()}/data/filtered_train_embedding/representative_train_ids.npy' | |
repr_id_path =f'{os.getcwd()}/data/filtered_train_embedding/representative_train_ids.npy' | |
borderline_id_path =f'{os.getcwd()}/data/filtered_train_embedding/borderline_train_ids.npy' | |
MODEL_PATH = f'{os.getcwd()}/model/keras_model_0422/' | |
ROOT_FIG_DIR = f'{os.getcwd()}/figures/' | |
test_emb = np.load(test_emb_path) | |
test_ids = np.load(test_emb_id_path) | |
test_id_list = list(test_ids) | |
test_labels = [_id.split("\\")[1] for _id in test_id_list] | |
print(" NUmber of test samples: ",len(test_labels)) | |
test_features = test_emb.reshape(-1,1792) | |
# train embedding list | |
train_ids = np.load(train_emb_id_path) | |
train_id_list = list(train_ids) | |
train_labels = [_id.split("\\")[1] for _id in train_id_list] | |
print(" NUmber of training samples: ",len(train_labels)) | |
annoy_tree = load_annoy_tree(test_features.shape[1],annoy_tree_save_path) | |
def annoy_matching(annoy_f,query_item, query_index, n=10): | |
return annoy_f.get_nns_by_vector(query_item, n) | |
def get_img(fn ,thumbnail=False): | |
img = Image.open(fn) | |
if thumbnail: | |
img.thumbnail((150,150)) | |
return img | |
def open_gray(fn): | |
img = cv2.cvtColor(cv2.imread(fn), cv2.COLOR_BGR2GRAY) | |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) | |
img = cv2.resize(img,(224,224)) | |
return img | |
def show_values(axs, orient="v", space=.01): | |
def _single(ax): | |
if orient == "v": | |
for p in ax.patches: | |
_x = p.get_x() + p.get_width() / 2 | |
_y = p.get_y() + p.get_height() + (p.get_height()*0.01) | |
value = '{:.1f}'.format(p.get_height()) | |
ax.text(_x, _y, value, ha="center") | |
elif orient == "h": | |
for p in ax.patches: | |
_x = p.get_x() + p.get_width() + float(space) | |
_y = p.get_y() + p.get_height() - (p.get_height()*0.5) | |
value = '{:.2f}'.format(p.get_width()) | |
ax.text(_x, _y, value, ha="left") | |
if isinstance(axs, np.ndarray): | |
for idx, ax in np.ndenumerate(axs): | |
_single(ax) | |
else: | |
_single(axs) | |
def plot_n_similar(seed_id,similar_ids, test_path,n=10, scale=5): | |
# img_list = [] | |
# title_list = ["SEED ID:{0} <br> Label:{1}".format(seed_id,os.path.basename(test_labels[seed_id]))] | |
# # for indx,row in cnv_cls_df.iterrows(): | |
# img_list.append(open_gray(test_path.replace("F:/","E:/"))) | |
# for i in range(len(similar_ids)): | |
# print("PATH:",similar_ids[i].replace("F:/","E:/")) | |
# img_list.append(open_gray(similar_ids[i].replace("F:/","E:/"))) | |
# # ax[i+1].imshow(get_img(similar_ids[i].replace("F:/","E:/")),cmap='gray') | |
# title = "ID:{0} <br> Distance: {1:.3f} <br> Label:{2}".format(i,0.1223,os.path.basename(similar_ids[i])[:-4]) | |
# title_list.append(title) | |
# fig = px.imshow(np.array(img_list), facet_col=0, binary_string=True,facet_row_spacing=0.002,facet_col_spacing=0.002) | |
# # Set facet titles | |
# for i, sigma in enumerate(title_list): | |
# fig.layout.annotations[i]['text'] = sigma | |
# fig.layout.annotations[i]['yshift'] = -40 | |
# # fig.layout.tex | |
# fig.update_layout( | |
# margin=dict( | |
# l=10, | |
# r=10, | |
# b=10, | |
# t=40, | |
# pad=1 | |
# ), | |
# ) | |
# fig.update_xaxes(showticklabels=False) | |
# fig.update_yaxes(showticklabels=False) | |
# fig.show() | |
f,ax = plt.subplots(1,n+1,figsize=((n+1)*scale,scale)) | |
# print(os.path.basename(test_labels[seed_id])[:-4]) | |
title = "SEED ID:{0}\nLabel:{1}".format(seed_id,os.path.basename(test_labels[seed_id])) | |
# print("path:", test_labels[seed_id].replace("F:/","E:/")) | |
ax[0].imshow(get_img(test_path.replace("F:/","E:/")),cmap='gray') | |
ax[0].set_title(title,fontsize=12) | |
for i in range(len(similar_ids)): | |
# print("PATH:", similar_ids[i]) | |
ax[i+1].imshow(get_img(similar_ids[i].replace("F:/","E:/")),cmap='gray') | |
title = "ID:{0}\nDistance: {1:.3f}\nLabel:{2}".format(i,0.1223,os.path.basename(similar_ids[i])[:-4]) | |
ax[i+1].set_title(title,fontsize=10) | |
f.suptitle("Images similar to seed_id {0}".format(seed_id),fontsize=18) | |
plt.subplots_adjust(top=0.4) | |
plt.tight_layout() | |
return f | |
def load_image(filename,change_url=True): | |
# if change_url: | |
# print(filename) | |
# print(os.path.exists(filename)) | |
img = cv2.imread(filename) | |
return img | |
def get_model(model_path): | |
new_model = tf.keras.models.load_model(model_path) | |
# keras_model = load_model(model_path) | |
return new_model | |
def get_feature_vector_model(model_path): | |
keras_model = tf.keras.models.load_model(model_path) | |
feature_extractor = tf.keras.Model(keras_model.inputs,keras_model.layers[-3].output) | |
return feature_extractor | |
def load_pd_data_frame(df_csv_path): | |
return pd.read_csv(df_csv_path) | |
def get_path_list_from_df(df_data): | |
return list(df_data['path']) | |
def get_class_probs_from_df(df_data): | |
return list(df_data['class_probs']) | |
def visualize_bar_plot(df_data): | |
fig = px.bar(df_data, x="probability", y="class", orientation='h') | |
return fig | |
# def run_instance_exp(img_path, img_path_list,prob_list,grad_vis_path_list): | |
# st.subheader('Instance Exploration') | |
# # st.columns((1,1,1)) with row4_2: | |
# LABELS = ['CNV', 'DRUSEN', 'DME', 'NORMAL'] | |
# left_column, middle_column, right_column = st.columns((1,1,1)) | |
# display_image = load_image(img_path) | |
# # fig = px.imshow(display_image) | |
# # left_column.plotly_chart(fig, use_container_width=True) | |
# left_column.image(cv2.resize(display_image, (180,180)),caption = "Selected Input") | |
# # get class probabilities | |
# indx = img_path_list.index(img_path) | |
# print(img_path) | |
# prb_tmp = prob_list[indx] | |
# print(f"{prb_tmp[1:-1]}") | |
# clss_probs = literal_eval('"'+prb_tmp[1:-1]+'"') | |
# print(clss_probs[1:-1].split(' ')) | |
# prob_cls = [float(p) for p in clss_probs[1:-1].split(' ')] | |
# tmp_df = pd.DataFrame.from_dict({'class':LABELS,'probability':prob_cls}) | |
# print(tmp_df.head()) | |
# fig = plt.figure(figsize=(15, 13)) | |
# sns.barplot(x='probability', y='class', data=tmp_df) | |
# middle_column.pyplot(fig) | |
# # st.caption('Predictions') | |
# tmp_grad_img = GRAD_CAM_IMAGE_DIR + img_path.split("\\")[-2] +'/'+img_path.split("\\") [-1] | |
# display_image = load_image(tmp_grad_img,replace=False) | |
# # left_column.plotly_chart(fig, use_container_width=True) | |
# right_column.image(display_image,caption = "ROI") | |
# # seed_id = 900 | |
# seed_id = test_id_list.index(img_path) | |
# query_item = test_features[seed_id] | |
# print(query_item.shape) | |
# closest_idxs = annoy_matching(annoy_tree,query_item, seed_id, 10) | |
# closest_fns = [train_ids[close_i] for close_i in closest_idxs] | |
# st.subheader('Top-10 Similar Samples from Gallery Set') | |
# st.plotly_chart(plot_n_similar(seed_id,closest_fns, img_path,n=10, scale=4), use_container_width=True) | |
# # st.pyplot(plot_n_similar(seed_id,closest_fns, img_path,n=10, scale=4)) | |
def run_instance_exp_keras_model(img_path, new_model, feature_extractor_model): | |
# st.subheader('Instance Predictions') | |
# st.columns((1,1,1)) with row4_2: | |
LABELS = ['CNV', 'DRUSEN', 'DME', 'NORMAL'] | |
left_column, middle_column, right_column = st.columns((1,1,1)) | |
print(img_path) | |
org_img_path = img_path | |
img_path = f'{os.getcwd()}/data/oct2017/test/' + img_path.split("\\")[-2] +'/'+img_path.split("\\") [-1] | |
# img_path.replace("F:/XAI/data/OCT2017/","/home/hodor/dev/Learning/XAI/streamlit_demo/multipage-app/data/xai_framework_data/") | |
display_image = load_image(img_path) | |
# fig = px.imshow(display_image) | |
# left_column.plotly_chart(fig, use_container_width=True) | |
left_column.image(display_image) | |
left_column.write("Input Image") | |
# left_column.image(cv2.resize(display_image, (180,180)),caption = "Selected Input") | |
roi_img, probs = get_predictions_and_roi(img_path, new_model) | |
## probs | |
# print(np.asarray(probs)) | |
# print(probs.shape) | |
prob_cls =np.asarray(probs)[0] | |
# print(prob_cls) | |
tmp_df = pd.DataFrame.from_dict({'class':LABELS,'probability':prob_cls}) | |
fig = plt.figure(figsize=(8, 8.8)) | |
p =sns.barplot(x='probability', y='class', data=tmp_df) | |
show_values(p, "h", space=0.05) | |
middle_column.pyplot(fig) | |
middle_column.write("Predicted Class Probabilities") | |
# fig = px.bar(LABELS, prob_cls) | |
# fig = px.bar(tmp_df, x="class", y="probability", orientation='h') | |
# middle_column.plotly_chart(fig, use_container_width=False) | |
# grad img | |
print("roi image stats", roi_img.shape) | |
# right_column.image(roi_img, caption = "Decision ROI") | |
print(display_image.shape) | |
tmp_shape = display_image.shape[:2] | |
right_column.image(cv2.resize(roi_img, (tmp_shape[1],tmp_shape[0]))) | |
right_column.write("GradCAM RoI") | |
# seed_id = 900 | |
seed_id = test_id_list.index(org_img_path) | |
query_item = get_feature_vector(img_path,feature_extractor_model) | |
query_item = query_item.reshape(-1,1792) | |
# print(query_item.shape) | |
closest_idxs = annoy_matching(annoy_tree,query_item[0,:], seed_id, 10) | |
closest_fns = [train_ids[close_i] for close_i in closest_idxs] | |
closest_fns_tmp = [f'{os.getcwd()}/data/oct2017/train/' + each_fn.split("\\")[-2] +'/'+each_fn.split("\\") [-1] | |
for each_fn in closest_fns] | |
# print(closest_fns) | |
st.subheader('Top-10 Similar Samples from Gallery(representative) Set') | |
# st.plotly_chart(plot_n_similar(seed_id,closest_fns, img_path,n=10, scale=4), use_container_width=True) | |
st.pyplot(plot_n_similar(seed_id,closest_fns_tmp, img_path,n=10,scale=4)) | |
if "load_state" not in st.session_state: | |
st.session_state.load_state = True | |
# if st.session_state.load_state: | |
# st.session_state.load_state = False | |
# with st.expander('Correct Decision'): | |
# st.info("Correct th decision if it is not valid. This sample will be added to next training bucket.") | |
# tmp_col1, tmp_col2 = st.columns(2) | |
# with tmp_col1: | |
# label_correect = st.radio( | |
# "Choose label visibility ๐", | |
# ["CNV", "DME", "NORMAL","DRUSEN"], | |
# disabled=False, | |
# horizontal=True) | |
# with tmp_col2: | |
# tmp_btn = st.button('ADD TO TRAINING BUCKET') | |
# if tmp_btn: | |
# st.warning("Sample added to training set..") | |
def main(user_type='Developer'): | |
new_model = get_model(MODEL_PATH) | |
feature_extractor_model = get_feature_vector_model(MODEL_PATH) | |
row4_1, row4_2 = st.tabs(["Global Level Explanations", "Instance Level Explanations"]) | |
with row4_1: | |
borderline_cases = np.load(borderline_id_path) | |
representative_cases = np.load(repr_id_path) | |
borderline_id_list = list(borderline_cases) | |
# print(borderline_id_list) | |
borderline_id_list = [f'{os.getcwd()}/data/oct2017/train/' + each_fn.split("\\")[-2] +'/'+each_fn.split("\\") [-1] | |
for each_fn in borderline_id_list] | |
representative_id_list = list(representative_cases) | |
representative_id_list = [f'{os.getcwd()}/data/oct2017/train/' + each_fn.split("\\")[-2] +'/'+each_fn.split("\\") [-1] | |
for each_fn in representative_id_list] | |
# st.info('GLOABAL EXPLANATION!! ') | |
option = st.selectbox('Please select to explore Representative/Borderline Samples', ["Choose here","Representative Samples","Borderline Cases"],index=0) | |
if not option.startswith("Choose"): | |
if user_type!='Manager': | |
if option.startswith('Rep'): | |
with st.expander('Click to see Representative Sampling Algorithm'): | |
algo_path = f'{ROOT_FIG_DIR}/representativesampling.png' | |
st.image(algo_path) | |
with st.expander('Click to see Manifold(t-SNE) Visualization of Representative Samples'): | |
show_tsne_vis(f"{ROOT_FIG_DIR}/tsne_representative.png", title="Representative") | |
else: | |
with st.expander('Click to see Borderline Sampling Algorithm'): | |
algo_path = f'{ROOT_FIG_DIR}/borderlinesampling.png' | |
st.image(algo_path) | |
with st.expander('Click to see Manifold(t-SNE) Visualization of Broderline Samples'): | |
show_tsne_vis(f"{ROOT_FIG_DIR}/tsne_borderline.png", title="Borderline") | |
clss = st.selectbox('Select a category(class)', ["CNV","DME", "NORMAL", "DRUSEN"]) | |
# side_1, side_2 = st.columns(2) | |
if option.startswith("Rep"): | |
filter_lst = list(filter(lambda k: clss in k, representative_id_list)) | |
show_random_samples(filter_lst,clss) | |
else: | |
filter_lst = list(filter(lambda k: clss in k, borderline_id_list)) | |
show_random_samples(filter_lst,clss) | |
# with side_1: | |
# check_emb = st.checkbox('Embdedding Space Visuzalization') | |
# with side_2: | |
# check_samp = st.checkbox('Random Sample Visuzalization') | |
# if check_emb and check_samp: | |
# st.write("Emb and vis") | |
# if option.startswith("Rep"): | |
# filter_lst = list(filter(lambda k: clss in k, representative_id_list)) | |
# show_random_samples(filter_lst,clss) | |
# show_tsne_vis(f"{ROOT_FIG_DIR}/tsne_representative.png", title="Representative") | |
# else: | |
# filter_lst = list(filter(lambda k: clss in k, borderline_id_list)) | |
# show_random_samples(filter_lst,clss) | |
# # show_tsne_vis(f"{ROOT_FIG_DIR}/tsne_borderline.png", title="Borderline") | |
# elif check_emb: | |
# st.write("embedding vis") | |
# if option.startswith("Rep"): | |
# show_tsne_vis(f"{ROOT_FIG_DIR}/tsne_representative.png", title="Representative") | |
# else: | |
# show_tsne_vis(f"{ROOT_FIG_DIR}/tsne_borderline.png", title="Borderline") | |
# elif check_samp: | |
# st.write("rand vis") | |
# if option.startswith("Rep"): | |
# filter_lst = list(filter(lambda k: clss in k, representative_id_list)) | |
# show_random_samples(filter_lst,clss) | |
# else: | |
# filter_lst = list(filter(lambda k: clss in k, borderline_id_list)) | |
# show_random_samples(filter_lst,clss) | |
with row4_2: | |
DF_TEST_PROP = load_pd_data_frame(TEST_CSV_FILE) | |
IMG_PATH_LISTS = get_path_list_from_df(DF_TEST_PROP) | |
IMG_CLSS_PROBS_LIST = get_class_probs_from_df(DF_TEST_PROP) | |
grad_vis_path_list = None | |
row2_col1, row2_col2 = st.columns(2) | |
with row2_col1: | |
option = st.selectbox('Please select a sample image๐', IMG_PATH_LISTS) | |
with row2_col2: | |
st.write("Click button") | |
pressed = st.button('Explain ME') | |
if pressed: | |
st.empty() | |
st.write('Please wait for a while! This may take up to a minute.') | |
run_instance_exp_keras_model(option, new_model,feature_extractor_model) | |
def form_callback(): | |
st.write("Training set updated") | |
# st.write("test2") | |
if user_type!='Manager': | |
with st.expander('Correct the Decision'): | |
with st.form("my_form"): | |
st.info("Correct the decision if it is not valid. The sample will be added to next training bucket.") | |
tmp_col1, tmp_col2 = st.columns(2) | |
with tmp_col1: | |
label_correect = st.radio( | |
"Choose label ๐", | |
["NONE","CNV", "DME", "NORMAL","DRUSEN"], | |
key="visibility", | |
horizontal=True) | |
# st.stop() | |
# Every form must have a submit button. | |
submit_button = st.form_submit_button(label='ADD TO TRAINING BUCKET', on_click=form_callback) | |
# st.session_state.load_state = False | |
# if submitted and not st.session_state.load_state: | |
# st.warning("Sample added to training set..") | |
# st.write("Outside the form") | |
# st.write("slider", slider_val, "checkbox", checkbox_val) | |
# with tmp_col2: | |
# tmp_btn = st.button('ADD TO TRAINING BUCKET') | |
# if tmp_btn: | |
# st.warning("Sample added to training set..") | |
# # new_model = load_model(MODEL_PATH) | |
# option = st.sidebar.selectbox('Please select a sample image, then click Explain Me button', IMG_PATH_LISTS) | |
# pressed = st.sidebar.button('Explain ME') | |
# main() | |
# expander_faq = st.expander("More About Our Project") | |
# expander_faq.write("Hi there! If you have any questions about our project, or simply want to check out the source code, please visit our github repo: https://github.com/kaplansinan/MLOPS") | |
def get_product_dev_page_layout(user_type ="Developer"): | |
return main(user_type=user_type) | |
# def get_product_dev_page_layout(): | |
# return main() | |