andy-wyx commited on
Commit
5579c05
·
1 Parent(s): 99ddcfc

debugging: xai output distortion

Browse files
Files changed (1) hide show
  1. explanations.py +22 -5
explanations.py CHANGED
@@ -28,7 +28,23 @@ def preprocess_image(image, output_size=(300, 300)):
28
 
29
  return image_resized
30
 
31
- def show(img, output_size,p=False, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  #img = preprocess_image(img, output_size=(output_size,output_size))
34
 
@@ -48,7 +64,8 @@ def show(img, output_size,p=False, **kwargs):
48
  # check if clip percentile
49
  if p is not False:
50
  img = np.clip(img, np.percentile(img, p), np.percentile(img, 100-p))
51
- img = preprocess_image(img, output_size=(output_size,output_size))
 
52
  plt.imshow(img, **kwargs)
53
  plt.axis('off')
54
 
@@ -56,7 +73,7 @@ def show(img, output_size,p=False, **kwargs):
56
 
57
 
58
 
59
- def explain(model, input_image,explain_method,nb_samples,size=600, n_classes=171) :
60
  """
61
  Generate explanations for a given model and dataset.
62
  :param model: The model to explain.
@@ -130,8 +147,8 @@ def explain(model, input_image,explain_method,nb_samples,size=600, n_classes=171
130
  phi = np.mean(phi, -1)
131
  #apply Gaussian smoothing
132
  phi_smoothed = cv2.GaussianBlur(phi, (5, 5), sigmaX=1.0, sigmaY=1.0)
133
- show(X[0],output_size = size)
134
- show(phi_smoothed, output_size = size,p=1, alpha=0.2)
135
  # show(X[0][:,size_repetitions:2*size_repetitions,:])
136
  # show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
137
  plt.savefig(f'phi_{e}{i}.png')
 
28
 
29
  return image_resized
30
 
31
+
32
+ def transform(image, original_size,output_size):
33
+ """
34
+ resize xai output back to original scale and pad to square-shape
35
+ """
36
+ h,w = original_size
37
+ image = cv2.resize(image,(h,w), interpolation = cv2.INTER_AREA)
38
+ if h > w:
39
+ padding = (h - w) // 2
40
+ image= cv2.copyMakeBorder(image, 0, 0, padding, padding, cv2.BORDER_CONSTANT, value=[0, 0, 0])
41
+ else:
42
+ padding = (w - h) // 2
43
+ image = cv2.copyMakeBorder(image, padding, padding, 0, 0, cv2.BORDER_CONSTANT, value=[0, 0, 0])
44
+ image = cv2.resize(image,output_size, interpolation = cv2.INTER_AREA)
45
+ return image
46
+
47
+ def show(img, original_size, output_size,p=False, **kwargs):
48
 
49
  #img = preprocess_image(img, output_size=(output_size,output_size))
50
 
 
64
  # check if clip percentile
65
  if p is not False:
66
  img = np.clip(img, np.percentile(img, p), np.percentile(img, 100-p))
67
+
68
+ img = transform(img,original_size=original_size,output_size=output_size)
69
  plt.imshow(img, **kwargs)
70
  plt.axis('off')
71
 
 
73
 
74
 
75
 
76
+ def explain(model, input_image,h,w,explain_method,nb_samples,size=600, n_classes=171) :
77
  """
78
  Generate explanations for a given model and dataset.
79
  :param model: The model to explain.
 
147
  phi = np.mean(phi, -1)
148
  #apply Gaussian smoothing
149
  phi_smoothed = cv2.GaussianBlur(phi, (5, 5), sigmaX=1.0, sigmaY=1.0)
150
+ show(X[0],original_size=(h,w),output_size = (size,size))
151
+ show(phi_smoothed, original_size=(h,w),output_size = (size,size),p=1, alpha=0.2)
152
  # show(X[0][:,size_repetitions:2*size_repetitions,:])
153
  # show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
154
  plt.savefig(f'phi_{e}{i}.png')