|
|
|
import tensorflow as tf |
|
|
|
class Augment(tf.keras.layers.Layer): |
|
def __init__(self, seed=42): |
|
super().__init__() |
|
|
|
self.augment_inputs = tf.keras.layers.RandomRotation(factor=(-0.9, 0.9), fill_mode="constant", |
|
interpolation="bilinear", seed=seed, fill_value=0.0) |
|
self.augment_labels = tf.keras.layers.RandomRotation(factor=(-0.9, 0.9), fill_mode="constant", |
|
interpolation="bilinear", seed=seed, fill_value=0.0) |
|
self.augment_inputs = tf.keras.layers.RandomFlip(mode="horizontal_and_vertical", seed=seed) |
|
self.augment_labels = tf.keras.layers.RandomFlip(mode="horizontal_and_vertical", seed=seed) |
|
|
|
def call(self, inputs, labels): |
|
inputs = self.augment_inputs(inputs) |
|
labels = self.augment_labels(labels) |
|
return inputs, labels |
|
|
|
|
|
def augment_flip(image, label, axis=0): |
|
if axis == 0: |
|
|
|
image = tf.image.random_flip_left_right(image, seed=42) |
|
label = tf.image.random_flip_left_right(label, seed=42) |
|
|
|
else: |
|
|
|
image = tf.image.random_flip_up_down(image, seed=42) |
|
label = tf.image.random_flip_up_down(label, seed=42) |
|
|
|
return image, label |
|
|
|
|
|
def augment_rot(image, label, kappa=1): |
|
image = tf.image.rot90(image, k=kappa) |
|
label = tf.image.rot90(label, k=kappa) |
|
|
|
return image, label |
|
|
|
|
|
def augment(images, labels, seed=42): |
|
print(type(images)) |
|
print(tf.shape(images)) |
|
|
|
images = tf.image.random_flip_left_right(images, seed=seed) |
|
labels = tf.image.random_flip_left_right(labels, seed=seed) |
|
images = tf.image.random_flip_up_down(images, seed=seed) |
|
labels = tf.image.random_flip_up_down(labels, seed=seed) |
|
images = tf.image.rot90(images, k=2) |
|
labels = tf.image.rot90(labels, k=2) |
|
|
|
|
|
|
|
return images, labels |
|
|
|
|