andy-wyx commited on
Commit
86104a0
·
1 Parent(s): dd58475

feat:workbench page

Browse files
Files changed (4) hide show
  1. app.py +268 -95
  2. closest_sample.py +56 -3
  3. explanations.py +47 -20
  4. inference_resnet.py +1 -1
app.py CHANGED
@@ -21,7 +21,8 @@ from inference_resnet import get_triplet_model
21
  from inference_beit import get_triplet_model_beit
22
  import pathlib
23
  import tensorflow as tf
24
- from closest_sample import get_images
 
25
 
26
  if not os.path.exists('images'):
27
  REPO_ID='Serrelab/image_examples_gradio'
@@ -35,6 +36,57 @@ if not os.path.exists('dataset'):
35
  print("warning! A read token in env variables is needed for authentication.")
36
  snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset')
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def get_model(model_name):
39
 
40
 
@@ -61,6 +113,13 @@ def get_model(model_name):
61
  embedding_depth = 2,
62
  n_classes = n_classes)
63
  model.load_weights('model_classification/fossil-142.h5')
 
 
 
 
 
 
 
64
  else:
65
  raise ValueError(f"Model name '{model_name}' is not recognized")
66
  return model,n_classes
@@ -82,7 +141,12 @@ def classify_image(input_image, model_name):
82
  model, n_classes= get_model(model_name)
83
  result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes)
84
  return result
85
- if 'Fossils 142' ==model_name:
 
 
 
 
 
86
  from inference_beit import inference_resnet_finer_beit
87
  model,n_classes = get_model(model_name)
88
  result = inference_resnet_finer_beit(input_image,model,size=384,n_classes=n_classes)
@@ -100,7 +164,12 @@ def get_embeddings(input_image,model_name):
100
  model, n_classes= get_model(model_name)
101
  result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes)
102
  return result
103
- if 'Fossils 142' ==model_name:
 
 
 
 
 
104
  from inference_beit import inference_resnet_embedding_beit
105
  model,n_classes = get_model(model_name)
106
  result = inference_resnet_embedding_beit(input_image,model,size=384,n_classes=n_classes)
@@ -114,30 +183,103 @@ def find_closest(input_image,model_name):
114
  #outputs = classes+paths
115
  return classes,paths
116
 
117
- def explain_image(input_image,model_name):
 
 
 
 
 
118
  model,n_classes= get_model(model_name)
119
- if model_name=='Fossils 142':
120
  size = 384
121
  else:
122
  size = 600
123
  #saliency, integrated, smoothgrad,
124
- exp_list = explain(model,input_image,size = size, n_classes=n_classes)
125
  #original = saliency + integrated + smoothgrad
126
  print('done')
127
- sobol1,sobol2,sobol3,sobol4,sobol5 = exp_list[0],exp_list[1],exp_list[2],exp_list[3],exp_list[4]
128
- rise1,rise2,rise3,rise4,rise5 = exp_list[5],exp_list[6],exp_list[7],exp_list[8],exp_list[9]
129
- hsic1,hsic2,hsic3,hsic4,hsic5 = exp_list[10],exp_list[11],exp_list[12],exp_list[13],exp_list[14]
130
- saliency1,saliency2,saliency3,saliency4,saliency5 = exp_list[15],exp_list[16],exp_list[17],exp_list[18],exp_list[19]
131
- return sobol1,sobol2,sobol3,sobol4,sobol5,rise1,rise2,rise3,rise4,rise5,hsic1,hsic2,hsic3,hsic4,hsic5,saliency1,saliency2,saliency3,saliency4,saliency5
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  #minimalist theme
134
  with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
135
 
136
  with gr.Tab(" Florrissant Fossils"):
137
-
138
  with gr.Row():
139
  with gr.Column():
140
- input_image = gr.Image(label="Input")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  classify_image_button = gr.Button("Classify Image")
142
 
143
  # with gr.Column():
@@ -148,21 +290,101 @@ with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
148
 
149
  with gr.Column():
150
  model_name = gr.Dropdown(
151
- ["Mummified 170", "Rock 170","Fossils 142"],
152
  multiselect=False,
153
- value="Fossils 142", # default option
154
  label="Model",
155
  interactive=True,
 
 
 
 
 
 
 
 
 
156
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  class_predicted = gr.Label(label='Class Predicted',num_top_classes=10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
- with gr.Row():
160
-
161
- paths = sorted(pathlib.Path('images/').rglob('*.jpg'))
162
- samples=[[path.as_posix()] for path in paths if 'fossils' in str(path) ][:19]
163
- examples_fossils = gr.Examples(samples, inputs=input_image,examples_per_page=10,label='Fossils Examples from the dataset')
164
- samples=[[path.as_posix()] for path in paths if 'leaves' in str(path) ][:19]
165
- examples_leaves = gr.Examples(samples, inputs=input_image,examples_per_page=5,label='Leaves Examples from the dataset')
166
 
167
  # with gr.Accordion("Using Diffuser"):
168
  # with gr.Column():
@@ -173,80 +395,20 @@ with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
173
  # class_predicted2 = gr.Label(label='Class Predicted from diffuser')
174
  # classify_button = gr.Button("Classify Image")
175
 
176
-
177
- with gr.Accordion("Explanations "):
178
- gr.Markdown("Computing Explanations from the model")
179
- with gr.Column():
180
- with gr.Row():
181
-
182
- #original_input = gr.Image(label="Original Frame")
183
- #saliency = gr.Image(label="saliency")
184
- #gradcam = gr.Image(label='integraged gradients')
185
- #guided_gradcam = gr.Image(label='gradcam')
186
- #guided_backprop = gr.Image(label='guided backprop')
187
- sobol1 = gr.Image(label = 'Sobol1')
188
- sobol2= gr.Image(label = 'Sobol2')
189
- sobol3= gr.Image(label = 'Sobol3')
190
- sobol4= gr.Image(label = 'Sobol4')
191
- sobol5= gr.Image(label = 'Sobol5')
192
-
193
- with gr.Row():
194
- rise1 = gr.Image(label = 'Rise1')
195
- rise2 = gr.Image(label = 'Rise2')
196
- rise3 = gr.Image(label = 'Rise3')
197
- rise4 = gr.Image(label = 'Rise4')
198
- rise5 = gr.Image(label = 'Rise5')
199
-
200
- with gr.Row():
201
- hsic1 = gr.Image(label = 'HSIC1')
202
- hsic2 = gr.Image(label = 'HSIC2')
203
- hsic3 = gr.Image(label = 'HSIC3')
204
- hsic4 = gr.Image(label = 'HSIC4')
205
- hsic5 = gr.Image(label = 'HSIC5')
206
-
207
- with gr.Row():
208
- saliency1 = gr.Image(label = 'Saliency1')
209
- saliency2 = gr.Image(label = 'Saliency2')
210
- saliency3 = gr.Image(label = 'Saliency3')
211
- saliency4 = gr.Image(label = 'Saliency4')
212
- saliency5 = gr.Image(label = 'Saliency5')
213
-
214
-
215
- generate_explanations = gr.Button("Generate Explanations")
216
-
217
- # with gr.Accordion('Closest Images'):
218
- # gr.Markdown("Finding the closest images in the dataset")
219
- # with gr.Row():
220
- # with gr.Column():
221
- # label_closest_image_0 = gr.Markdown('')
222
- # closest_image_0 = gr.Image(label='Closest Image',image_mode='contain',width=200, height=200)
223
- # with gr.Column():
224
- # label_closest_image_1 = gr.Markdown('')
225
- # closest_image_1 = gr.Image(label='Second Closest Image',image_mode='contain',width=200, height=200)
226
- # with gr.Column():
227
- # label_closest_image_2 = gr.Markdown('')
228
- # closest_image_2 = gr.Image(label='Third Closest Image',image_mode='contain',width=200, height=200)
229
- # with gr.Column():
230
- # label_closest_image_3 = gr.Markdown('')
231
- # closest_image_3 = gr.Image(label='Forth Closest Image',image_mode='contain', width=200, height=200)
232
- # with gr.Column():
233
- # label_closest_image_4 = gr.Markdown('')
234
- # closest_image_4 = gr.Image(label='Fifth Closest Image',image_mode='contain',width=200, height=200)
235
- # find_closest_btn = gr.Button("Find Closest Images")
236
- with gr.Accordion('Closest Images'):
237
- gr.Markdown("Finding the closest images in the dataset")
238
-
239
- with gr.Row():
240
- gallery = gr.Gallery(label="Closest Images", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None)
241
- #.style(grid=[1, 5], height=200, width=200)
242
 
243
- find_closest_btn = gr.Button("Find Closest Images")
244
-
245
- #segment_button.click(segment_image, inputs=input_image, outputs=segmented_image)
246
- classify_image_button.click(classify_image, inputs=[input_image,model_name], outputs=class_predicted)
247
- generate_explanations.click(explain_image, inputs=[input_image,model_name], outputs=[sobol1,sobol2,sobol3,sobol4,sobol5,rise1,rise2,rise3,rise4,rise5,hsic1,hsic2,hsic3,hsic4,hsic5,saliency1,saliency2,saliency3,saliency4,saliency5]) #
 
 
 
 
 
 
248
  #find_closest_btn.click(find_closest, inputs=[input_image,model_name], outputs=[label_closest_image_0,label_closest_image_1,label_closest_image_2,label_closest_image_3,label_closest_image_4,closest_image_0,closest_image_1,closest_image_2,closest_image_3,closest_image_4])
249
- def update_outputs(input_image,model_name):
250
  labels, images = find_closest(input_image,model_name)
251
  #labels_html = "".join([f'<div style="display: inline-block; text-align: center; width: 18%;">{label}</div>' for label in labels])
252
  #labels_markdown = f"<div style='width: 100%; text-align: center;'>{labels_html}</div>"
@@ -255,8 +417,19 @@ with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
255
  image_caption.append((images[i],labels[i]))
256
  return image_caption
257
 
258
- find_closest_btn.click(fn=update_outputs, inputs=[input_image,model_name], outputs=[gallery])
259
  #classify_segmented_button.click(classify_image, inputs=[segmented_image,model_name], outputs=class_predicted)
 
 
 
 
 
 
 
 
 
 
 
260
 
261
  demo.queue() # manage multiple incoming requests
262
 
 
21
  from inference_beit import get_triplet_model_beit
22
  import pathlib
23
  import tensorflow as tf
24
+ from closest_sample import get_images,get_diagram
25
+
26
 
27
  if not os.path.exists('images'):
28
  REPO_ID='Serrelab/image_examples_gradio'
 
36
  print("warning! A read token in env variables is needed for authentication.")
37
  snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset')
38
 
39
+ HEADER = '''
40
+ <h2><b>Official Gradio Demo</b></h2><h2><a href='https://huggingface.co/spaces/Serrelab/fossil_app' target='_blank'><b>Identifying Florissant Leaf Fossils to Family using Deep Neural Networks </b></a></h2>
41
+ Code: <a href='https://github.com/orgs/serre-lab/projects/2' target='_blank'>GitHub</a>. Paper: <a href='' target='_blank'>ArXiv</a>.
42
+
43
+
44
+ '''
45
+
46
+ """
47
+ **Fossil** a brief intro to the project.
48
+ # ❗️❗️❗️**Important Notes:**
49
+ # - some notes to users some notes to users some notes to users some notes to users some notes to users some notes to users .
50
+ # - some notes to users some notes to users some notes to users some notes to users some notes to users some notes to users.
51
+
52
+ """
53
+
54
+ USER_GUIDE = """
55
+ <div style='background-color: #f0f0f0; padding: 20px; border-radius: 10px;'>
56
+ <h2>❗️ User Guide</h2>
57
+ <p>Welcome to the interactive fossil exploration tool. Here's how to get started:</p>
58
+ <ul>
59
+ <li><strong>Upload an Image:</strong> Drag and drop or choose from given samples to upload images of fossils.</li>
60
+ <li><strong>Process Image:</strong> After uploading, click the 'Process Image' button to analyze the image.</li>
61
+ <li><strong>Explore Results:</strong> Switch to the 'Workbench' tab to check out detailed analysis and results.</li>
62
+ </ul>
63
+ <h3>Tips</h3>
64
+ <ul>
65
+ <li>Zoom into images on the workbench for finer details.</li>
66
+ <li>Use the examples below as references for what types of images to upload.</li>
67
+ </ul>
68
+ <p>Enjoy exploring! 🌟</p>
69
+ </div>
70
+ """
71
+
72
+ TIPS = """
73
+ ## Tips
74
+ - Zoom into images on the workbench for finer details.
75
+ - Use the examples below as references for what types of images to upload.
76
+
77
+ Enjoy exploring!
78
+ """
79
+ CITATION = '''
80
+ 📧 **Contact** <br>
81
+ If you have any questions, feel free to contact us at <b>[email protected]</b>.
82
+ '''
83
+ """
84
+ 📝 **Citation**
85
+ cite using this bibtex:...
86
+ ```
87
+ ```
88
+ 📋 **License**
89
+ """
90
  def get_model(model_name):
91
 
92
 
 
113
  embedding_depth = 2,
114
  n_classes = n_classes)
115
  model.load_weights('model_classification/fossil-142.h5')
116
+ elif model_name == 'Fossils new':
117
+ n_classes = 142
118
+ model = get_triplet_model_beit(input_shape = (384, 384, 3),
119
+ embedding_units = 256,
120
+ embedding_depth = 2,
121
+ n_classes = n_classes)
122
+ model.load_weights('model_classification/fossil-new.h5')
123
  else:
124
  raise ValueError(f"Model name '{model_name}' is not recognized")
125
  return model,n_classes
 
141
  model, n_classes= get_model(model_name)
142
  result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes)
143
  return result
144
+ elif 'Fossils 142' ==model_name:
145
+ from inference_beit import inference_resnet_finer_beit
146
+ model,n_classes = get_model(model_name)
147
+ result = inference_resnet_finer_beit(input_image,model,size=384,n_classes=n_classes)
148
+ return result
149
+ elif 'Fossils new' ==model_name:
150
  from inference_beit import inference_resnet_finer_beit
151
  model,n_classes = get_model(model_name)
152
  result = inference_resnet_finer_beit(input_image,model,size=384,n_classes=n_classes)
 
164
  model, n_classes= get_model(model_name)
165
  result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes)
166
  return result
167
+ elif 'Fossils 142' ==model_name:
168
+ from inference_beit import inference_resnet_embedding_beit
169
+ model,n_classes = get_model(model_name)
170
+ result = inference_resnet_embedding_beit(input_image,model,size=384,n_classes=n_classes)
171
+ return result
172
+ elif 'Fossils new' ==model_name:
173
  from inference_beit import inference_resnet_embedding_beit
174
  model,n_classes = get_model(model_name)
175
  result = inference_resnet_embedding_beit(input_image,model,size=384,n_classes=n_classes)
 
183
  #outputs = classes+paths
184
  return classes,paths
185
 
186
+ def generate_diagram_closest(input_image,model_name,top_k):
187
+ embedding = get_embeddings(input_image,model_name)
188
+ diagram_path = get_diagram(embedding,top_k)
189
+ return diagram_path
190
+
191
+ def explain_image(input_image,model_name,explain_method,nb_samples):
192
  model,n_classes= get_model(model_name)
193
+ if model_name=='Fossils 142' or 'Fossils new':
194
  size = 384
195
  else:
196
  size = 600
197
  #saliency, integrated, smoothgrad,
198
+ classes,exp_list = explain(model,input_image,explain_method,nb_samples,size = size, n_classes=n_classes)
199
  #original = saliency + integrated + smoothgrad
200
  print('done')
 
 
 
 
 
201
 
202
+ return classes,exp_list
203
+
204
+ def setup_examples():
205
+ paths = sorted(pathlib.Path('images/').rglob('*.jpg'))
206
+ samples = [path.as_posix() for path in paths if 'fossils' in str(path)][:19]
207
+ examples_fossils = gr.Examples(samples, inputs=input_image,examples_per_page=5,label='Fossils Examples from the dataset')
208
+ samples=[[path.as_posix()] for path in paths if 'leaves' in str(path) ][:19]
209
+ examples_leaves = gr.Examples(samples, inputs=input_image,examples_per_page=5,label='Leaves Examples from the dataset')
210
+ return examples_fossils,examples_leaves
211
+
212
+ def preprocess_image(image, output_size=(300, 300)):
213
+ #shape (height, width, channels)
214
+ h, w = image.shape[:2]
215
+
216
+ #padding
217
+ if h > w:
218
+ padding = (h - w) // 2
219
+ image_padded = cv2.copyMakeBorder(image, 0, 0, padding, padding, cv2.BORDER_CONSTANT, value=[0, 0, 0])
220
+ else:
221
+ padding = (w - h) // 2
222
+ image_padded = cv2.copyMakeBorder(image, padding, padding, 0, 0, cv2.BORDER_CONSTANT, value=[0, 0, 0])
223
+
224
+ # resize
225
+ image_resized = cv2.resize(image_padded, output_size, interpolation=cv2.INTER_AREA)
226
+
227
+ return image_resized
228
+
229
+ def update_display(image):
230
+ processed_image = preprocess_image(image)
231
+ instruction = "Image ready. Please switch to the 'Specimen Workbench' tab to check out further analysis and outputs."
232
+ model_name = gr.Dropdown(
233
+ ["Mummified 170", "Rock 170","Fossils 142","Fossils new"],
234
+ multiselect=False,
235
+ value="Fossils new", # default option
236
+ label="Model",
237
+ interactive=True,
238
+ info="Choose the model you'd like to use"
239
+ )
240
+ explain_method = gr.Dropdown(
241
+ ["Sobol", "HSIC","Rise","Saliency"],
242
+ multiselect=False,
243
+ value="Rise", # default option
244
+ label="Explain method",
245
+ interactive=True,
246
+ info="Choose one method to explain the model"
247
+ )
248
+ sampling_size = gr.Slider(1, 5000, value=2000, label="Sampling Size in Rise",interactive=True,visible=True,
249
+ info="Choose between 1 and 5000")
250
+
251
+ top_k = gr.Slider(10,200,value=50,label="Number of Closest Samples for Distribution Chart",interactive=True,info="Choose between 10 and 200")
252
+ class_predicted = gr.Label(label='Class Predicted',num_top_classes=10)
253
+ exp_gallery = gr.Gallery(label="Explanation Heatmaps for top 5 predicted classes", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None)
254
+ closest_gallery = gr.Gallery(label="Closest Images", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None)
255
+ diagram= gr.Image(label = 'Bar Chart')
256
+ return processed_image,processed_image,instruction,model_name,explain_method,sampling_size,top_k,class_predicted,exp_gallery,closest_gallery,diagram
257
+ def update_slider_visibility(explain_method):
258
+ bool = explain_method=="Rise"
259
+ return {sampling_size: gr.Slider(1, 5000, value=2000, label="Sampling Size in Rise", visible=bool, interactive=True)}
260
+
261
  #minimalist theme
262
  with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
263
 
264
  with gr.Tab(" Florrissant Fossils"):
265
+ gr.Markdown(HEADER)
266
  with gr.Row():
267
  with gr.Column():
268
+ gr.Markdown(USER_GUIDE)
269
+ with gr.Column(scale=2):
270
+ with gr.Column(scale=2):
271
+ instruction_text = gr.Textbox(label="Instructions", value="Upload/Choose an image and click 'Process Image'.")
272
+ input_image = gr.Image(label="Input",width="100%",container=True)
273
+ process_button = gr.Button("Process Image")
274
+ with gr.Column(scale=1):
275
+ examples_fossils,examples_leaves = setup_examples()
276
+
277
+ gr.Markdown(CITATION)
278
+
279
+ with gr.Tab("Specimen Workbench"):
280
+ with gr.Row():
281
+ with gr.Column():
282
+ workbench_image = gr.Image(label="Workbench Image")
283
  classify_image_button = gr.Button("Classify Image")
284
 
285
  # with gr.Column():
 
290
 
291
  with gr.Column():
292
  model_name = gr.Dropdown(
293
+ ["Mummified 170", "Rock 170","Fossils 142","Fossils new"],
294
  multiselect=False,
295
+ value="Fossils new", # default option
296
  label="Model",
297
  interactive=True,
298
+ info="Choose the model you'd like to use"
299
+ )
300
+ explain_method = gr.Dropdown(
301
+ ["Sobol", "HSIC","Rise","Saliency"],
302
+ multiselect=False,
303
+ value="Rise", # default option
304
+ label="Explain method",
305
+ interactive=True,
306
+ info="Choose one method to explain the model"
307
  )
308
+ # explain_method = gr.CheckboxGroup(["Sobol", "HSIC","Rise","Saliency"],
309
+ # label="explain method",
310
+ # value="Rise",
311
+ # multiselect=False,
312
+ # interactive=True,)
313
+ sampling_size = gr.Slider(1, 5000, value=2000, label="Sampling Size in Rise",interactive=True,visible=True,
314
+ info="Choose between 1 and 5000")
315
+
316
+ top_k = gr.Slider(10,200,value=50,label="Number of Closest Samples for Distribution Chart",interactive=True,info="Choose between 10 and 200")
317
+ explain_method.change(
318
+ fn=update_slider_visibility,
319
+ inputs=explain_method,
320
+ outputs=sampling_size
321
+ )
322
+ with gr.Row():
323
+ with gr.Column(scale=1):
324
  class_predicted = gr.Label(label='Class Predicted',num_top_classes=10)
325
+ with gr.Column(scale=4):
326
+ with gr.Accordion("Explanations "):
327
+ gr.Markdown("Computing Explanations from the model")
328
+ with gr.Column():
329
+ with gr.Row():
330
+
331
+ #original_input = gr.Image(label="Original Frame")
332
+ #saliency = gr.Image(label="saliency")
333
+ #gradcam = gr.Image(label='integraged gradients')
334
+ #guided_gradcam = gr.Image(label='gradcam')
335
+ #guided_backprop = gr.Image(label='guided backprop')
336
+ # exp1 = gr.Image(label = 'Class_name1')
337
+ # exp2= gr.Image(label = 'Class_name2')
338
+ # exp3= gr.Image(label = 'Class_name3')
339
+ # exp4= gr.Image(label = 'Class_name4')
340
+ # exp5= gr.Image(label = 'Class_name5')
341
+
342
+ exp_gallery = gr.Gallery(label="Explanation Heatmaps for top 5 predicted classes", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None)
343
+
344
+ generate_explanations = gr.Button("Generate Explanations")
345
+
346
+ # with gr.Accordion('Closest Images'):
347
+ # gr.Markdown("Finding the closest images in the dataset")
348
+ # with gr.Row():
349
+ # with gr.Column():
350
+ # label_closest_image_0 = gr.Markdown('')
351
+ # closest_image_0 = gr.Image(label='Closest Image',image_mode='contain',width=200, height=200)
352
+ # with gr.Column():
353
+ # label_closest_image_1 = gr.Markdown('')
354
+ # closest_image_1 = gr.Image(label='Second Closest Image',image_mode='contain',width=200, height=200)
355
+ # with gr.Column():
356
+ # label_closest_image_2 = gr.Markdown('')
357
+ # closest_image_2 = gr.Image(label='Third Closest Image',image_mode='contain',width=200, height=200)
358
+ # with gr.Column():
359
+ # label_closest_image_3 = gr.Markdown('')
360
+ # closest_image_3 = gr.Image(label='Forth Closest Image',image_mode='contain', width=200, height=200)
361
+ # with gr.Column():
362
+ # label_closest_image_4 = gr.Markdown('')
363
+ # closest_image_4 = gr.Image(label='Fifth Closest Image',image_mode='contain',width=200, height=200)
364
+ # find_closest_btn = gr.Button("Find Closest Images")
365
+ with gr.Accordion('Closest Fossil Images'):
366
+ gr.Markdown("Finding the closest images in the dataset")
367
+
368
+ with gr.Row():
369
+ closest_gallery = gr.Gallery(label="Closest Images", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None)
370
+ #.style(grid=[1, 5], height=200, width=200)
371
+
372
+ find_closest_btn = gr.Button("Find Closest Images")
373
+
374
+ #segment_button.click(segment_image, inputs=input_image, outputs=segmented_image)
375
+ classify_image_button.click(classify_image, inputs=[input_image,model_name], outputs=class_predicted)
376
+ # generate_exp.click(exp_image, inputs=[input_image,model_name,explain_method,sampling_size], outputs=[exp1,exp2,exp3,exp4,exp5]) #
377
+ with gr.Accordion('Closest Leaves Images'):
378
+ gr.Markdown("5 closest leaves")
379
+ with gr.Accordion("Class Distribution of Closest Samples "):
380
+ gr.Markdown("Visualize class distribution of top-k closest samples in our dataset")
381
+ with gr.Column():
382
+ with gr.Row():
383
+ diagram= gr.Image(label = 'Bar Chart')
384
+
385
+ generate_diagram = gr.Button("Generate Diagram")
386
+
387
 
 
 
 
 
 
 
 
388
 
389
  # with gr.Accordion("Using Diffuser"):
390
  # with gr.Column():
 
395
  # class_predicted2 = gr.Label(label='Class Predicted from diffuser')
396
  # classify_button = gr.Button("Classify Image")
397
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
 
399
+ def update_exp_outputs(input_image,model_name,explain_method,nb_samples):
400
+ labels, images = explain_image(input_image,model_name,explain_method,nb_samples)
401
+ #labels_html = "".join([f'<div style="display: inline-block; text-align: center; width: 18%;">{label}</div>' for label in labels])
402
+ #labels_markdown = f"<div style='width: 100%; text-align: center;'>{labels_html}</div>"
403
+ image_caption=[]
404
+ for i in range(5):
405
+ image_caption.append((images[i],"Predicted Class "+str(i)+": "+labels[i]))
406
+ return image_caption
407
+
408
+ generate_explanations.click(fn=update_exp_outputs, inputs=[input_image,model_name,explain_method,sampling_size], outputs=[exp_gallery])
409
+
410
  #find_closest_btn.click(find_closest, inputs=[input_image,model_name], outputs=[label_closest_image_0,label_closest_image_1,label_closest_image_2,label_closest_image_3,label_closest_image_4,closest_image_0,closest_image_1,closest_image_2,closest_image_3,closest_image_4])
411
+ def update_closest_outputs(input_image,model_name):
412
  labels, images = find_closest(input_image,model_name)
413
  #labels_html = "".join([f'<div style="display: inline-block; text-align: center; width: 18%;">{label}</div>' for label in labels])
414
  #labels_markdown = f"<div style='width: 100%; text-align: center;'>{labels_html}</div>"
 
417
  image_caption.append((images[i],labels[i]))
418
  return image_caption
419
 
420
+ find_closest_btn.click(fn=update_closest_outputs, inputs=[input_image,model_name], outputs=[closest_gallery])
421
  #classify_segmented_button.click(classify_image, inputs=[segmented_image,model_name], outputs=class_predicted)
422
+
423
+ generate_diagram.click(generate_diagram_closest, inputs=[input_image,model_name,top_k], outputs=diagram)
424
+
425
+ process_button.click(
426
+ fn=update_display,
427
+ inputs=input_image,
428
+ outputs=[input_image,workbench_image,instruction_text,model_name,explain_method,sampling_size,top_k,class_predicted,exp_gallery,closest_gallery,diagram]
429
+ )
430
+
431
+
432
+
433
 
434
  demo.queue() # manage multiple incoming requests
435
 
closest_sample.py CHANGED
@@ -5,6 +5,8 @@ import pandas as pd
5
  import os
6
  from huggingface_hub import snapshot_download
7
  import requests
 
 
8
 
9
 
10
  pca_fossils = pk.load(open('pca_fossils_170_finer.pkl','rb'))
@@ -23,7 +25,7 @@ embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
23
 
24
  fossils_pd= pd.read_csv('fossils_paths.csv')
25
 
26
- def pca_distance(pca,sample,embedding):
27
  """
28
  Args:
29
  pca:fitted PCA model
@@ -35,7 +37,7 @@ def pca_distance(pca,sample,embedding):
35
  s = pca.transform(sample.reshape(1,-1))
36
  all = pca.transform(embedding[:,-1])
37
  distances = np.linalg.norm(all - s, axis=1)
38
- return np.argsort(distances)[:5]
39
 
40
  def return_paths(argsorted,files):
41
  paths= []
@@ -56,7 +58,7 @@ def get_images(embedding):
56
 
57
  #pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
58
 
59
- pca_d =pca_distance(pca_fossils,embedding,embedding_fossils)
60
 
61
  fossils_paths = fossils_pd['file_name'].values
62
 
@@ -87,3 +89,54 @@ def get_images(embedding):
87
  # '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths]
88
 
89
  return classes, local_paths
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import os
6
  from huggingface_hub import snapshot_download
7
  import requests
8
+ import matplotlib.pyplot as plt
9
+ from collections import Counter
10
 
11
 
12
  pca_fossils = pk.load(open('pca_fossils_170_finer.pkl','rb'))
 
25
 
26
  fossils_pd= pd.read_csv('fossils_paths.csv')
27
 
28
+ def pca_distance(pca,sample,embedding,top_k):
29
  """
30
  Args:
31
  pca:fitted PCA model
 
37
  s = pca.transform(sample.reshape(1,-1))
38
  all = pca.transform(embedding[:,-1])
39
  distances = np.linalg.norm(all - s, axis=1)
40
+ return np.argsort(distances)[:top_k]
41
 
42
  def return_paths(argsorted,files):
43
  paths= []
 
58
 
59
  #pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
60
 
61
+ pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=5)
62
 
63
  fossils_paths = fossils_pd['file_name'].values
64
 
 
89
  # '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths]
90
 
91
  return classes, local_paths
92
+
93
+ def get_diagram(embedding,top_k):
94
+
95
+ #pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
96
+
97
+ pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=top_k)
98
+
99
+ fossils_paths = fossils_pd['file_name'].values
100
+
101
+ paths = return_paths(pca_d,fossils_paths)
102
+ #print(paths)
103
+
104
+ folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/'
105
+ folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/'
106
+
107
+ classes = []
108
+ for i, path in enumerate(paths):
109
+ local_file_path = f'image_{i}.jpg'
110
+ if 'Florissant_Fossil/512/full/jpg/' in path:
111
+ public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant)
112
+ elif 'General_Fossil/512/full/jpg/' in path:
113
+ public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general)
114
+ else:
115
+ print("no match found")
116
+ print(public_path)
117
+ #download_public_image(public_path, local_file_path)
118
+ parts = [part for part in public_path.split('/') if part]
119
+ part = parts[-2]
120
+ classes.append(part)
121
+ #local_paths.append(local_file_path)
122
+ #paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/',
123
+ # '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths]
124
+ class_counts = Counter(classes)
125
+
126
+ sorted_class_counts = sorted(class_counts.items(), key=lambda item: item[1], reverse=True)
127
+ sorted_classes, sorted_frequencies = zip(*sorted_class_counts)
128
+ colors = plt.cm.viridis(np.linspace(0, 1, len(sorted_classes)))
129
+ fig, ax = plt.subplots()
130
+ ax.bar(sorted_classes, sorted_frequencies,color=colors)
131
+ ax.set_xlabel('Class Label')
132
+ ax.set_ylabel('Frequency')
133
+ ax.set_title('Distribution of '+str(top_k) +' Closest Sample Classes')
134
+ ax.set_xticklabels(class_counts.keys(), rotation=45, ha='right')
135
+
136
+ # Save the diagram to a file
137
+ diagram_path = 'class_distribution_chart.png'
138
+ plt.tight_layout() # Adjust layout to make room for rotated x-axis labels
139
+ plt.savefig(diagram_path)
140
+ plt.close() # Close the figure to free up memory
141
+
142
+ return diagram_path
explanations.py CHANGED
@@ -7,6 +7,7 @@ from xplique.attributions.global_sensitivity_analysis import LatinHypercube
7
  import numpy as np
8
  import matplotlib.pyplot as plt
9
  from inference_resnet import inference_resnet_finer, preprocess, _clever_crop
 
10
  BATCH_SIZE = 1
11
 
12
  def show(img, p=False, **kwargs):
@@ -35,7 +36,7 @@ def show(img, p=False, **kwargs):
35
 
36
 
37
 
38
- def explain(model, input_image,size=600, n_classes=171) :
39
  """
40
  Generate explanations for a given model and dataset.
41
  :param model: The model to explain.
@@ -45,31 +46,55 @@ def explain(model, input_image,size=600, n_classes=171) :
45
  :param batch_size: The batch size to use.
46
  :return: The explanations.
47
  """
48
-
49
  # we only need the classification part of the model
50
  class_model = tf.keras.Model(model.input, model.output[1])
51
 
52
- explainers = [
53
- #Sobol, RISE, HSIC, Saliency
54
- #IntegratedGradients(class_model, steps=50, batch_size=BATCH_SIZE),
55
- #SmoothGrad(class_model, nb_samples=50, batch_size=BATCH_SIZE),
56
- #GradCAM(class_model),
57
- SobolAttributionMethod(class_model, grid_size=8, nb_design=32),
58
- Rise(class_model,nb_samples = 5000, batch_size = BATCH_SIZE,grid_size=15,
59
- preservation_probability=0.5),
60
- HsicAttributionMethod(class_model,
61
  grid_size=7, nb_design=1500,
62
- sampler = LatinHypercube(binary=True)),
63
- Saliency(class_model),
64
- #
65
- ]
66
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  cropped,repetitions = _clever_crop(input_image,(size,size))
68
- size_repetitions = int(size//(repetitions.numpy()+1))
 
 
 
 
 
 
 
 
69
  X = preprocess(cropped,size=size)
70
  predictions = class_model.predict(np.array([X]))
71
  #Y = np.argmax(predictions)
72
  top_5_indices = np.argsort(predictions[0])[-5:][::-1]
 
 
 
73
  #print(top_5_indices)
74
  X = np.expand_dims(X, 0)
75
  explanations = []
@@ -81,8 +106,10 @@ def explain(model, input_image,size=600, n_classes=171) :
81
  phi = np.abs(explainer(X, Y))[0]
82
  if len(phi.shape) == 3:
83
  phi = np.mean(phi, -1)
84
- show(X[0][:,size_repetitions:2*size_repetitions,:])
85
- show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
 
 
86
  plt.savefig(f'phi_{e}{i}.png')
87
  explanations.append(f'phi_{e}{i}.png')
88
  # avg=[]
@@ -101,4 +128,4 @@ def explain(model, input_image,size=600, n_classes=171) :
101
  if len(explanations)==1:
102
  explanations = explanations[0]
103
  # return explanations,avg
104
- return explanations
 
7
  import numpy as np
8
  import matplotlib.pyplot as plt
9
  from inference_resnet import inference_resnet_finer, preprocess, _clever_crop
10
+ from labels import lookup_140
11
  BATCH_SIZE = 1
12
 
13
  def show(img, p=False, **kwargs):
 
36
 
37
 
38
 
39
+ def explain(model, input_image,explain_method,nb_samples,size=600, n_classes=171) :
40
  """
41
  Generate explanations for a given model and dataset.
42
  :param model: The model to explain.
 
46
  :param batch_size: The batch size to use.
47
  :return: The explanations.
48
  """
49
+ print('using explain_method:',explain_method)
50
  # we only need the classification part of the model
51
  class_model = tf.keras.Model(model.input, model.output[1])
52
 
53
+ explainers = []
54
+ if explain_method=="Sobol":
55
+ explainers.append(SobolAttributionMethod(class_model, grid_size=8, nb_design=32))
56
+ if explain_method=="HSIC":
57
+ explainers.append(HsicAttributionMethod(class_model,
 
 
 
 
58
  grid_size=7, nb_design=1500,
59
+ sampler = LatinHypercube(binary=True)))
60
+ if explain_method=="Rise":
61
+ explainers.append(Rise(class_model,nb_samples = nb_samples, batch_size = BATCH_SIZE,grid_size=15,
62
+ preservation_probability=0.5))
63
+ if explain_method=="Saliency":
64
+ explainers.append(Saliency(class_model))
65
+
66
+ # explainers = [
67
+ # #Sobol, RISE, HSIC, Saliency
68
+ # #IntegratedGradients(class_model, steps=50, batch_size=BATCH_SIZE),
69
+ # #SmoothGrad(class_model, nb_samples=50, batch_size=BATCH_SIZE),
70
+ # #GradCAM(class_model),
71
+ # SobolAttributionMethod(class_model, grid_size=8, nb_design=32),
72
+ # HsicAttributionMethod(class_model,
73
+ # grid_size=7, nb_design=1500,
74
+ # sampler = LatinHypercube(binary=True)),
75
+ # Saliency(class_model),
76
+ # Rise(class_model,nb_samples = 5000, batch_size = BATCH_SIZE,grid_size=15,
77
+ # preservation_probability=0.5),
78
+ # #
79
+ # ]
80
+
81
  cropped,repetitions = _clever_crop(input_image,(size,size))
82
+ # size_repetitions = int(size//(repetitions.numpy()+1))
83
+ # print(size)
84
+ # print(type(input_image))
85
+ # print(input_image.shape)
86
+ # size_repetitions = int(size//(repetitions+1))
87
+ # print(type(repetitions))
88
+ # print(repetitions)
89
+ # print(size_repetitions)
90
+ # print(type(size_repetitions))
91
  X = preprocess(cropped,size=size)
92
  predictions = class_model.predict(np.array([X]))
93
  #Y = np.argmax(predictions)
94
  top_5_indices = np.argsort(predictions[0])[-5:][::-1]
95
+ classes = []
96
+ for index in top_5_indices:
97
+ classes.append(lookup_140[index])
98
  #print(top_5_indices)
99
  X = np.expand_dims(X, 0)
100
  explanations = []
 
106
  phi = np.abs(explainer(X, Y))[0]
107
  if len(phi.shape) == 3:
108
  phi = np.mean(phi, -1)
109
+ show(X[0])
110
+ show(phi, p=1, alpha=0.4)
111
+ # show(X[0][:,size_repetitions:2*size_repetitions,:])
112
+ # show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
113
  plt.savefig(f'phi_{e}{i}.png')
114
  explanations.append(f'phi_{e}{i}.png')
115
  # avg=[]
 
128
  if len(explanations)==1:
129
  explanations = explanations[0]
130
  # return explanations,avg
131
+ return classes,explanations
inference_resnet.py CHANGED
@@ -7,7 +7,7 @@ else:
7
 
8
  from keras.applications import resnet
9
  import tensorflow.keras.layers as L
10
- import os
11
 
12
  from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
13
  import matplotlib.pyplot as plt
 
7
 
8
  from keras.applications import resnet
9
  import tensorflow.keras.layers as L
10
+ import os
11
 
12
  from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
13
  import matplotlib.pyplot as plt