xai_framework / utils /tf_image_preprocessing.py
hodorfi's picture
Upload 1288 files
191195c
raw
history blame
1.99 kB
import tensorflow as tf
import numpy as np
def get_preprocess_func(config):
arch = config.MODEL.backbone_arch
if arch == 'mobilenetv2':
return tf.keras.applications.mobilenet.preprocess_input
elif 'resnet' in arch:
return tf.keras.applications.resnet.preprocess_input
elif 'eff' in arch:
return tf.keras.applications.efficientnet.preprocess_input
elif 'dense' in arch:
return tf.keras.applications.densenet.preprocess_input
elif arch == 'xception':
return tf.keras.applications.xception.preprocess_input
else:
raise Exception(f'{arch} is not yet implemented')
def get_unpreprocess_func(config):
"""
Returns function that processes input to 0-1 float range
"""
arch = config.MODEL.backbone_arch
def tensor_to_numpy(x):
return x.numpy()
def clip(x):
return np.clip(x, 0.,1.)
def none_mode(x):
x = tensor_to_numpy(x)
x /= 255.
return clip(x)
def caffe_mode(x):
mean = [103.939, 116.779, 123.68]
x = tensor_to_numpy(x)
x[..., 0] += mean[0]
x[..., 1] += mean[1]
x[..., 2] += mean[2]
# 'BGR'->'RGB'
x = x[..., ::-1]
x /= 255.
return clip(x)
def tf_mode(x):
x = tensor_to_numpy(x)
x = (x + 0.5) / 2.0
return clip(x)
def torch_mode(x):
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
x = tensor_to_numpy(x)
x[..., 0] = (x[..., 0] * std[0]) + mean[0]
x[..., 1] = (x[..., 1] * std[1]) + mean[1]
x[..., 2] = (x[..., 2] * std[2]) + mean[2]
return clip(x)
if arch == 'mobilenetv2':
return tf_mode
elif 'resnet' in arch:
return caffe_mode
elif 'eff' in arch:
return none_mode
elif 'dense' in arch:
return torch_mode
elif arch == 'xception':
return tf_mode
else:
raise Exception(f'{arch} is not yet implemented')