COCAM / aix /augmentation.py
cerquide's picture
Moved aix
bf62930
import tensorflow as tf
class Augment(tf.keras.layers.Layer):
def __init__(self, seed=42):
super().__init__()
# both use the same seed, so they'll make the same random changes.
self.augment_inputs = tf.keras.layers.RandomRotation(factor=(-0.9, 0.9), fill_mode="constant",
interpolation="bilinear", seed=seed, fill_value=0.0)
self.augment_labels = tf.keras.layers.RandomRotation(factor=(-0.9, 0.9), fill_mode="constant",
interpolation="bilinear", seed=seed, fill_value=0.0)
self.augment_inputs = tf.keras.layers.RandomFlip(mode="horizontal_and_vertical", seed=seed)
self.augment_labels = tf.keras.layers.RandomFlip(mode="horizontal_and_vertical", seed=seed)
def call(self, inputs, labels):
inputs = self.augment_inputs(inputs)
labels = self.augment_labels(labels)
return inputs, labels
def augment_flip(image, label, axis=0):
if axis == 0:
image = tf.image.random_flip_left_right(image, seed=42)
label = tf.image.random_flip_left_right(label, seed=42)
else:
image = tf.image.random_flip_up_down(image, seed=42)
label = tf.image.random_flip_up_down(label, seed=42)
return image, label
def augment_rot(image, label, kappa=1):
image = tf.image.rot90(image, k=kappa)
label = tf.image.rot90(label, k=kappa)
return image, label
def augment(images, labels, seed=42):
print(type(images))
print(tf.shape(images))
images = tf.image.random_flip_left_right(images, seed=seed)
labels = tf.image.random_flip_left_right(labels, seed=seed)
images = tf.image.random_flip_up_down(images, seed=seed)
labels = tf.image.random_flip_up_down(labels, seed=seed)
images = tf.image.rot90(images, k=2)
labels = tf.image.rot90(labels, k=2)
# images = tf.image.random_crop(images, size = [1, IMG_SIZE[0], IMG_SIZE[1], 1], seed = seed)
# labels = tf.image.random_crop(labels, size = [1, IMG_SIZE[0], IMG_SIZE[1], 1], seed = seed)
return images, labels