climateGAN / climategan /transforms.py
vict0rsch's picture
initial commit from cc-ai/climateGAN
448ebbd
"""Data transforms for the loaders
"""
import random
import traceback
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from skimage.color import rgba2rgb
from skimage.io import imread
from torchvision import transforms as trsfs
from torchvision.transforms.functional import (
adjust_brightness,
adjust_contrast,
adjust_saturation,
)
from climategan.tutils import normalize
def interpolation(task):
if task in ["d", "m", "s"]:
return {"mode": "nearest"}
else:
return {"mode": "bilinear", "align_corners": True}
class Resize:
def __init__(self, target_size, keep_aspect_ratio=False):
"""
Resize transform. Target_size can be an int or a tuple of ints,
depending on whether both height and width should have the same
final size or not.
If keep_aspect_ratio is specified then target_size must be an int:
the smallest dimension of x will be set to target_size and the largest
dimension will be computed to the closest int keeping the original
aspect ratio. e.g.
>>> x = torch.rand(1, 3, 1200, 1800)
>>> m = torch.rand(1, 1, 600, 600)
>>> d = {"x": x, "m": m}
>>> {k: v.shape for k, v in Resize(640, True)(d).items()}
{"x": (1, 3, 640, 960), "m": (1, 1, 640, 960)}
Args:
target_size (int | tuple(int)): New size for the tensor
keep_aspect_ratio (bool, optional): Whether or not to keep aspect ratio
when resizing. Requires target_size to be an int. If keeping aspect
ratio, smallest dim will be set to target_size. Defaults to False.
"""
if isinstance(target_size, (int, tuple, list)):
if not isinstance(target_size, int) and not keep_aspect_ratio:
assert len(target_size) == 2
self.h, self.w = target_size
else:
if keep_aspect_ratio:
assert isinstance(target_size, int)
self.h = self.w = target_size
self.default_h = int(self.h)
self.default_w = int(self.w)
self.sizes = {}
elif isinstance(target_size, dict):
assert (
not keep_aspect_ratio
), "dict target_size not compatible with keep_aspect_ratio"
self.sizes = {
k: {"h": v, "w": v} for k, v in target_size.items() if k != "default"
}
self.default_h = int(target_size["default"])
self.default_w = int(target_size["default"])
self.keep_aspect_ratio = keep_aspect_ratio
def compute_new_default_size(self, tensor):
"""
compute the new size for a tensor depending on target size
and keep_aspect_rato
Args:
tensor (torch.Tensor): 4D tensor N x C x H x W.
Returns:
tuple(int): (new_height, new_width)
"""
if self.keep_aspect_ratio:
h, w = tensor.shape[-2:]
if h < w:
return (self.h, int(self.default_h * w / h))
else:
return (int(self.default_h * h / w), self.default_w)
return (self.default_h, self.default_w)
def compute_new_size_for_task(self, task):
assert (
not self.keep_aspect_ratio
), "compute_new_size_for_task is not compatible with keep aspect ratio"
if task not in self.sizes:
return (self.default_h, self.default_w)
return (self.sizes[task]["h"], self.sizes[task]["w"])
def __call__(self, data):
"""
Resize a dict of tensors to the "x" key's new_size
Args:
data (dict[str:torch.Tensor]): The data dict to transform
Returns:
dict[str: torch.Tensor]: dict with all tensors resized to the
new size of the data["x"] tensor
"""
task = tensor = new_size = None
try:
if not self.sizes:
d = {}
new_size = self.compute_new_default_size(
data["x"] if "x" in data else list(data.values())[0]
)
for task, tensor in data.items():
d[task] = F.interpolate(
tensor, size=new_size, **interpolation(task)
)
return d
d = {}
for task, tensor in data.items():
new_size = self.compute_new_size_for_task(task)
d[task] = F.interpolate(tensor, size=new_size, **interpolation(task))
return d
except Exception as e:
tb = traceback.format_exc()
print("Debug: task, shape, interpolation, h, w, new_size")
print(task)
print(tensor.shape)
print(interpolation(task))
print(self.h, self.w)
print(new_size)
print(tb)
raise Exception(e)
class RandomCrop:
def __init__(self, size, center=False):
assert isinstance(size, (int, tuple, list))
if not isinstance(size, int):
assert len(size) == 2
self.h, self.w = size
else:
self.h = self.w = size
self.h = int(self.h)
self.w = int(self.w)
self.center = center
def __call__(self, data):
H, W = (
data["x"].size()[-2:] if "x" in data else list(data.values())[0].size()[-2:]
)
if not self.center:
top = np.random.randint(0, H - self.h)
left = np.random.randint(0, W - self.w)
else:
top = (H - self.h) // 2
left = (W - self.w) // 2
return {
task: tensor[:, :, top : top + self.h, left : left + self.w]
for task, tensor in data.items()
}
class RandomHorizontalFlip:
def __init__(self, p=0.5):
# self.flip = TF.hflip
self.p = p
def __call__(self, data):
if np.random.rand() > self.p:
return data
return {task: torch.flip(tensor, [3]) for task, tensor in data.items()}
class ToTensor:
def __init__(self):
self.ImagetoTensor = trsfs.ToTensor()
self.MaptoTensor = self.ImagetoTensor
def __call__(self, data):
new_data = {}
for task, im in data.items():
if task in {"x", "a"}:
new_data[task] = self.ImagetoTensor(im)
elif task in {"m"}:
new_data[task] = self.MaptoTensor(im)
elif task == "s":
new_data[task] = torch.squeeze(torch.from_numpy(np.array(im))).to(
torch.int64
)
elif task == "d":
new_data = im
return new_data
class Normalize:
def __init__(self, opts):
if opts.data.normalization == "HRNet":
self.normImage = trsfs.Normalize(
((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
)
else:
self.normImage = trsfs.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
self.normDepth = lambda x: x
self.normMask = lambda x: x
self.normSeg = lambda x: x
self.normalize = {
"x": self.normImage,
"s": self.normSeg,
"d": self.normDepth,
"m": self.normMask,
}
def __call__(self, data):
return {
task: self.normalize.get(task, lambda x: x)(tensor.squeeze(0))
for task, tensor in data.items()
}
class RandBrightness: # Input need to be between -1 and 1
def __call__(self, data):
return {
task: rand_brightness(tensor) if task == "x" else tensor
for task, tensor in data.items()
}
class RandSaturation:
def __call__(self, data):
return {
task: rand_saturation(tensor) if task == "x" else tensor
for task, tensor in data.items()
}
class RandContrast:
def __call__(self, data):
return {
task: rand_contrast(tensor) if task == "x" else tensor
for task, tensor in data.items()
}
class BucketizeDepth:
def __init__(self, opts, domain):
self.domain = domain
if opts.gen.d.classify.enable and domain in {"s", "kitti"}:
self.buckets = torch.linspace(
*[
opts.gen.d.classify.linspace.min,
opts.gen.d.classify.linspace.max,
opts.gen.d.classify.linspace.buckets - 1,
]
)
self.transforms = {
"d": lambda tensor: torch.bucketize(
tensor, self.buckets, out_int32=True, right=True
)
}
else:
self.transforms = {}
def __call__(self, data):
return {
task: self.transforms.get(task, lambda x: x)(tensor)
for task, tensor in data.items()
}
class PrepareInference:
"""
Transform which:
- transforms a str or an array into a tensor
- resizes the image to keep the aspect ratio
- crops in the center of the resized image
- normalize to 0:1
- rescale to -1:1
"""
def __init__(self, target_size=640, half=False, is_label=False, enforce_128=True):
if enforce_128:
if target_size % 2 ** 7 != 0:
raise ValueError(
f"Received a target_size of {target_size}, which is not a "
+ "multiple of 2^7 = 128. Set enforce_128 to False to disable "
+ "this error."
)
self.resize = Resize(target_size, keep_aspect_ratio=True)
self.crop = RandomCrop((target_size, target_size), center=True)
self.half = half
self.is_label = is_label
def process(self, t):
if isinstance(t, (str, Path)):
t = imread(str(t))
if isinstance(t, np.ndarray):
if t.shape[-1] == 4:
t = rgba2rgb(t)
t = torch.from_numpy(t)
if t.ndim == 3:
t = t.permute(2, 0, 1)
if t.ndim == 3:
t = t.unsqueeze(0)
elif t.ndim == 2:
t = t.unsqueeze(0).unsqueeze(0)
if not self.is_label:
t = t.to(torch.float32)
t = normalize(t)
t = (t - 0.5) * 2
t = {"m": t} if self.is_label else {"x": t}
t = self.resize(t)
t = self.crop(t)
t = t["m"] if self.is_label else t["x"]
if self.half and not self.is_label:
t = t.half()
return t
def __call__(self, x):
"""
normalize, rescale, resize, crop in the center
x can be: dict {"task": data} list [data, ..] or data
data ^ can be a str, a Path, a numpy arrray or a Tensor
"""
if isinstance(x, dict):
return {k: self.process(v) for k, v in x.items()}
if isinstance(x, list):
return [self.process(t) for t in x]
return self.process(x)
class PrepareTest:
"""
Transform which:
- transforms a str or an array into a tensor
- resizes the image to keep the aspect ratio
- crops in the center of the resized image
- normalize to 0:1 (optional)
- rescale to -1:1 (optional)
"""
def __init__(self, target_size=640, half=False):
self.resize = Resize(target_size, keep_aspect_ratio=True)
self.crop = RandomCrop((target_size, target_size), center=True)
self.half = half
def process(self, t, normalize=False, rescale=False):
if isinstance(t, (str, Path)):
# t = img_as_float(imread(str(t)))
t = imread(str(t))
if t.shape[-1] == 4:
# t = rgba2rgb(t)
t = t[:, :, :3]
if np.ndim(t) == 2:
t = np.repeat(t[:, :, np.newaxis], 3, axis=2)
if isinstance(t, np.ndarray):
t = torch.from_numpy(t)
t = t.permute(2, 0, 1)
if len(t.shape) == 3:
t = t.unsqueeze(0)
t = t.to(torch.float32)
normalize(t) if normalize else t
(t - 0.5) * 2 if rescale else t
t = {"x": t}
t = self.resize(t)
t = self.crop(t)
t = t["x"]
if self.half:
return t.to(torch.float16)
return t
def __call__(self, x, normalize=False, rescale=False):
"""
Call process()
x can be: dict {"task": data} list [data, ..] or data
data ^ can be a str, a Path, a numpy arrray or a Tensor
"""
if isinstance(x, dict):
return {k: self.process(v, normalize, rescale) for k, v in x.items()}
if isinstance(x, list):
return [self.process(t, normalize, rescale) for t in x]
return self.process(x, normalize, rescale)
def get_transform(transform_item, mode):
"""Returns the torchivion transform function associated to a
transform_item listed in opts.data.transforms ; transform_item is
an addict.Dict
"""
if transform_item.name == "crop" and not (
transform_item.ignore is True or transform_item.ignore == mode
):
return RandomCrop(
(transform_item.height, transform_item.width),
center=transform_item.center == mode,
)
elif transform_item.name == "resize" and not (
transform_item.ignore is True or transform_item.ignore == mode
):
return Resize(
transform_item.new_size, transform_item.get("keep_aspect_ratio", False)
)
elif transform_item.name == "hflip" and not (
transform_item.ignore is True or transform_item.ignore == mode
):
return RandomHorizontalFlip(p=transform_item.p or 0.5)
elif transform_item.name == "brightness" and not (
transform_item.ignore is True or transform_item.ignore == mode
):
return RandBrightness()
elif transform_item.name == "saturation" and not (
transform_item.ignore is True or transform_item.ignore == mode
):
return RandSaturation()
elif transform_item.name == "contrast" and not (
transform_item.ignore is True or transform_item.ignore == mode
):
return RandContrast()
elif transform_item.ignore is True or transform_item.ignore == mode:
return None
raise ValueError("Unknown transform_item {}".format(transform_item))
def get_transforms(opts, mode, domain):
"""Get all the transform functions listed in opts.data.transforms
using get_transform(transform_item, mode)
"""
transforms = []
color_jittering_transforms = ["brightness", "saturation", "contrast"]
for t in opts.data.transforms:
if t.name not in color_jittering_transforms:
transforms.append(get_transform(t, mode))
if "p" not in opts.tasks and mode == "train":
for t in opts.data.transforms:
if t.name in color_jittering_transforms:
transforms.append(get_transform(t, mode))
transforms += [Normalize(opts), BucketizeDepth(opts, domain)]
transforms = [t for t in transforms if t is not None]
return transforms
# ----- Adapted functions from https://github.com/mit-han-lab/data-efficient-gans -----#
def rand_brightness(tensor, is_diff_augment=False):
if is_diff_augment:
assert len(tensor.shape) == 4
type_ = tensor.dtype
device_ = tensor.device
rand_tens = torch.rand(tensor.size(0), 1, 1, 1, dtype=type_, device=device_)
return tensor + (rand_tens - 0.5)
else:
factor = random.uniform(0.5, 1.5)
tensor = adjust_brightness(tensor, brightness_factor=factor)
# dummy pixels to fool scaling and preserve range
tensor[:, :, 0, 0] = 1.0
tensor[:, :, -1, -1] = 0.0
return tensor
def rand_saturation(tensor, is_diff_augment=False):
if is_diff_augment:
assert len(tensor.shape) == 4
type_ = tensor.dtype
device_ = tensor.device
rand_tens = torch.rand(tensor.size(0), 1, 1, 1, dtype=type_, device=device_)
x_mean = tensor.mean(dim=1, keepdim=True)
return (tensor - x_mean) * (rand_tens * 2) + x_mean
else:
factor = random.uniform(0.5, 1.5)
tensor = adjust_saturation(tensor, saturation_factor=factor)
# dummy pixels to fool scaling and preserve range
tensor[:, :, 0, 0] = 1.0
tensor[:, :, -1, -1] = 0.0
return tensor
def rand_contrast(tensor, is_diff_augment=False):
if is_diff_augment:
assert len(tensor.shape) == 4
type_ = tensor.dtype
device_ = tensor.device
rand_tens = torch.rand(tensor.size(0), 1, 1, 1, dtype=type_, device=device_)
x_mean = tensor.mean(dim=[1, 2, 3], keepdim=True)
return (tensor - x_mean) * (rand_tens + 0.5) + x_mean
else:
factor = random.uniform(0.5, 1.5)
tensor = adjust_contrast(tensor, contrast_factor=factor)
# dummy pixels to fool scaling and preserve range
tensor[:, :, 0, 0] = 1.0
tensor[:, :, -1, -1] = 0.0
return tensor
def rand_cutout(tensor, ratio=0.5):
assert len(tensor.shape) == 4, "For rand cutout, tensor must be 4D."
type_ = tensor.dtype
device_ = tensor.device
cutout_size = int(tensor.size(-2) * ratio + 0.5), int(tensor.size(-1) * ratio + 0.5)
grid_batch, grid_x, grid_y = torch.meshgrid(
torch.arange(tensor.size(0), dtype=torch.long, device=device_),
torch.arange(cutout_size[0], dtype=torch.long, device=device_),
torch.arange(cutout_size[1], dtype=torch.long, device=device_),
)
size_ = [tensor.size(0), 1, 1]
offset_x = torch.randint(
0,
tensor.size(-2) + (1 - cutout_size[0] % 2),
size=size_,
device=device_,
)
offset_y = torch.randint(
0,
tensor.size(-1) + (1 - cutout_size[1] % 2),
size=size_,
device=device_,
)
grid_x = torch.clamp(
grid_x + offset_x - cutout_size[0] // 2, min=0, max=tensor.size(-2) - 1
)
grid_y = torch.clamp(
grid_y + offset_y - cutout_size[1] // 2, min=0, max=tensor.size(-1) - 1
)
mask = torch.ones(
tensor.size(0), tensor.size(2), tensor.size(3), dtype=type_, device=device_
)
mask[grid_batch, grid_x, grid_y] = 0
return tensor * mask.unsqueeze(1)
def rand_translation(tensor, ratio=0.125):
assert len(tensor.shape) == 4, "For rand translation, tensor must be 4D."
device_ = tensor.device
shift_x, shift_y = (
int(tensor.size(2) * ratio + 0.5),
int(tensor.size(3) * ratio + 0.5),
)
translation_x = torch.randint(
-shift_x, shift_x + 1, size=[tensor.size(0), 1, 1], device=device_
)
translation_y = torch.randint(
-shift_y, shift_y + 1, size=[tensor.size(0), 1, 1], device=device_
)
grid_batch, grid_x, grid_y = torch.meshgrid(
torch.arange(tensor.size(0), dtype=torch.long, device=device_),
torch.arange(tensor.size(2), dtype=torch.long, device=device_),
torch.arange(tensor.size(3), dtype=torch.long, device=device_),
)
grid_x = torch.clamp(grid_x + translation_x + 1, 0, tensor.size(2) + 1)
grid_y = torch.clamp(grid_y + translation_y + 1, 0, tensor.size(3) + 1)
x_pad = F.pad(tensor, [1, 1, 1, 1, 0, 0, 0, 0])
tensor = (
x_pad.permute(0, 2, 3, 1)
.contiguous()[grid_batch, grid_x, grid_y]
.permute(0, 3, 1, 2)
)
return tensor
class DiffTransforms:
def __init__(self, diff_aug_opts):
self.do_color_jittering = diff_aug_opts.do_color_jittering
self.do_cutout = diff_aug_opts.do_cutout
self.do_translation = diff_aug_opts.do_translation
self.cutout_ratio = diff_aug_opts.cutout_ratio
self.translation_ratio = diff_aug_opts.translation_ratio
def __call__(self, tensor):
if self.do_color_jittering:
tensor = rand_brightness(tensor, is_diff_augment=True)
tensor = rand_contrast(tensor, is_diff_augment=True)
tensor = rand_saturation(tensor, is_diff_augment=True)
if self.do_translation:
tensor = rand_translation(tensor, ratio=self.translation_ratio)
if self.do_cutout:
tensor = rand_cutout(tensor, ratio=self.cutout_ratio)
return tensor