andy-wyx commited on
Commit
f8b140a
·
1 Parent(s): 00eaf70

update image padding for xai

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. explanations.py +6 -3
app.py CHANGED
@@ -289,7 +289,7 @@ with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
289
  with gr.Tab("Specimen Workbench"):
290
  with gr.Row():
291
  with gr.Column():
292
- original_image = gr.Image(visible = True)
293
  workbench_image = gr.Image(label="Workbench Image")
294
  classify_image_button = gr.Button("Classify Image")
295
 
 
289
  with gr.Tab("Specimen Workbench"):
290
  with gr.Row():
291
  with gr.Column():
292
+ original_image = gr.Image(visible = False)
293
  workbench_image = gr.Image(label="Workbench Image")
294
  classify_image_button = gr.Button("Classify Image")
295
 
explanations.py CHANGED
@@ -8,11 +8,14 @@ import numpy as np
8
  import matplotlib.pyplot as plt
9
  from inference_resnet import inference_resnet_finer, preprocess, _clever_crop
10
  from labels import lookup_140
 
11
  BATCH_SIZE = 1
12
 
13
- def show(img, p=False, **kwargs):
14
  img = np.array(img, dtype=np.float32)
15
 
 
 
16
  # check if channel first
17
  if img.shape[0] == 1:
18
  img = img[0]
@@ -106,8 +109,8 @@ def explain(model, input_image,explain_method,nb_samples,size=600, n_classes=171
106
  phi = np.abs(explainer(X, Y))[0]
107
  if len(phi.shape) == 3:
108
  phi = np.mean(phi, -1)
109
- show(X[0])
110
- show(phi, p=1, alpha=0.4)
111
  # show(X[0][:,size_repetitions:2*size_repetitions,:])
112
  # show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
113
  plt.savefig(f'phi_{e}{i}.png')
 
8
  import matplotlib.pyplot as plt
9
  from inference_resnet import inference_resnet_finer, preprocess, _clever_crop
10
  from labels import lookup_140
11
+ from app import preprocess_image
12
  BATCH_SIZE = 1
13
 
14
+ def show(img, p=False,output_size, **kwargs):
15
  img = np.array(img, dtype=np.float32)
16
 
17
+ img = preprocess_image(img, output_size=output_size)
18
+
19
  # check if channel first
20
  if img.shape[0] == 1:
21
  img = img[0]
 
109
  phi = np.abs(explainer(X, Y))[0]
110
  if len(phi.shape) == 3:
111
  phi = np.mean(phi, -1)
112
+ show(X[0],output_size = size)
113
+ show(phi, p=1, alpha=0.4, output_size = size)
114
  # show(X[0][:,size_repetitions:2*size_repetitions,:])
115
  # show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
116
  plt.savefig(f'phi_{e}{i}.png')