xai_framework / utils /tfexplain_utils.py
hodorfi's picture
Upload 1288 files
191195c
raw
history blame
4.04 kB
import sys
import tf_explain
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage.filters import maximum_filter
import cv2
sys.path.append('./')
from utils.datagenerator import _denorm
print("loading modules")
def get_tfexplain_callbacks(vis_image, vis_label):
validation_class_zero = (np.array([
el for el, label in zip(vis_image, vis_label)
if np.all(np.argmax(label) == 0)
][0:5]), None)
validation_class_one = (np.array([
el for el, label in zip(vis_image, vis_label)
if np.all(np.argmax(label) == 1)
][0:5]), None)
validation_class_two = (np.array([
el for el, label in zip(vis_image, vis_label)
if np.all(np.argmax(label) == 2)
][0:5]), None)
validation_class_three = (np.array([
el for el, label in zip(vis_image, vis_label)
if np.all(np.argmax(label) == 3)
][0:5]), None)
# class four validation
# validation_class_four = (np.array([
# el for el, label in zip(vis_image, vis_label)
# if np.all(np.argmax(label) == 4)
# ][0:5]), None)
callbacks = [
tf_explain.callbacks.GradCAMCallback(validation_class_zero, class_index=0),
tf_explain.callbacks.GradCAMCallback(validation_class_one, class_index=1),
tf_explain.callbacks.GradCAMCallback(validation_class_two, class_index=2),
tf_explain.callbacks.GradCAMCallback(validation_class_three, class_index=3),
tf_explain.callbacks.ActivationsVisualizationCallback(
validation_class_zero, layers_name=["top_activation"]
)
]
return callbacks
def get_post_gradcam_heatmap(model,img,class_index=1, layer_name="top_conv"):
data = ([img], None)
# Start explainer
# https://docs.opencv.org/master/d3/d50/group__imgproc__colormap.html
explainer = tf_explain.core.grad_cam.GradCAM()
grid = explainer.explain(data, model, class_index=class_index,layer_name=layer_name,image_weight=0.0)
return grid
def vis_hatmap_over_img(img, heatmap_grid):
fig, ax = plt.subplots(1,1)
ax.imshow(_denorm(img, np.min(img), np.max(img)), cmap='gray')
ax.imshow(heatmap_grid, alpha=0.5, interpolation='bilinear')
return ax
def get_peak_location(heatmap):
# https://arvrjourney.com/human-pose-estimation-using-openpose-with-tensorflow-part-2-e78ab9104fc8
heat_gray = cv2.cvtColor(heatmap, cv2.COLOR_RGB2GRAY)
heat_gray_norm = heat_gray/255.
part_candidates = heat_gray_norm*(heat_gray_norm == maximum_filter(heat_gray_norm,footprint=np.ones((9,9))))
row,col = np.where(part_candidates==np.max(part_candidates))
xc = col[0]
yc = row[0]
return xc,yc
def get_thresholded_img_contours(heat_map):
# convert gray
heat_gray = cv2.cvtColor(heat_map, cv2.COLOR_RGB2GRAY)
# decide thr value
thr = np.max(heat_gray) - np.max(heat_gray)/5
# thr = 175
# binary thresholding - simple one and can be substitued by adaptive thresholding
# ret,thresh1 = cv2.threshold(heat_gray,thr,np.max(heat_gray) ,cv2.THRESH_BINARY)
ret,thresh1 = cv2.threshold(heat_gray,thr,255 ,cv2.THRESH_BINARY)
# thresh1 = cv2.adaptiveThreshold(heat_gray,255,cv2.ADAPTIVE_THRESH_GAUSSIAN_C,\
# cv2.THRESH_BINARY,11,2)
# morphological ops
kernel_ellps = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(5,5))
# kernel = np.ones((3,3),np.uint8)
clos_img = cv2.morphologyEx(thresh1, cv2.MORPH_CLOSE, kernel_ellps)
# find contours for rectangle
contours, _ = cv2.findContours(clos_img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
# based on contour area, get the maximum ones
cnt_max = max(contours, key=cv2.contourArea)
# get rectangle
x,y,w,h = cv2.boundingRect(cnt_max)
return clos_img,[x,y,w,h]
def get_max_area_contour(contours):
# find largest area contour
max_area = -1
cnt = None
for i in range(len(contours)):
area = cv2.contourArea(contours[i])
if area>max_area:
cnt = contours[i]
max_area = area
return cnt