|
"""Data transforms for the loaders |
|
""" |
|
import random |
|
import traceback |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from skimage.color import rgba2rgb |
|
from skimage.io import imread |
|
from torchvision import transforms as trsfs |
|
from torchvision.transforms.functional import ( |
|
adjust_brightness, |
|
adjust_contrast, |
|
adjust_saturation, |
|
) |
|
|
|
from climategan.tutils import normalize |
|
|
|
|
|
def interpolation(task): |
|
if task in ["d", "m", "s"]: |
|
return {"mode": "nearest"} |
|
else: |
|
return {"mode": "bilinear", "align_corners": True} |
|
|
|
|
|
class Resize: |
|
def __init__(self, target_size, keep_aspect_ratio=False): |
|
""" |
|
Resize transform. Target_size can be an int or a tuple of ints, |
|
depending on whether both height and width should have the same |
|
final size or not. |
|
|
|
If keep_aspect_ratio is specified then target_size must be an int: |
|
the smallest dimension of x will be set to target_size and the largest |
|
dimension will be computed to the closest int keeping the original |
|
aspect ratio. e.g. |
|
>>> x = torch.rand(1, 3, 1200, 1800) |
|
>>> m = torch.rand(1, 1, 600, 600) |
|
>>> d = {"x": x, "m": m} |
|
>>> {k: v.shape for k, v in Resize(640, True)(d).items()} |
|
{"x": (1, 3, 640, 960), "m": (1, 1, 640, 960)} |
|
|
|
|
|
|
|
Args: |
|
target_size (int | tuple(int)): New size for the tensor |
|
keep_aspect_ratio (bool, optional): Whether or not to keep aspect ratio |
|
when resizing. Requires target_size to be an int. If keeping aspect |
|
ratio, smallest dim will be set to target_size. Defaults to False. |
|
""" |
|
if isinstance(target_size, (int, tuple, list)): |
|
if not isinstance(target_size, int) and not keep_aspect_ratio: |
|
assert len(target_size) == 2 |
|
self.h, self.w = target_size |
|
else: |
|
if keep_aspect_ratio: |
|
assert isinstance(target_size, int) |
|
self.h = self.w = target_size |
|
|
|
self.default_h = int(self.h) |
|
self.default_w = int(self.w) |
|
self.sizes = {} |
|
elif isinstance(target_size, dict): |
|
assert ( |
|
not keep_aspect_ratio |
|
), "dict target_size not compatible with keep_aspect_ratio" |
|
|
|
self.sizes = { |
|
k: {"h": v, "w": v} for k, v in target_size.items() if k != "default" |
|
} |
|
self.default_h = int(target_size["default"]) |
|
self.default_w = int(target_size["default"]) |
|
|
|
self.keep_aspect_ratio = keep_aspect_ratio |
|
|
|
def compute_new_default_size(self, tensor): |
|
""" |
|
compute the new size for a tensor depending on target size |
|
and keep_aspect_rato |
|
|
|
Args: |
|
tensor (torch.Tensor): 4D tensor N x C x H x W. |
|
|
|
Returns: |
|
tuple(int): (new_height, new_width) |
|
""" |
|
if self.keep_aspect_ratio: |
|
h, w = tensor.shape[-2:] |
|
if h < w: |
|
return (self.h, int(self.default_h * w / h)) |
|
else: |
|
return (int(self.default_h * h / w), self.default_w) |
|
return (self.default_h, self.default_w) |
|
|
|
def compute_new_size_for_task(self, task): |
|
assert ( |
|
not self.keep_aspect_ratio |
|
), "compute_new_size_for_task is not compatible with keep aspect ratio" |
|
|
|
if task not in self.sizes: |
|
return (self.default_h, self.default_w) |
|
|
|
return (self.sizes[task]["h"], self.sizes[task]["w"]) |
|
|
|
def __call__(self, data): |
|
""" |
|
Resize a dict of tensors to the "x" key's new_size |
|
|
|
Args: |
|
data (dict[str:torch.Tensor]): The data dict to transform |
|
|
|
Returns: |
|
dict[str: torch.Tensor]: dict with all tensors resized to the |
|
new size of the data["x"] tensor |
|
""" |
|
task = tensor = new_size = None |
|
try: |
|
if not self.sizes: |
|
d = {} |
|
new_size = self.compute_new_default_size( |
|
data["x"] if "x" in data else list(data.values())[0] |
|
) |
|
for task, tensor in data.items(): |
|
d[task] = F.interpolate( |
|
tensor, size=new_size, **interpolation(task) |
|
) |
|
return d |
|
|
|
d = {} |
|
for task, tensor in data.items(): |
|
new_size = self.compute_new_size_for_task(task) |
|
d[task] = F.interpolate(tensor, size=new_size, **interpolation(task)) |
|
return d |
|
|
|
except Exception as e: |
|
tb = traceback.format_exc() |
|
print("Debug: task, shape, interpolation, h, w, new_size") |
|
print(task) |
|
print(tensor.shape) |
|
print(interpolation(task)) |
|
print(self.h, self.w) |
|
print(new_size) |
|
print(tb) |
|
raise Exception(e) |
|
|
|
|
|
class RandomCrop: |
|
def __init__(self, size, center=False): |
|
assert isinstance(size, (int, tuple, list)) |
|
if not isinstance(size, int): |
|
assert len(size) == 2 |
|
self.h, self.w = size |
|
else: |
|
self.h = self.w = size |
|
|
|
self.h = int(self.h) |
|
self.w = int(self.w) |
|
self.center = center |
|
|
|
def __call__(self, data): |
|
H, W = ( |
|
data["x"].size()[-2:] if "x" in data else list(data.values())[0].size()[-2:] |
|
) |
|
|
|
if not self.center: |
|
top = np.random.randint(0, H - self.h) |
|
left = np.random.randint(0, W - self.w) |
|
else: |
|
top = (H - self.h) // 2 |
|
left = (W - self.w) // 2 |
|
|
|
return { |
|
task: tensor[:, :, top : top + self.h, left : left + self.w] |
|
for task, tensor in data.items() |
|
} |
|
|
|
|
|
class RandomHorizontalFlip: |
|
def __init__(self, p=0.5): |
|
|
|
self.p = p |
|
|
|
def __call__(self, data): |
|
if np.random.rand() > self.p: |
|
return data |
|
return {task: torch.flip(tensor, [3]) for task, tensor in data.items()} |
|
|
|
|
|
class ToTensor: |
|
def __init__(self): |
|
self.ImagetoTensor = trsfs.ToTensor() |
|
self.MaptoTensor = self.ImagetoTensor |
|
|
|
def __call__(self, data): |
|
new_data = {} |
|
for task, im in data.items(): |
|
if task in {"x", "a"}: |
|
new_data[task] = self.ImagetoTensor(im) |
|
elif task in {"m"}: |
|
new_data[task] = self.MaptoTensor(im) |
|
elif task == "s": |
|
new_data[task] = torch.squeeze(torch.from_numpy(np.array(im))).to( |
|
torch.int64 |
|
) |
|
elif task == "d": |
|
new_data = im |
|
|
|
return new_data |
|
|
|
|
|
class Normalize: |
|
def __init__(self, opts): |
|
if opts.data.normalization == "HRNet": |
|
self.normImage = trsfs.Normalize( |
|
((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) |
|
) |
|
else: |
|
self.normImage = trsfs.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
self.normDepth = lambda x: x |
|
self.normMask = lambda x: x |
|
self.normSeg = lambda x: x |
|
|
|
self.normalize = { |
|
"x": self.normImage, |
|
"s": self.normSeg, |
|
"d": self.normDepth, |
|
"m": self.normMask, |
|
} |
|
|
|
def __call__(self, data): |
|
return { |
|
task: self.normalize.get(task, lambda x: x)(tensor.squeeze(0)) |
|
for task, tensor in data.items() |
|
} |
|
|
|
|
|
class RandBrightness: |
|
def __call__(self, data): |
|
return { |
|
task: rand_brightness(tensor) if task == "x" else tensor |
|
for task, tensor in data.items() |
|
} |
|
|
|
|
|
class RandSaturation: |
|
def __call__(self, data): |
|
return { |
|
task: rand_saturation(tensor) if task == "x" else tensor |
|
for task, tensor in data.items() |
|
} |
|
|
|
|
|
class RandContrast: |
|
def __call__(self, data): |
|
return { |
|
task: rand_contrast(tensor) if task == "x" else tensor |
|
for task, tensor in data.items() |
|
} |
|
|
|
|
|
class BucketizeDepth: |
|
def __init__(self, opts, domain): |
|
self.domain = domain |
|
|
|
if opts.gen.d.classify.enable and domain in {"s", "kitti"}: |
|
self.buckets = torch.linspace( |
|
*[ |
|
opts.gen.d.classify.linspace.min, |
|
opts.gen.d.classify.linspace.max, |
|
opts.gen.d.classify.linspace.buckets - 1, |
|
] |
|
) |
|
|
|
self.transforms = { |
|
"d": lambda tensor: torch.bucketize( |
|
tensor, self.buckets, out_int32=True, right=True |
|
) |
|
} |
|
else: |
|
self.transforms = {} |
|
|
|
def __call__(self, data): |
|
return { |
|
task: self.transforms.get(task, lambda x: x)(tensor) |
|
for task, tensor in data.items() |
|
} |
|
|
|
|
|
class PrepareInference: |
|
""" |
|
Transform which: |
|
- transforms a str or an array into a tensor |
|
- resizes the image to keep the aspect ratio |
|
- crops in the center of the resized image |
|
- normalize to 0:1 |
|
- rescale to -1:1 |
|
""" |
|
|
|
def __init__(self, target_size=640, half=False, is_label=False, enforce_128=True): |
|
if enforce_128: |
|
if target_size % 2 ** 7 != 0: |
|
raise ValueError( |
|
f"Received a target_size of {target_size}, which is not a " |
|
+ "multiple of 2^7 = 128. Set enforce_128 to False to disable " |
|
+ "this error." |
|
) |
|
self.resize = Resize(target_size, keep_aspect_ratio=True) |
|
self.crop = RandomCrop((target_size, target_size), center=True) |
|
self.half = half |
|
self.is_label = is_label |
|
|
|
def process(self, t): |
|
if isinstance(t, (str, Path)): |
|
t = imread(str(t)) |
|
|
|
if isinstance(t, np.ndarray): |
|
if t.shape[-1] == 4: |
|
t = rgba2rgb(t) |
|
|
|
t = torch.from_numpy(t) |
|
if t.ndim == 3: |
|
t = t.permute(2, 0, 1) |
|
|
|
if t.ndim == 3: |
|
t = t.unsqueeze(0) |
|
elif t.ndim == 2: |
|
t = t.unsqueeze(0).unsqueeze(0) |
|
|
|
if not self.is_label: |
|
t = t.to(torch.float32) |
|
t = normalize(t) |
|
t = (t - 0.5) * 2 |
|
|
|
t = {"m": t} if self.is_label else {"x": t} |
|
t = self.resize(t) |
|
t = self.crop(t) |
|
t = t["m"] if self.is_label else t["x"] |
|
|
|
if self.half and not self.is_label: |
|
t = t.half() |
|
|
|
return t |
|
|
|
def __call__(self, x): |
|
""" |
|
normalize, rescale, resize, crop in the center |
|
|
|
x can be: dict {"task": data} list [data, ..] or data |
|
data ^ can be a str, a Path, a numpy arrray or a Tensor |
|
""" |
|
if isinstance(x, dict): |
|
return {k: self.process(v) for k, v in x.items()} |
|
|
|
if isinstance(x, list): |
|
return [self.process(t) for t in x] |
|
|
|
return self.process(x) |
|
|
|
|
|
class PrepareTest: |
|
""" |
|
Transform which: |
|
- transforms a str or an array into a tensor |
|
- resizes the image to keep the aspect ratio |
|
- crops in the center of the resized image |
|
- normalize to 0:1 (optional) |
|
- rescale to -1:1 (optional) |
|
""" |
|
|
|
def __init__(self, target_size=640, half=False): |
|
self.resize = Resize(target_size, keep_aspect_ratio=True) |
|
self.crop = RandomCrop((target_size, target_size), center=True) |
|
self.half = half |
|
|
|
def process(self, t, normalize=False, rescale=False): |
|
if isinstance(t, (str, Path)): |
|
|
|
t = imread(str(t)) |
|
if t.shape[-1] == 4: |
|
|
|
t = t[:, :, :3] |
|
if np.ndim(t) == 2: |
|
t = np.repeat(t[:, :, np.newaxis], 3, axis=2) |
|
|
|
if isinstance(t, np.ndarray): |
|
t = torch.from_numpy(t) |
|
t = t.permute(2, 0, 1) |
|
|
|
if len(t.shape) == 3: |
|
t = t.unsqueeze(0) |
|
|
|
t = t.to(torch.float32) |
|
normalize(t) if normalize else t |
|
(t - 0.5) * 2 if rescale else t |
|
t = {"x": t} |
|
t = self.resize(t) |
|
t = self.crop(t) |
|
t = t["x"] |
|
|
|
if self.half: |
|
return t.to(torch.float16) |
|
|
|
return t |
|
|
|
def __call__(self, x, normalize=False, rescale=False): |
|
""" |
|
Call process() |
|
|
|
x can be: dict {"task": data} list [data, ..] or data |
|
data ^ can be a str, a Path, a numpy arrray or a Tensor |
|
""" |
|
if isinstance(x, dict): |
|
return {k: self.process(v, normalize, rescale) for k, v in x.items()} |
|
|
|
if isinstance(x, list): |
|
return [self.process(t, normalize, rescale) for t in x] |
|
|
|
return self.process(x, normalize, rescale) |
|
|
|
|
|
def get_transform(transform_item, mode): |
|
"""Returns the torchivion transform function associated to a |
|
transform_item listed in opts.data.transforms ; transform_item is |
|
an addict.Dict |
|
""" |
|
|
|
if transform_item.name == "crop" and not ( |
|
transform_item.ignore is True or transform_item.ignore == mode |
|
): |
|
return RandomCrop( |
|
(transform_item.height, transform_item.width), |
|
center=transform_item.center == mode, |
|
) |
|
|
|
elif transform_item.name == "resize" and not ( |
|
transform_item.ignore is True or transform_item.ignore == mode |
|
): |
|
return Resize( |
|
transform_item.new_size, transform_item.get("keep_aspect_ratio", False) |
|
) |
|
|
|
elif transform_item.name == "hflip" and not ( |
|
transform_item.ignore is True or transform_item.ignore == mode |
|
): |
|
return RandomHorizontalFlip(p=transform_item.p or 0.5) |
|
|
|
elif transform_item.name == "brightness" and not ( |
|
transform_item.ignore is True or transform_item.ignore == mode |
|
): |
|
return RandBrightness() |
|
|
|
elif transform_item.name == "saturation" and not ( |
|
transform_item.ignore is True or transform_item.ignore == mode |
|
): |
|
return RandSaturation() |
|
|
|
elif transform_item.name == "contrast" and not ( |
|
transform_item.ignore is True or transform_item.ignore == mode |
|
): |
|
return RandContrast() |
|
|
|
elif transform_item.ignore is True or transform_item.ignore == mode: |
|
return None |
|
|
|
raise ValueError("Unknown transform_item {}".format(transform_item)) |
|
|
|
|
|
def get_transforms(opts, mode, domain): |
|
"""Get all the transform functions listed in opts.data.transforms |
|
using get_transform(transform_item, mode) |
|
""" |
|
transforms = [] |
|
color_jittering_transforms = ["brightness", "saturation", "contrast"] |
|
|
|
for t in opts.data.transforms: |
|
if t.name not in color_jittering_transforms: |
|
transforms.append(get_transform(t, mode)) |
|
|
|
if "p" not in opts.tasks and mode == "train": |
|
for t in opts.data.transforms: |
|
if t.name in color_jittering_transforms: |
|
transforms.append(get_transform(t, mode)) |
|
|
|
transforms += [Normalize(opts), BucketizeDepth(opts, domain)] |
|
transforms = [t for t in transforms if t is not None] |
|
|
|
return transforms |
|
|
|
|
|
|
|
def rand_brightness(tensor, is_diff_augment=False): |
|
if is_diff_augment: |
|
assert len(tensor.shape) == 4 |
|
type_ = tensor.dtype |
|
device_ = tensor.device |
|
rand_tens = torch.rand(tensor.size(0), 1, 1, 1, dtype=type_, device=device_) |
|
return tensor + (rand_tens - 0.5) |
|
else: |
|
factor = random.uniform(0.5, 1.5) |
|
tensor = adjust_brightness(tensor, brightness_factor=factor) |
|
|
|
tensor[:, :, 0, 0] = 1.0 |
|
tensor[:, :, -1, -1] = 0.0 |
|
return tensor |
|
|
|
|
|
def rand_saturation(tensor, is_diff_augment=False): |
|
if is_diff_augment: |
|
assert len(tensor.shape) == 4 |
|
type_ = tensor.dtype |
|
device_ = tensor.device |
|
rand_tens = torch.rand(tensor.size(0), 1, 1, 1, dtype=type_, device=device_) |
|
x_mean = tensor.mean(dim=1, keepdim=True) |
|
return (tensor - x_mean) * (rand_tens * 2) + x_mean |
|
else: |
|
factor = random.uniform(0.5, 1.5) |
|
tensor = adjust_saturation(tensor, saturation_factor=factor) |
|
|
|
tensor[:, :, 0, 0] = 1.0 |
|
tensor[:, :, -1, -1] = 0.0 |
|
return tensor |
|
|
|
|
|
def rand_contrast(tensor, is_diff_augment=False): |
|
if is_diff_augment: |
|
assert len(tensor.shape) == 4 |
|
type_ = tensor.dtype |
|
device_ = tensor.device |
|
rand_tens = torch.rand(tensor.size(0), 1, 1, 1, dtype=type_, device=device_) |
|
x_mean = tensor.mean(dim=[1, 2, 3], keepdim=True) |
|
return (tensor - x_mean) * (rand_tens + 0.5) + x_mean |
|
else: |
|
factor = random.uniform(0.5, 1.5) |
|
tensor = adjust_contrast(tensor, contrast_factor=factor) |
|
|
|
tensor[:, :, 0, 0] = 1.0 |
|
tensor[:, :, -1, -1] = 0.0 |
|
return tensor |
|
|
|
|
|
def rand_cutout(tensor, ratio=0.5): |
|
assert len(tensor.shape) == 4, "For rand cutout, tensor must be 4D." |
|
type_ = tensor.dtype |
|
device_ = tensor.device |
|
cutout_size = int(tensor.size(-2) * ratio + 0.5), int(tensor.size(-1) * ratio + 0.5) |
|
grid_batch, grid_x, grid_y = torch.meshgrid( |
|
torch.arange(tensor.size(0), dtype=torch.long, device=device_), |
|
torch.arange(cutout_size[0], dtype=torch.long, device=device_), |
|
torch.arange(cutout_size[1], dtype=torch.long, device=device_), |
|
) |
|
size_ = [tensor.size(0), 1, 1] |
|
offset_x = torch.randint( |
|
0, |
|
tensor.size(-2) + (1 - cutout_size[0] % 2), |
|
size=size_, |
|
device=device_, |
|
) |
|
offset_y = torch.randint( |
|
0, |
|
tensor.size(-1) + (1 - cutout_size[1] % 2), |
|
size=size_, |
|
device=device_, |
|
) |
|
grid_x = torch.clamp( |
|
grid_x + offset_x - cutout_size[0] // 2, min=0, max=tensor.size(-2) - 1 |
|
) |
|
grid_y = torch.clamp( |
|
grid_y + offset_y - cutout_size[1] // 2, min=0, max=tensor.size(-1) - 1 |
|
) |
|
mask = torch.ones( |
|
tensor.size(0), tensor.size(2), tensor.size(3), dtype=type_, device=device_ |
|
) |
|
mask[grid_batch, grid_x, grid_y] = 0 |
|
return tensor * mask.unsqueeze(1) |
|
|
|
|
|
def rand_translation(tensor, ratio=0.125): |
|
assert len(tensor.shape) == 4, "For rand translation, tensor must be 4D." |
|
device_ = tensor.device |
|
shift_x, shift_y = ( |
|
int(tensor.size(2) * ratio + 0.5), |
|
int(tensor.size(3) * ratio + 0.5), |
|
) |
|
translation_x = torch.randint( |
|
-shift_x, shift_x + 1, size=[tensor.size(0), 1, 1], device=device_ |
|
) |
|
translation_y = torch.randint( |
|
-shift_y, shift_y + 1, size=[tensor.size(0), 1, 1], device=device_ |
|
) |
|
grid_batch, grid_x, grid_y = torch.meshgrid( |
|
torch.arange(tensor.size(0), dtype=torch.long, device=device_), |
|
torch.arange(tensor.size(2), dtype=torch.long, device=device_), |
|
torch.arange(tensor.size(3), dtype=torch.long, device=device_), |
|
) |
|
grid_x = torch.clamp(grid_x + translation_x + 1, 0, tensor.size(2) + 1) |
|
grid_y = torch.clamp(grid_y + translation_y + 1, 0, tensor.size(3) + 1) |
|
x_pad = F.pad(tensor, [1, 1, 1, 1, 0, 0, 0, 0]) |
|
tensor = ( |
|
x_pad.permute(0, 2, 3, 1) |
|
.contiguous()[grid_batch, grid_x, grid_y] |
|
.permute(0, 3, 1, 2) |
|
) |
|
return tensor |
|
|
|
|
|
class DiffTransforms: |
|
def __init__(self, diff_aug_opts): |
|
self.do_color_jittering = diff_aug_opts.do_color_jittering |
|
self.do_cutout = diff_aug_opts.do_cutout |
|
self.do_translation = diff_aug_opts.do_translation |
|
self.cutout_ratio = diff_aug_opts.cutout_ratio |
|
self.translation_ratio = diff_aug_opts.translation_ratio |
|
|
|
def __call__(self, tensor): |
|
if self.do_color_jittering: |
|
tensor = rand_brightness(tensor, is_diff_augment=True) |
|
tensor = rand_contrast(tensor, is_diff_augment=True) |
|
tensor = rand_saturation(tensor, is_diff_augment=True) |
|
if self.do_translation: |
|
tensor = rand_translation(tensor, ratio=self.translation_ratio) |
|
if self.do_cutout: |
|
tensor = rand_cutout(tensor, ratio=self.cutout_ratio) |
|
return tensor |
|
|