fossil_app / app.py
andy-wyx's picture
debugging: xai output distortion
99ddcfc
raw
history blame
23.6 kB
import os
import sys
from env import config_env
config_env()
import gradio as gr
from huggingface_hub import snapshot_download
import cv2
import dotenv
dotenv.load_dotenv()
import numpy as np
import gradio as gr
import glob
from inference_sam import segmentation_sam
from explanations import explain
from inference_resnet import get_triplet_model
from inference_resnet_v2 import get_resnet_model,inference_resnet_embedding_v2,inference_resnet_finer_v2
from inference_beit import get_triplet_model_beit
import pathlib
import tensorflow as tf
from closest_sample import get_images,get_diagram
if not os.path.exists('images'):
REPO_ID='Serrelab/image_examples_gradio'
snapshot_download(repo_id=REPO_ID, token=os.environ.get('READ_TOKEN'),repo_type='dataset',local_dir='images')
if not os.path.exists('dataset'):
REPO_ID='Serrelab/Fossils'
token = os.environ.get('READ_TOKEN')
print(f"Read token:{token}")
if token is None:
print("warning! A read token in env variables is needed for authentication.")
snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset')
HEADER = '''
<div style='display: flex; align-items: baseline;'>
<h1 style='margin-right: 10px;'><b>Official Gradio Demo:</b></h1>
<h1>🍁 <a href='https://huggingface.co/spaces/Serrelab/fossil_app' target='_blank'><b>Identifying Florissant Leaf Fossils to Family using Deep Neural Networks</b></a></h1>
</div>
'''
"""
**Fossil** a brief intro to the project.
# ❗️❗️❗️**Important Notes:**
# - some notes to users some notes to users some notes to users some notes to users some notes to users some notes to users .
# - some notes to users some notes to users some notes to users some notes to users some notes to users some notes to users.
Code: <a href='https://github.com/orgs/serre-lab/projects/2' target='_blank'>GitHub</a>. Paper: <a href='' target='_blank'>ArXiv</a>.
"""
USER_GUIDE = """
<div style='background-color: #f0f0f0; padding: 20px; border-radius: 10px;'>
<h2 style='font-size: 22px;'>❗️ User Guide</h2>
<p style='font-size: 16px;'>Welcome to the interactive fossil exploration tool. Here's how to get started:</p>
<ul style='font-size: 16px;'>
<li><strong>Upload an Image:</strong> Drag and drop or choose from given samples to upload images of fossils.</li>
<li><strong>Process Image:</strong> After uploading, click the 'Process Image' button to analyze the image.</li>
<li><strong>Explore Results:</strong> Switch to the 'Workbench' tab to check out detailed analysis and results.</li>
</ul>
<h3 style='font-size: 18px;'>Tips</h3>
<ul style='font-size: 16px;'>
<li>Zoom into images on the workbench for finer details.</li>
<li>Use the examples below as references for what types of images to upload.</li>
</ul>
<p style='font-size: 16px;'>Enjoy exploring! 🌟</p>
</div>
"""
TIPS = """
## Tips
- Zoom into images on the workbench for finer details.
- Use the examples below as references for what types of images to upload.
Enjoy exploring!
"""
CITATION = '''
πŸ“§ **Contact** <br>
If you have any questions, feel free to contact us at <b>[email protected]</b>.
'''
"""
πŸ“ **Citation**
cite using this bibtex:...
```
```
πŸ“‹ **License**
"""
def get_model(model_name):
if model_name=='Mummified 170':
n_classes = 170
model = get_triplet_model(input_shape = (600, 600, 3),
embedding_units = 256,
embedding_depth = 2,
backbone_class=tf.keras.applications.ResNet50V2,
nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2')
model.load_weights('model_classification/mummified-170.h5')
elif model_name=='Rock 170':
n_classes = 171
model = get_triplet_model(input_shape = (600, 600, 3),
embedding_units = 256,
embedding_depth = 2,
backbone_class=tf.keras.applications.ResNet50V2,
nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2')
model.load_weights('model_classification/rock-170.h5')
# elif model_name == 'Fossils 142': #BEiT
# n_classes = 142
# model = get_triplet_model_beit(input_shape = (384, 384, 3),
# embedding_units = 256,
# embedding_depth = 2,
# n_classes = n_classes)
# model.load_weights('model_classification/fossil-142.h5')
# elif model_name == 'Fossils new': # BEiT-v2
# n_classes = 142
# model = get_triplet_model_beit(input_shape = (384, 384, 3),
# embedding_units = 256,
# embedding_depth = 2,
# n_classes = n_classes)
# model.load_weights('model_classification/fossil-new.h5')
elif model_name == 'Fossils 142': # new resnet
n_classes = 142
model,_,_ = get_resnet_model('model_classification/fossil-model.h5')
else:
raise ValueError(f"Model name '{model_name}' is not recognized")
return model,n_classes
def segment_image(input_image):
img = segmentation_sam(input_image)
return img
def classify_image(input_image, model_name):
#segmented_image = segment_image(input_image)
if 'Rock 170' ==model_name:
from inference_resnet import inference_resnet_finer
model,n_classes= get_model(model_name)
result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes)
return result
elif 'Mummified 170' ==model_name:
from inference_resnet import inference_resnet_finer
model, n_classes= get_model(model_name)
result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes)
return result
elif 'Fossils BEiT' ==model_name:
from inference_beit import inference_resnet_finer_beit
model,n_classes = get_model(model_name)
result = inference_resnet_finer_beit(input_image,model,size=384,n_classes=n_classes)
return result
# elif 'Fossils new' ==model_name:
# from inference_beit import inference_resnet_finer_beit
# model,n_classes = get_model(model_name)
# result = inference_resnet_finer_beit(input_image,model,size=384,n_classes=n_classes)
# return result
elif 'Fossils 142' ==model_name:
from inference_beit import inference_resnet_finer_beit
model,n_classes = get_model(model_name)
result = inference_resnet_finer_v2(input_image,model,size=384,n_classes=n_classes)
return result
return None
def get_embeddings(input_image,model_name):
if 'Rock 170' ==model_name:
from inference_resnet import inference_resnet_embedding
model,n_classes= get_model(model_name)
result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes)
return result
elif 'Mummified 170' ==model_name:
from inference_resnet import inference_resnet_embedding
model, n_classes= get_model(model_name)
result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes)
return result
elif 'Fossils BEiT' ==model_name:
from inference_beit import inference_resnet_embedding_beit
model,n_classes = get_model(model_name)
result = inference_resnet_embedding_beit(input_image,model,size=384,n_classes=n_classes)
return result
# elif 'Fossils new' ==model_name:
# from inference_beit import inference_resnet_embedding_beit
# model,n_classes = get_model(model_name)
# result = inference_resnet_embedding_beit(input_image,model,size=384,n_classes=n_classes)
# return result
elif 'Fossils 142' ==model_name:
from inference_beit import inference_resnet_embedding_beit
model,n_classes = get_model(model_name)
result = inference_resnet_embedding_v2(input_image,model,size=384,n_classes=n_classes)
return result
return None
def find_closest(input_image,model_name):
embedding = get_embeddings(input_image,model_name)
classes, paths = get_images(embedding,model_name)
#outputs = classes+paths
return classes,paths
def generate_diagram_closest(input_image,model_name,top_k):
embedding = get_embeddings(input_image,model_name)
diagram_path = get_diagram(embedding,top_k,model_name)
return diagram_path
def explain_image(input_image,model_name,explain_method,nb_samples):
model,n_classes= get_model(model_name)
if model_name=='Fossils BEiT' or 'Fossils 142':
size = 384
else:
size = 600
#saliency, integrated, smoothgrad,
classes,exp_list = explain(model,input_image,explain_method,nb_samples,size = size, n_classes=n_classes)
#original = saliency + integrated + smoothgrad
print('done')
return classes,exp_list
def setup_examples():
paths = sorted(pathlib.Path('images/').rglob('*.jpg'))
samples = [path.as_posix() for path in paths if 'selected fossil examples' in str(path)][:23]
examples_fossils = gr.Examples(samples, inputs=input_image,examples_per_page=8,label='Fossils Examples from the dataset')
samples=[[path.as_posix()] for path in paths if 'leaves' in str(path) ][:19]
examples_leaves = gr.Examples(samples, inputs=input_image,examples_per_page=8,label='Leaves Examples from the dataset')
return examples_fossils,examples_leaves
def preprocess_image(image, output_size=(300, 300)):
#shape (height, width, channels)
h, w = image.shape[:2]
#padding
if h > w:
padding = (h - w) // 2
image_padded = cv2.copyMakeBorder(image, 0, 0, padding, padding, cv2.BORDER_CONSTANT, value=[0, 0, 0])
else:
padding = (w - h) // 2
image_padded = cv2.copyMakeBorder(image, padding, padding, 0, 0, cv2.BORDER_CONSTANT, value=[0, 0, 0])
# resize
image_resized = cv2.resize(image_padded, output_size, interpolation=cv2.INTER_AREA)
return image_resized
def increase_brightness(img, value=30):
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) # Convert to HSV
h, s, v = cv2.split(hsv)
lim = 255 - value
v[v > lim] = 255
v[v <= lim] += value
final_hsv = cv2.merge((h, s, v))
img_bright = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)
return img_bright
def update_display(image):
original_image = image
processed_image = preprocess_image(image)
instruction = "Image ready. Please switch to the 'Specimen Workbench' tab to check out further analysis and outputs."
model_name = "Fossils 142"
# gr.Dropdown(
# ["Mummified 170", "Rock 170","Fossils 142","Fossils new"],
# multiselect=False,
# value="Fossils new", # default option
# label="Model",
# interactive=True,
# info="Choose the model you'd like to use"
# )
explain_method = "Rise"
# gr.Dropdown(
# ["Sobol", "HSIC","Rise","Saliency"],
# multiselect=False,
# value="Rise", # default option
# label="Explain method",
# interactive=True,
# info="Choose one method to explain the model"
# )
sampling_size = 10
# gr.Slider(1, 5000, value=2000, label="Sampling Size in Rise",interactive=True,visible=True,
# info="Choose between 1 and 5000")
top_k = 50
# gr.Slider(10,200,value=50,label="Number of Closest Samples for Distribution Chart",interactive=True,info="Choose between 10 and 200")
class_predicted = None # gr.Label(label='Class Predicted',num_top_classes=10)
exp_gallery = None
# gr.Gallery(label="Explanation Heatmaps for top 5 predicted classes", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None)
closest_gallery = None
# gr.Gallery(label="Closest Images", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None)
diagram= None
# gr.Image(label = 'Bar Chart')
return original_image,processed_image,processed_image,instruction,model_name,explain_method,sampling_size,top_k,class_predicted,exp_gallery,closest_gallery,diagram
def update_slider_visibility(explain_method):
bool = explain_method=="Rise"
return {sampling_size: gr.Slider(1, 5000, value=2000, label="Sampling Size in Rise", visible=bool, interactive=True)}
#minimalist theme
with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
with gr.Tab(" Florrissant Fossils"):
gr.Markdown(HEADER)
with gr.Row():
with gr.Column():
gr.Markdown(USER_GUIDE)
with gr.Column(scale=2):
with gr.Column(scale=2):
instruction_text = gr.Textbox(label="Instructions", value="Upload/Choose an image and click 'Process Image'.")
input_image = gr.Image(label="Input",width="100%",container=True)
process_button = gr.Button("Process Image")
with gr.Column(scale=1):
examples_fossils,examples_leaves = setup_examples()
gr.Markdown(CITATION)
with gr.Tab("Specimen Workbench"):
with gr.Row():
with gr.Column():
original_image = gr.Image(visible = False)
workbench_image = gr.Image(label="Workbench Image")
classify_image_button = gr.Button("Classify Image")
# with gr.Column():
# #segmented_image = gr.outputs.Image(label="SAM output",type='numpy')
# segmented_image=gr.Image(label="Segmented Image", type='numpy')
# segment_button = gr.Button("Segment Image")
# #classify_segmented_button = gr.Button("Classify Segmented Image")
with gr.Column():
model_name = gr.Dropdown(
["Fossils 142"],#"Mummified 170", "Rock 170","Fossils BEiT" removed
multiselect=False,
value="Fossils 142", # default option
label="Model",
interactive=True,
info="Choose the model you'd like to use"
)
explain_method = gr.Dropdown(
["Sobol", "HSIC","Rise","Saliency"],
multiselect=False,
value="Rise", # default option
label="Explain method",
interactive=True,
info="Choose one method to explain the model"
)
# explain_method = gr.CheckboxGroup(["Sobol", "HSIC","Rise","Saliency"],
# label="explain method",
# value="Rise",
# multiselect=False,
# interactive=True,)
sampling_size = gr.Slider(10, 3000, value=10, label="Sampling Size in Rise",interactive=True,visible=True,
info="Choose between 10 and 3000")
top_k = gr.Slider(10,200,value=50,label="Number of Closest Samples for Distribution Chart",interactive=True,info="Choose between 10 and 200")
explain_method.change(
fn=update_slider_visibility,
inputs=explain_method,
outputs=sampling_size
)
with gr.Row():
with gr.Column(scale=1):
class_predicted = gr.Label(label='Plant Family Predicted',num_top_classes=10)
with gr.Column(scale=4):
with gr.Accordion("Explanations "):
gr.Markdown("Computing Explanations from the model for Top 5 Predicted Plant Families")
with gr.Column():
with gr.Row():
#original_input = gr.Image(label="Original Frame")
#saliency = gr.Image(label="saliency")
#gradcam = gr.Image(label='integraged gradients')
#guided_gradcam = gr.Image(label='gradcam')
#guided_backprop = gr.Image(label='guided backprop')
# exp1 = gr.Image(label = 'Class_name1')
# exp2= gr.Image(label = 'Class_name2')
# exp3= gr.Image(label = 'Class_name3')
# exp4= gr.Image(label = 'Class_name4')
# exp5= gr.Image(label = 'Class_name5')
exp_gallery = gr.Gallery(label="Explanation Heatmaps for top 5 predicted classes", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None)
generate_explanations = gr.Button("Generate Explanations")
# with gr.Accordion('Closest Images'):
# gr.Markdown("Finding the closest images in the dataset")
# with gr.Row():
# with gr.Column():
# label_closest_image_0 = gr.Markdown('')
# closest_image_0 = gr.Image(label='Closest Image',image_mode='contain',width=200, height=200)
# with gr.Column():
# label_closest_image_1 = gr.Markdown('')
# closest_image_1 = gr.Image(label='Second Closest Image',image_mode='contain',width=200, height=200)
# with gr.Column():
# label_closest_image_2 = gr.Markdown('')
# closest_image_2 = gr.Image(label='Third Closest Image',image_mode='contain',width=200, height=200)
# with gr.Column():
# label_closest_image_3 = gr.Markdown('')
# closest_image_3 = gr.Image(label='Forth Closest Image',image_mode='contain', width=200, height=200)
# with gr.Column():
# label_closest_image_4 = gr.Markdown('')
# closest_image_4 = gr.Image(label='Fifth Closest Image',image_mode='contain',width=200, height=200)
# find_closest_btn = gr.Button("Find Closest Images")
with gr.Accordion('Closest Fossil Images'):
gr.Markdown("Finding 5 closest images in the dataset")
with gr.Row():
closest_gallery = gr.Gallery(label="Closest Images", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None)
#.style(grid=[1, 5], height=200, width=200)
find_closest_btn = gr.Button("Find Closest Images")
#segment_button.click(segment_image, inputs=input_image, outputs=segmented_image)
classify_image_button.click(classify_image, inputs=[original_image,model_name], outputs=class_predicted)
# generate_exp.click(exp_image, inputs=[input_image,model_name,explain_method,sampling_size], outputs=[exp1,exp2,exp3,exp4,exp5]) #
# with gr.Accordion('Closest Leaves Images'):
# gr.Markdown("5 closest leaves")
with gr.Accordion("Family Distribution of Closest Samples "):
gr.Markdown("Visualize plant family distribution of top-k closest samples in our dataset")
with gr.Column():
with gr.Row():
diagram= gr.Image(label = 'Bar Chart')
generate_diagram = gr.Button("Generate Diagram")
# with gr.Accordion("Using Diffuser"):
# with gr.Column():
# prompt = gr.Textbox(lines=1, label="Prompt")
# output_image = gr.Image(label="Output")
# generate_button = gr.Button("Generate Leave")
# with gr.Column():
# class_predicted2 = gr.Label(label='Class Predicted from diffuser')
# classify_button = gr.Button("Classify Image")
def update_exp_outputs(input_image,model_name,explain_method,nb_samples):
labels, images = explain_image(input_image,model_name,explain_method,nb_samples)
#labels_html = "".join([f'<div style="display: inline-block; text-align: center; width: 18%;">{label}</div>' for label in labels])
#labels_markdown = f"<div style='width: 100%; text-align: center;'>{labels_html}</div>"
image_caption=[]
for i in range(5):
image_caption.append((images[i],"Predicted Plant Family "+str(i)+": "+labels[i]))
return image_caption
generate_explanations.click(fn=update_exp_outputs, inputs=[original_image,model_name,explain_method,sampling_size], outputs=[exp_gallery])
#find_closest_btn.click(find_closest, inputs=[input_image,model_name], outputs=[label_closest_image_0,label_closest_image_1,label_closest_image_2,label_closest_image_3,label_closest_image_4,closest_image_0,closest_image_1,closest_image_2,closest_image_3,closest_image_4])
def update_closest_outputs(input_image,model_name):
labels, images = find_closest(input_image,model_name)
#labels_html = "".join([f'<div style="display: inline-block; text-align: center; width: 18%;">{label}</div>' for label in labels])
#labels_markdown = f"<div style='width: 100%; text-align: center;'>{labels_html}</div>"
image_caption=[]
for i in range(5):
image_caption.append((images[i],labels[i]))
return image_caption
find_closest_btn.click(fn=update_closest_outputs, inputs=[original_image,model_name], outputs=[closest_gallery])
#classify_segmented_button.click(classify_image, inputs=[segmented_image,model_name], outputs=class_predicted)
generate_diagram.click(generate_diagram_closest, inputs=[original_image,model_name,top_k], outputs=diagram)
process_button.click(
fn=update_display,
inputs=input_image,
outputs=[original_image,input_image,workbench_image,instruction_text,model_name,explain_method,sampling_size,top_k,class_predicted,exp_gallery,closest_gallery,diagram]
)
demo.queue() # manage multiple incoming requests
if os.getenv('SYSTEM') == 'spaces':
demo.launch(width='40%')
#,auth=(os.environ.get('USERNAME'), os.environ.get('PASSWORD'))
else:
demo.launch()