|
import json |
|
import os |
|
import pickle |
|
import random |
|
import time |
|
from collections import Counter |
|
from datetime import datetime |
|
from glob import glob |
|
|
|
import gdown |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import pandas as pd |
|
import seaborn as sns |
|
import streamlit as st |
|
from PIL import Image |
|
|
|
import SessionState |
|
from download_utils import * |
|
from image_utils import * |
|
|
|
random.seed(datetime.now()) |
|
np.random.seed(int(time.time())) |
|
|
|
NUMBER_OF_TRIALS = 20 |
|
CLASSIFIER_TAG = "" |
|
explaination_functions = [load_chm_nns, load_knn_nns] |
|
selected_xai_tool = None |
|
|
|
|
|
folder_to_name = {} |
|
class_descriptions = {} |
|
classifier_predictions = {} |
|
selected_dataset = "Final" |
|
|
|
root_visualization_dir = "./visualizations/" |
|
viz_url = "https://drive.google.com/uc?id=1LpmOc_nFBzApYWAokO2J-s9RRXsk3pBN" |
|
viz_archivefile = "Final.zip" |
|
|
|
demonstration_url = "https://drive.google.com/uc?id=1C92llG5VrlABrsIEvxfNlSDc_gIeLlls" |
|
demonst_zipfile = "demonstrations.zip" |
|
|
|
picklefile_url = "https://drive.google.com/uc?id=1Yx4abA4VLZGO5JkzhXVGdy6mbPltMd68" |
|
prediction_root = "./predictions/" |
|
prediction_pickle = f"{prediction_root}predictions.pickle" |
|
|
|
|
|
|
|
download_files( |
|
root_visualization_dir, |
|
viz_url, |
|
viz_archivefile, |
|
demonstration_url, |
|
demonst_zipfile, |
|
picklefile_url, |
|
prediction_root, |
|
prediction_pickle, |
|
) |
|
|
|
|
|
app_mode = "" |
|
|
|
|
|
with open("imagenet-labels.json", "rb") as f: |
|
folder_to_name = json.load(f) |
|
|
|
with open("gloss.txt", "r") as f: |
|
description_file = f.readlines() |
|
|
|
class_descriptions = {l.split("\t")[0]: l.split("\t")[1] for l in description_file} |
|
|
|
|
|
with open(prediction_pickle, "rb") as f: |
|
classifier_predictions = pickle.load(f) |
|
|
|
|
|
session_state = SessionState.get( |
|
page=1, |
|
first_run=1, |
|
user_feedback={}, |
|
queries=[], |
|
is_classifier_correct={}, |
|
XAI_tool="Unselected", |
|
) |
|
|
|
|
|
|
|
def resmaple_queries(): |
|
if session_state.first_run == 1: |
|
both_correct = glob( |
|
root_visualization_dir + selected_dataset + "/Both_correct/*.JPEG" |
|
) |
|
both_wrong = glob( |
|
root_visualization_dir + selected_dataset + "/Both_wrong/*.JPEG" |
|
) |
|
|
|
correct_samples = list( |
|
np.random.choice(a=both_correct, size=NUMBER_OF_TRIALS // 2, replace=False) |
|
) |
|
wrong_samples = list( |
|
np.random.choice(a=both_wrong, size=NUMBER_OF_TRIALS // 2, replace=False) |
|
) |
|
|
|
all_images = correct_samples + wrong_samples |
|
random.shuffle(all_images) |
|
session_state.queries = all_images |
|
session_state.first_run = -1 |
|
|
|
session_state.user_feedback = {} |
|
session_state.is_classifier_correct = {} |
|
|
|
|
|
def render_experiment(query): |
|
current_query = session_state.queries[query] |
|
query_id = os.path.basename(current_query) |
|
|
|
predicted_wnid = classifier_predictions[query_id][f"{CLASSIFIER_TAG}-predictions"] |
|
prediction_confidence = classifier_predictions[query_id][ |
|
f"{CLASSIFIER_TAG}-confidence" |
|
] |
|
prediction_label = folder_to_name[predicted_wnid] |
|
class_def = class_descriptions[predicted_wnid] |
|
|
|
session_state.is_classifier_correct[query_id] = classifier_predictions[query_id][ |
|
f"{CLASSIFIER_TAG.upper()}-Output" |
|
] |
|
|
|
|
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.image(load_query(current_query), caption=f"Query ID: {query_id}") |
|
with col2: |
|
|
|
with st.expander("Show Class Description"): |
|
st.write(f"**Name**: {prediction_label}") |
|
st.write("**Class Definition**:") |
|
st.markdown("`" + class_def + "`") |
|
st.image( |
|
Image.open(f"demonstrations/{predicted_wnid}.jpeg"), |
|
caption=f"Class Explanation", |
|
use_column_width=True, |
|
) |
|
|
|
default_value = 0 |
|
if query_id in session_state.user_feedback.keys(): |
|
if session_state.user_feedback[query_id] == "Correct": |
|
default_value = 1 |
|
elif session_state.user_feedback[query_id] == "Wrong": |
|
default_value = 2 |
|
|
|
session_state.user_feedback[query_id] = st.radio( |
|
"What do you think about model's prediction?", |
|
("-", "Correct", "Wrong"), |
|
key=query_id, |
|
index=default_value, |
|
) |
|
st.write(f"**Model Prediction**: {prediction_label}") |
|
st.write(f"**Model Confidence**: {prediction_confidence}") |
|
|
|
|
|
if selected_xai_tool is not None: |
|
st.image( |
|
selected_xai_tool(current_query), |
|
caption=f"Explaination", |
|
use_column_width=True, |
|
) |
|
|
|
|
|
|
|
if st.button("Debug: Show Everything"): |
|
st.image(Image.open(current_query)) |
|
|
|
|
|
def render_results(): |
|
user_correct_guess = 0 |
|
for q in session_state.user_feedback.keys(): |
|
uf = True if session_state.user_feedback[q] == "Correct" else False |
|
if session_state.is_classifier_correct[q] == uf: |
|
user_correct_guess += 1 |
|
|
|
st.write( |
|
f"User performance on {CLASSIFIER_TAG}: {user_correct_guess} out of {len( session_state.user_feedback)} Correct" |
|
) |
|
st.markdown("## User Performance Breakdown") |
|
|
|
categories = [ |
|
"Correct", |
|
"Wrong", |
|
] |
|
breakdown_stats_correct = {c: 0 for c in categories} |
|
breakdown_stats_wrong = {c: 0 for c in categories} |
|
|
|
experiment_summary = [] |
|
|
|
for q in session_state.user_feedback.keys(): |
|
category = "Correct" if session_state.is_classifier_correct[q] else "Wrong" |
|
is_user_correct = category == session_state.user_feedback[q] |
|
|
|
if is_user_correct: |
|
breakdown_stats_correct[category] += 1 |
|
else: |
|
breakdown_stats_wrong[category] += 1 |
|
|
|
experiment_summary.append( |
|
[ |
|
q, |
|
classifier_predictions[q]["real-gts"], |
|
folder_to_name[ |
|
classifier_predictions[q][f"{CLASSIFIER_TAG}-predictions"] |
|
], |
|
category, |
|
session_state.user_feedback[q], |
|
is_user_correct, |
|
] |
|
) |
|
|
|
experiment_summary_df = pd.DataFrame.from_records( |
|
experiment_summary, |
|
columns=[ |
|
"Query", |
|
"GT Labels", |
|
f"{CLASSIFIER_TAG} Prediction", |
|
"Category", |
|
"User Prediction", |
|
"Is User Prediction Correct", |
|
], |
|
) |
|
st.write("Summary", experiment_summary_df) |
|
|
|
csv = convert_df(experiment_summary_df) |
|
st.download_button( |
|
"Press to Download", csv, "summary.csv", "text/csv", key="download-records" |
|
) |
|
|
|
user_pf_by_model_pred = experiment_summary_df.groupby("Category").agg( |
|
{"Is User Prediction Correct": ["count", "sum", "mean"]} |
|
) |
|
|
|
user_pf_by_model_pred.columns = user_pf_by_model_pred.columns.droplevel(0) |
|
user_pf_by_model_pred.columns = [ |
|
"Count", |
|
"Correct User Guess", |
|
"Mean User Performance", |
|
] |
|
user_pf_by_model_pred.index.name = "Model Prediction" |
|
st.write("User performance break down by Model prediction:", user_pf_by_model_pred) |
|
csv = convert_df(user_pf_by_model_pred) |
|
st.download_button( |
|
"Press to Download", |
|
csv, |
|
"user-performance-by-model-prediction.csv", |
|
"text/csv", |
|
key="download-performance-by-model-prediction", |
|
) |
|
|
|
|
|
confusion_matrix = pd.crosstab( |
|
experiment_summary_df["Category"], |
|
experiment_summary_df["User Prediction"], |
|
rownames=["Actual"], |
|
colnames=["Predicted"], |
|
) |
|
st.write("Confusion Matrix", confusion_matrix) |
|
csv = convert_df(confusion_matrix) |
|
st.download_button( |
|
"Press to Download", |
|
csv, |
|
"confusion-matrix.csv", |
|
"text/csv", |
|
key="download-confusiion-matrix", |
|
) |
|
|
|
|
|
def render_menu(): |
|
|
|
readme_text = st.markdown( |
|
""" |
|
# Instructions |
|
``` |
|
When testing this study, you should first see the class definition, then hide the expander and see the query. |
|
``` |
|
""" |
|
) |
|
|
|
app_mode = st.selectbox( |
|
"Choose the page to show:", |
|
["Experiment Instruction", "Start Experiment", "See the Results"], |
|
) |
|
|
|
if app_mode == "Experiment Instruction": |
|
st.success("To continue select an option in the dropdown menu.") |
|
elif app_mode == "Start Experiment": |
|
|
|
readme_text.empty() |
|
|
|
page_id = session_state.page |
|
col1, col4, col2, col3 = st.columns(4) |
|
prev_page = col1.button("Previous Image") |
|
|
|
if prev_page: |
|
page_id -= 1 |
|
if page_id < 1: |
|
page_id = 1 |
|
|
|
next_page = col2.button("Next Image") |
|
|
|
if next_page: |
|
page_id += 1 |
|
if page_id > NUMBER_OF_TRIALS: |
|
page_id = NUMBER_OF_TRIALS |
|
|
|
if page_id == NUMBER_OF_TRIALS: |
|
st.success( |
|
'You have reached the last image. Please go to the "Results" page to see your performance.' |
|
) |
|
if st.button("View"): |
|
app_mode = "See the Results" |
|
|
|
if col3.button("Resample"): |
|
st.write("Restarting ...") |
|
page_id = 1 |
|
session_state.first_run = 1 |
|
resmaple_queries() |
|
|
|
session_state.page = page_id |
|
st.write(f"Render Experiment: {session_state.page}") |
|
render_experiment(session_state.page - 1) |
|
elif app_mode == "See the Results": |
|
readme_text.empty() |
|
st.write("Results Summary") |
|
render_results() |
|
|
|
|
|
def main(): |
|
global app_mode |
|
global session_state |
|
global selected_xai_tool |
|
global CLASSIFIER_TAG |
|
|
|
|
|
|
|
st.set_page_config(layout="wide") |
|
st.title("Visual CorrespondenceHuman Study - ImageNet") |
|
|
|
options = [ |
|
"Unselected", |
|
"NOXAI", |
|
"KNN", |
|
"EMD-Corr Nearest Neighbors", |
|
"EMD-Corr Correspondence", |
|
"CHM-Corr Nearest Neighbors", |
|
"CHM-Corr Correspondence", |
|
] |
|
|
|
st.markdown( |
|
""" <style> |
|
div[role="radiogroup"] > :first-child{ |
|
display: none !important; |
|
} |
|
</style> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
if session_state.XAI_tool == "Unselected": |
|
default = options.index(session_state.XAI_tool) |
|
session_state.XAI_tool = st.radio( |
|
"What explaination tool do you want to evaluate?", |
|
options, |
|
key="which_xai", |
|
index=default, |
|
) |
|
|
|
|
|
if session_state.XAI_tool != "Unselected": |
|
st.markdown(f"## SELECTED METHOD ``{session_state.XAI_tool}``") |
|
|
|
if session_state.XAI_tool == "NOXAI": |
|
CLASSIFIER_TAG = "knn" |
|
selected_xai_tool = None |
|
elif session_state.XAI_tool == "KNN": |
|
selected_xai_tool = load_knn_nns |
|
CLASSIFIER_TAG = "knn" |
|
elif session_state.XAI_tool == "CHM-Corr Nearest Neighbors": |
|
selected_xai_tool = load_chm_nns |
|
CLASSIFIER_TAG = "CHM" |
|
elif session_state.XAI_tool == "CHM-Corr Correspondence": |
|
selected_xai_tool = load_chm_corrs |
|
CLASSIFIER_TAG = "CHM" |
|
elif session_state.XAI_tool == "EMD-Corr Nearest Neighbors": |
|
selected_xai_tool = load_emd_nns |
|
CLASSIFIER_TAG = "EMD" |
|
elif session_state.XAI_tool == "EMD-Corr Correspondence": |
|
selected_xai_tool = load_emd_corrs |
|
CLASSIFIER_TAG = "EMD" |
|
|
|
resmaple_queries() |
|
render_menu() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|