""" |
Main component: the trainer handles everything: |
* initializations |
* training |
* saving |
""" |
import inspect |
import warnings |
from copy import deepcopy |
from pathlib import Path |
from time import time |
import numpy as np |
from comet_ml import ExistingExperiment, Experiment |
warnings.simplefilter("ignore", UserWarning) |
import torch |
import torch.nn as nn |
from addict import Dict |
from torch import autograd, sigmoid, softmax |
from torch.cuda.amp import GradScaler, autocast |
from tqdm import tqdm |
from climategan.data import get_all_loaders, decode_segmap_merged_labels |
from climategan.discriminator import OmniDiscriminator, create_discriminator |
from climategan.eval_metrics import accuracy, mIOU |
from climategan.fid import compute_val_fid |
from climategan.fire import add_fire |
from climategan.generator import OmniGenerator, create_generator |
from climategan.logger import Logger |
from climategan.losses import get_losses |
from climategan.optim import get_optimizer |
from climategan.transforms import DiffTransforms |
from climategan.tutils import ( |
divide_pred, |
get_num_params, |
get_WGAN_gradient, |
lrgb2srgb, |
normalize, |
print_num_parameters, |
shuffle_batch_tuple, |
srgb2lrgb, |
tensor_to_uint8_numpy_image, |
vgg_preprocess, |
zero_grad, |
) |
from climategan.utils import ( |
comet_kwargs, |
div_dict, |
find_target_size, |
flatten_opts, |
get_display_indices, |
get_existing_comet_id, |
get_latest_opts, |
merge, |
resolve, |
sum_dict, |
Timer, |
) |
try: |
import torch_xla.core.xla_model as xm |
except ImportError: |
pass |
class Trainer: |
"""Main trainer class""" |
def __init__(self, opts, comet_exp=None, verbose=0, device=None): |
"""Trainer class to gather various model training procedures |
such as training evaluating saving and logging |
init: |
* creates an addict.Dict logger |
* creates logger.exp as a comet_exp experiment if `comet` arg is True |
* sets the device (1 GPU or CPU) |
Args: |
opts (addict.Dict): options to configure the trainer, the data, the models |
comet (bool, optional): whether to log the trainer with comet.ml. |
Defaults to False. |
verbose (int, optional): printing level to debug. Defaults to 0. |
""" |
super().__init__() |
self.opts = opts |
self.verbose = verbose |
self.logger = Logger(self) |
self.losses = None |
self.G = self.D = None |
self.real_val_fid_stats = None |
self.use_pl4m = False |
self.is_setup = False |
self.loaders = self.all_loaders = None |
self.exp = None |
self.current_mode = "train" |
self.diff_transforms = None |
self.kitti_pretrain = self.opts.train.kitti.pretrain |
self.pseudo_training_tasks = set(self.opts.train.pseudo.tasks) |
self.lr_names = {} |
self.base_display_images = {} |
self.kitty_display_images = {} |
self.domain_labels = {"s": 0, "r": 1} |
self.device = device or torch.device( |
"cuda:0" if torch.cuda.is_available() else "cpu" |
) |
if isinstance(comet_exp, Experiment): |
self.exp = comet_exp |
if self.opts.train.amp: |
optimizers = [ |
self.opts.gen.opt.optimizer.lower(), |
self.opts.dis.opt.optimizer.lower(), |
] |
if "extraadam" in optimizers: |
raise ValueError( |
"AMP does not work with ExtraAdam ({})".format(optimizers) |
) |
self.grad_scaler_d = GradScaler() |
self.grad_scaler_g = GradScaler() |
if ( |
self.opts.gen.s.depth_feat_fusion is True |
or self.opts.gen.s.depth_dada_fusion is True |
): |
self.opts.gen.s.use_dada = True |
@torch.no_grad() |
def paint_and_mask(self, image_batch, mask_batch=None, resolution="approx"): |
""" |
Paints a batch of images (or a single image with a batch dim of 1). If |
masks are not provided, they are inferred from the masker. |
Resolution can either be the train-time resolution or the closest |
multiple of 2 ** spade_n_up |
Operations performed without gradient |
If resolution == "approx" then the output image has the shape: |
(dim // 2 ** spade_n_up) * 2 ** spade_n_up, for dim in [height, width] |
eg: (1000, 1300) => (896, 1280) for spade_n_up = 7 |
If resolution == "exact" then the output image has the same shape: |
we first process in "approx" mode then upsample bilinear |
If resolution == "basic" image output shape is the train-time's |
(typically 640x640) |
If resolution == "upsample" image is inferred as "basic" and |
then upsampled to original size |
Args: |
image_batch (torch.Tensor): 4D batch of images to flood |
mask_batch (torch.Tensor, optional): Masks for the images. |
Defaults to None (infer with Masker). |
resolution (str, optional): "approx", "exact" or False |
Returns: |
torch.Tensor: N x C x H x W where H and W depend on `resolution` |
""" |
assert resolution in {"approx", "exact", "basic", "upsample"} |
previous_mode = self.current_mode |
if previous_mode == "train": |
self.eval_mode() |
if mask_batch is None: |
mask_batch = self.G.mask(x=image_batch) |
else: |
assert len(image_batch) == len(mask_batch) |
assert image_batch.shape[-2:] == mask_batch.shape[-2:] |
if resolution not in {"approx", "exact"}: |
painted = self.G.paint(mask_batch, image_batch) |
if resolution == "upsample": |
painted = nn.functional.interpolate( |
painted, size=image_batch.shape[-2:], mode="bilinear" |
) |
else: |
zh = self.G.painter.z_h |
zw = self.G.painter.z_w |
self.G.painter.z_h = ( |
image_batch.shape[-2] // 2**self.opts.gen.p.spade_n_up |
) |
self.G.painter.z_w = ( |
image_batch.shape[-1] // 2**self.opts.gen.p.spade_n_up |
) |
painted = self.G.paint(mask_batch, image_batch) |
self.G.painter.z_h = zh |
self.G.painter.z_w = zw |
if resolution == "exact": |
painted = nn.functional.interpolate( |
painted, size=image_batch.shape[-2:], mode="bilinear" |
) |
if previous_mode == "train": |
self.train_mode() |
return painted |
def _p(self, *args, **kwargs): |
""" |
verbose-dependant print util |
""" |
if self.verbose > 0: |
print(*args, **kwargs) |
@torch.no_grad() |
def infer_all( |
self, |
x, |
numpy=True, |
stores={}, |
bin_value=-1, |
half=False, |
xla=False, |
cloudy=True, |
auto_resize_640=False, |
ignore_event=set(), |
return_intermediates=False, |
): |
""" |
Create a dictionary of events from a numpy or tensor, |
single or batch image data. |
stores is a dictionary of times for the Timer class. |
bin_value is used to binarize (or not) flood masks |
all values in the output dictionary have 4 dimensions: |
BxHxWxC if numpy else BxCxHxW |
""" |
assert self.is_setup |
assert len(x.shape) in {3, 4}, f"Unknown Data shape {x.shape}" |
if not isinstance(x, torch.Tensor): |
x = torch.tensor(x, device=self.device) |
if len(x.shape) == 3: |
x.unsqueeze_(0) |
if x.shape[1] != 3: |
assert x.shape[-1] == 3, f"Unknown x shape to permute {x.shape}" |
x = x.permute(0, 3, 1, 2) |
if x.device != self.device: |
x = x.to(self.device) |
if auto_resize_640 and (x.shape[-1] != 640 or x.shape[-2] != 640): |
x = torch.nn.functional.interpolate(x, (640, 640), mode="bilinear") |
if half: |
x = x.half() |
self.G.painter.set_latent_shape(x.shape, True) |
with Timer(store=stores.get("all events", [])): |
with Timer(store=stores.get("encode", [])): |
z = self.G.encode(x) |
if xla: |
xm.mark_step() |
with Timer(store=stores.get("depth", [])): |
depth, z_depth = self.G.decoders["d"](z) |
if xla: |
xm.mark_step() |
with Timer(store=stores.get("segmentation", [])): |
segmentation = self.G.decoders["s"](z, z_depth) |
if xla: |
xm.mark_step() |
with Timer(store=stores.get("mask", [])): |
cond = self.G.make_m_cond(depth, segmentation, x) |
mask = self.G.mask(z=z, cond=cond, z_depth=z_depth) |
if xla: |
xm.mark_step() |
if "wildfire" not in ignore_event: |
with Timer(store=stores.get("wildfire", [])): |
wildfire = self.compute_fire(x, seg_preds=segmentation) |
if "smog" not in ignore_event: |
with Timer(store=stores.get("smog", [])): |
smog = self.compute_smog(x, d=depth, s=segmentation) |
if "flood" not in ignore_event: |
with Timer(store=stores.get("flood", [])): |
flood = self.compute_flood( |
x, |
m=mask, |
s=segmentation, |
cloudy=cloudy, |
bin_value=bin_value, |
) |
if xla: |
xm.mark_step() |
output_data = {} |
if numpy: |
with Timer(store=stores.get("numpy", [])): |
if "flood" not in ignore_event: |
flood = tensor_to_uint8_numpy_image(flood) |
output_data["flood"] = flood |
if "wildfire" not in ignore_event: |
wildfire = tensor_to_uint8_numpy_image(wildfire) |
output_data["wildfire"] = wildfire |
if "smog" not in ignore_event: |
smog = tensor_to_uint8_numpy_image(smog) |
output_data["smog"] = smog |
if return_intermediates: |
if numpy: |
output_data["mask"] = ( |
((mask > bin_value) * 255).cpu().numpy().astype(np.uint8) |
) |
output_data["depth"] = tensor_to_uint8_numpy_image(depth) |
output_data["segmentation"] = ( |
decode_segmap_merged_labels(segmentation, "r", False) |
.cpu() |
.permute(0, 2, 3, 1) |
.numpy() |
.astype(np.uint8) |
) |
else: |
output_data["mask"] = mask |
output_data["depth"] = depth |
output_data["segmentation"] = segmentation |
return output_data |
@classmethod |
def resume_from_path( |
cls, |
path, |
overrides={}, |
setup=True, |
inference=False, |
new_exp=False, |
device=None, |
verbose=1, |
): |
""" |
Resume and optionally setup a trainer from a specific path, |
using the latest opts and checkpoint. Requires path to contain opts.yaml |
(or increased), url.txt (or increased) and checkpoints/ |
Args: |
path (str | pathlib.Path): Trainer to resume |
overrides (dict, optional): Override loaded opts with those. Defaults to {}. |
setup (bool, optional): Wether or not to setup the trainer before |
returning it. Defaults to True. |
inference (bool, optional): Setup should be done in inference mode or not. |
Defaults to False. |
new_exp (bool, optional): Re-use existing comet exp in path or create |
a new one? Defaults to False. |
device (torch.device, optional): Device to use |
Returns: |
climategan.Trainer: Loaded and resumed trainer |
""" |
p = resolve(path) |
assert p.exists() |
c = p / "checkpoints" |
assert c.exists() and c.is_dir() |
opts = get_latest_opts(p) |
opts = Dict(merge(overrides, opts)) |
opts.train.resume = True |
if new_exp is None: |
exp = None |
elif new_exp is True: |
exp = Experiment(project_name="climategan", **comet_kwargs) |
exp.log_asset_folder( |
str(resolve(Path(__file__)).parent), |
recursive=True, |
log_file_name=True, |
) |
exp.log_parameters(flatten_opts(opts)) |
else: |
comet_id = get_existing_comet_id(p) |
exp = ExistingExperiment(previous_experiment=comet_id, **comet_kwargs) |
trainer = cls(opts, comet_exp=exp, device=device, verbose=verbose) |
if setup: |
trainer.setup(inference=inference) |
return trainer |
def save(self): |
save_dir = Path(self.opts.output_path) / Path("checkpoints") |
save_dir.mkdir(exist_ok=True) |
save_path = save_dir / "latest_ckpt.pth" |
save_dict = { |
"epoch": self.logger.epoch, |
"G": self.G.state_dict(), |
"g_opt": self.g_opt.state_dict(), |
"step": self.logger.global_step, |
} |
if self.D is not None and get_num_params(self.D) > 0: |
save_dict["D"] = self.D.state_dict() |
save_dict["d_opt"] = self.d_opt.state_dict() |
if ( |
self.logger.epoch >= self.opts.train.min_save_epoch |
and self.logger.epoch % self.opts.train.save_n_epochs == 0 |
): |
torch.save(save_dict, save_dir / f"epoch_{self.logger.epoch}_ckpt.pth") |
torch.save(save_dict, save_path) |
def resume(self, inference=False): |
tpu = "xla" in str(self.device) |
if tpu: |
print("Resuming on TPU:", self.device) |
m_path = Path(self.opts.load_paths.m) |
p_path = Path(self.opts.load_paths.p) |
pm_path = Path(self.opts.load_paths.pm) |
output_path = Path(self.opts.output_path) |
map_loc = self.device if not tpu else "cpu" |
if "m" in self.opts.tasks and "p" in self.opts.tasks: |
if all([str(p) == "none" for p in [m_path, p_path, pm_path]]): |
checkpoint_path = output_path / "checkpoints/latest_ckpt.pth" |
print("Resuming P+M model from", str(checkpoint_path)) |
checkpoint = torch.load(checkpoint_path, map_location=map_loc) |
elif str(pm_path) != "none": |
assert pm_path.exists() |
if pm_path.is_dir(): |
checkpoint_path = pm_path / "checkpoints/latest_ckpt.pth" |
else: |
assert pm_path.suffix == ".pth" |
checkpoint_path = pm_path |
print("Resuming P+M model from", str(checkpoint_path)) |
checkpoint = torch.load(checkpoint_path, map_location=map_loc) |
elif m_path != p_path: |
assert m_path.exists() |
assert p_path.exists() |
if m_path.is_dir(): |
m_path = m_path / "checkpoints/latest_ckpt.pth" |
if p_path.is_dir(): |
p_path = p_path / "checkpoints/latest_ckpt.pth" |
assert m_path.suffix == ".pth" |
assert p_path.suffix == ".pth" |
print(f"Resuming P+M model from \n -{p_path} \nand \n -{m_path}") |
m_checkpoint = torch.load(m_path, map_location=map_loc) |
p_checkpoint = torch.load(p_path, map_location=map_loc) |
checkpoint = merge(m_checkpoint, p_checkpoint) |
else: |
raise ValueError( |
"Cannot resume a P+M model with provided load_paths:\n{}".format( |
self.opts.load_paths |
) |
) |
else: |
if str(m_path) != "none" and str(p_path) != "none": |
raise ValueError( |
"Opts tasks are {} but received 2 values for the load_paths".format( |
self.opts.tasks |
) |
) |
elif str(m_path) != "none": |
assert m_path.exists() |
assert "m" in self.opts.tasks |
model = "M" |
if m_path.is_dir(): |
m_path = m_path / "checkpoints/latest_ckpt.pth" |
checkpoint_path = m_path |
elif str(p_path) != "none": |
assert p_path.exists() |
assert "p" in self.opts.tasks |
model = "P" |
if p_path.is_dir(): |
p_path = p_path / "checkpoints/latest_ckpt.pth" |
checkpoint_path = p_path |
else: |
model = "P" if "p" in self.opts.tasks else "M" |
checkpoint_path = output_path / "checkpoints/latest_ckpt.pth" |
print(f"Resuming {model} model from {checkpoint_path}") |
checkpoint = torch.load(checkpoint_path, map_location=map_loc) |
if tpu: |
checkpoint = xm.send_cpu_data_to_device(checkpoint, self.device) |
if inference: |
incompatible_keys = self.G.load_state_dict(checkpoint["G"], strict=False) |
if incompatible_keys.missing_keys: |
print("WARNING: Missing keys in self.G.load_state_dict, keeping inits") |
print(incompatible_keys.missing_keys) |
if incompatible_keys.unexpected_keys: |
print("WARNING: Ignoring Unexpected keys in self.G.load_state_dict") |
print(incompatible_keys.unexpected_keys) |
else: |
self.G.load_state_dict(checkpoint["G"]) |
if inference: |
print("Done loading checkpoints.") |
return |
self.g_opt.load_state_dict(checkpoint["g_opt"]) |
for _ in range(self.logger.epoch + 1): |
self.update_learning_rates() |
if self.D is not None and get_num_params(self.D) > 0: |
self.D.load_state_dict(checkpoint["D"]) |
self.d_opt.load_state_dict(checkpoint["d_opt"]) |
self.logger.epoch = checkpoint["epoch"] |
self.logger.global_step = checkpoint["step"] |
self.exp.log_text( |
"Resuming from epoch {} & step {}".format( |
checkpoint["epoch"], checkpoint["step"] |
) |
) |
if self.logger.global_step % 2 != 0: |
self.logger.global_step += 1 |
def eval_mode(self): |
""" |
Set trainer's models in eval mode |
""" |
if self.G is not None: |
self.G.eval() |
if self.D is not None: |
self.D.eval() |
self.current_mode = "eval" |
def train_mode(self): |
""" |
Set trainer's models in train mode |
""" |
if self.G is not None: |
self.G.train() |
if self.D is not None: |
self.D.train() |
self.current_mode = "train" |
def assert_z_matches_x(self, x, z): |
assert x.shape[0] == ( |
z.shape[0] if not isinstance(z, (list, tuple)) else z[0].shape[0] |
), "x-> {}, z->{}".format( |
x.shape, z.shape if not isinstance(z, (list, tuple)) else z[0].shape |
) |
def batch_to_device(self, b): |
"""sends the data in b to self.device |
Args: |
b (dict): the batch dictionnay |
Returns: |
dict: the batch dictionnary with its "data" field sent to self.device |
""" |
for task, tensor in b["data"].items(): |
b["data"][task] = tensor.to(self.device) |
return b |
def sample_painter_z(self, batch_size): |
return self.G.sample_painter_z(batch_size, self.device) |
@property |
def train_loaders(self): |
"""Get a zip of all training loaders |
Returns: |
generator: zip generator yielding tuples: |
(batch_rf, batch_rn, batch_sf, batch_sn) |
""" |
return zip(*list(self.loaders["train"].values())) |
@property |
def val_loaders(self): |
"""Get a zip of all validation loaders |
Returns: |
generator: zip generator yielding tuples: |
(batch_rf, batch_rn, batch_sf, batch_sn) |
""" |
return zip(*list(self.loaders["val"].values())) |
def compute_latent_shape(self): |
"""Compute the latent shape, i.e. the Encoder's output shape, |
from a batch. |
Raises: |
ValueError: If no loader, the latent_shape cannot be inferred |
Returns: |
tuple: (c, h, w) |
""" |
x = None |
for mode in self.all_loaders: |
for domain in self.all_loaders.loaders[mode]: |
x = ( |
self.all_loaders[mode][domain] |
.dataset[0]["data"]["x"] |
.to(self.device) |
) |
break |
if x is not None: |
break |
if x is None: |
raise ValueError("No batch found to compute_latent_shape") |
x = x.unsqueeze(0) |
z = self.G.encode(x) |
return z.shape[1:] if not isinstance(z, (list, tuple)) else z[0].shape[1:] |
def g_opt_step(self): |
"""Run an optimizing step ; if using ExtraAdam, there needs to be an extrapolation |
step every other step |
""" |
if "extra" in self.opts.gen.opt.optimizer.lower() and ( |
self.logger.global_step % 2 == 0 |
): |
self.g_opt.extrapolation() |
else: |
self.g_opt.step() |
def d_opt_step(self): |
"""Run an optimizing step ; if using ExtraAdam, there needs to be an extrapolation |
step every other step |
""" |
if "extra" in self.opts.dis.opt.optimizer.lower() and ( |
self.logger.global_step % 2 == 0 |
): |
self.d_opt.extrapolation() |
else: |
self.d_opt.step() |
def update_learning_rates(self): |
if self.g_scheduler is not None: |
self.g_scheduler.step() |
if self.d_scheduler is not None: |
self.d_scheduler.step() |
def setup(self, inference=False): |
"""Prepare the trainer before it can be used to train the models: |
* initialize G and D |
* creates 2 optimizers |
""" |
self.logger.global_step = 0 |
start_time = time() |
self.logger.time.start_time = start_time |
verbose = self.verbose |
if not inference: |
self.all_loaders = get_all_loaders(self.opts) |
__t = time() |
print("Creating generator...") |
self.G: OmniGenerator = create_generator( |
self.opts, device=self.device, no_init=inference, verbose=verbose |
) |
self.has_painter = get_num_params(self.G.painter) or self.G.load_val_painter() |
if self.has_painter: |
self.G.painter.set_latent_shape(find_target_size(self.opts, "x"), True) |
print(f"Generator OK in {time() - __t:.1f}s.") |
if inference: |
print("Inference mode: no Discriminator, no optimizers") |
print_num_parameters(self) |
self.switch_data(to="base") |
if self.opts.train.resume: |
self.resume(True) |
self.eval_mode() |
print("Trainer is in evaluation mode.") |
print("Setup done.") |
self.is_setup = True |
return |
self.D: OmniDiscriminator = create_discriminator( |
self.opts, self.device, verbose=verbose |
) |
print("Discriminator OK.") |
print_num_parameters(self) |
self.g_opt, self.g_scheduler, self.lr_names["G"] = get_optimizer( |
self.G, self.opts.gen.opt, self.opts.tasks |
) |
if get_num_params(self.D) > 0: |
self.d_opt, self.d_scheduler, self.lr_names["D"] = get_optimizer( |
self.D, self.opts.dis.opt, self.opts.tasks, True |
) |
else: |
self.d_opt, self.d_scheduler = None, None |
self.losses = get_losses(self.opts, verbose, device=self.device) |
if "p" in self.opts.tasks and self.opts.gen.p.diff_aug.use: |
self.diff_transforms = DiffTransforms(self.opts.gen.p.diff_aug) |
if verbose > 0: |
for mode, mode_dict in self.all_loaders.items(): |
for domain, domain_loader in mode_dict.items(): |
print( |
"Loader {} {} : {}".format( |
mode, domain, len(domain_loader.dataset) |
) |
) |
self.set_display_images() |
self.logger.log_architecture() |
if self.kitti_pretrain: |
self.switch_data(to="kitti") |
else: |
self.switch_data(to="base") |
print(" " * 50, end="\r") |
print("Done creating display images") |
if self.opts.train.resume: |
print("Resuming Model (inference: False)") |
self.resume(False) |
else: |
print("Not resuming: starting a new model") |
print("Setup done.") |
self.is_setup = True |
def switch_data(self, to="kitti"): |
caller = inspect.stack()[1].function |
print(f"[{caller}] Switching data source to", to) |
self.data_source = to |
if to == "kitti": |
self.display_images = self.kitty_display_images |
if self.all_loaders is not None: |
self.loaders = { |
mode: {"s": self.all_loaders[mode]["kitti"]} |
for mode in self.all_loaders |
} |
else: |
self.display_images = self.base_display_images |
if self.all_loaders is not None: |
self.loaders = { |
mode: { |
domain: self.all_loaders[mode][domain] |
for domain in self.all_loaders[mode] |
if domain != "kitti" |
} |
for mode in self.all_loaders |
} |
if ( |
self.logger.global_step % 2 != 0 |
and "extra" in self.opts.dis.opt.optimizer.lower() |
): |
print( |
"Warning: artificially bumping step to run an extrapolation step first." |
) |
self.logger.global_step += 1 |
def set_display_images(self, use_all=False): |
for mode, mode_dict in self.all_loaders.items(): |
if self.kitti_pretrain: |
self.kitty_display_images[mode] = {} |
self.base_display_images[mode] = {} |
for domain in mode_dict: |
if self.kitti_pretrain and domain == "kitti": |
target_dict = self.kitty_display_images |
else: |
if domain == "kitti": |
continue |
target_dict = self.base_display_images |
dataset = self.all_loaders[mode][domain].dataset |
display_indices = ( |
get_display_indices(self.opts, domain, len(dataset)) |
if not use_all |
else list(range(len(dataset))) |
) |
ldis = len(display_indices) |
print( |
f" Creating {ldis} {mode} {domain} display images...", |
end="\r", |
flush=True, |
) |
target_dict[mode][domain] = [ |
Dict(dataset[i]) |
for i in display_indices |
if (print(f"({i})", end="\r") is None and i < len(dataset)) |
] |
if self.exp is not None: |
for im_id, d in enumerate(target_dict[mode][domain]): |
self.exp.log_parameter( |
"display_image_{}_{}_{}".format(mode, domain, im_id), |
d["paths"], |
) |
def train(self): |
"""For each epoch: |
* train |
* eval |
* save |
""" |
assert self.is_setup |
for self.logger.epoch in range( |
self.logger.epoch, self.logger.epoch + self.opts.train.epochs |
): |
if ( |
self.logger.epoch == self.opts.gen.p.pl4m_epoch |
and get_num_params(self.G.painter) > 0 |
and "p" in self.opts.tasks |
and self.opts.gen.m.use_pl4m |
): |
print( |
"\n\n >>> Enabling pl4m at epoch {}\n\n".format(self.logger.epoch) |
) |
self.use_pl4m = True |
self.run_epoch() |
self.run_evaluation(verbose=1) |
self.save() |
if self.logger.epoch == self.opts.train.kitti.epochs - 1: |
self.switch_data(to="base") |
self.kitti_pretrain = False |
if self.logger.epoch == self.opts.train.pseudo.epochs - 1: |
self.pseudo_training_tasks = set() |
def run_epoch(self): |
"""Runs an epoch: |
* checks trainer is setup |
* gets a tuple of batches per domain |
* sends batches to device |
* updates sequentially G, D |
""" |
assert self.is_setup |
self.train_mode() |
if self.exp is not None: |
self.exp.log_parameter("epoch", self.logger.epoch) |
epoch_len = min(len(loader) for loader in self.loaders["train"].values()) |
epoch_desc = "Epoch {}".format(self.logger.epoch) |
self.logger.time.epoch_start = time() |
for multi_batch_tuple in tqdm( |
self.train_loaders, |
desc=epoch_desc, |
total=epoch_len, |
mininterval=0.5, |
unit="batch", |
): |
self.logger.time.step_start = time() |
multi_batch_tuple = shuffle_batch_tuple(multi_batch_tuple) |
multi_domain_batch = { |
batch["domain"][0]: self.batch_to_device(batch) |
for batch in multi_batch_tuple |
} |
if self.d_opt is not None: |
for param in self.D.parameters(): |
param.requires_grad = False |
self.update_G(multi_domain_batch) |
if self.d_opt is not None and not self.kitti_pretrain: |
for param in self.D.parameters(): |
param.requires_grad = True |
self.update_D(multi_domain_batch) |
self.logger.global_step += 1 |
self.logger.log_step_time(time()) |
if not self.kitti_pretrain: |
self.update_learning_rates() |
self.logger.log_learning_rates() |
self.logger.log_epoch_time(time()) |
def update_G(self, multi_domain_batch, verbose=0): |
"""Perform an update on g from multi_domain_batch which is a dictionary |
domain => batch |
* automatic mixed precision according to self.opts.train.amp |
* compute loss for each task |
* loss.backward() |
* g_opt_step() |
* g_opt.step() or .extrapolation() depending on self.logger.global_step |
* logs losses on comet.ml with self.logger.log_losses(model_to_update="G") |
Args: |
multi_domain_batch (dict): dictionnary of domain batches |
""" |
zero_grad(self.G) |
if self.opts.train.amp: |
with autocast(): |
g_loss = self.get_G_loss(multi_domain_batch, verbose) |
self.grad_scaler_g.scale(g_loss).backward() |
self.grad_scaler_g.step(self.g_opt) |
self.grad_scaler_g.update() |
else: |
g_loss = self.get_G_loss(multi_domain_batch, verbose) |
g_loss.backward() |
self.g_opt_step() |
self.logger.log_losses(model_to_update="G", mode="train") |
def update_D(self, multi_domain_batch, verbose=0): |
zero_grad(self.D) |
if self.opts.train.amp: |
with autocast(): |
d_loss = self.get_D_loss(multi_domain_batch, verbose) |
self.grad_scaler_d.scale(d_loss).backward() |
self.grad_scaler_d.step(self.d_opt) |
self.grad_scaler_d.update() |
else: |
d_loss = self.get_D_loss(multi_domain_batch, verbose) |
d_loss.backward() |
self.d_opt_step() |
self.logger.losses.disc.total_loss = d_loss.item() |
self.logger.log_losses(model_to_update="D", mode="train") |
def get_D_loss(self, multi_domain_batch, verbose=0): |
"""Compute the discriminators' losses: |
* for each domain-specific batch: |
* encode the image |
* get the conditioning tensor if using spade |
* source domain is the data's domain, sequentially r|s then f|n |
* get the target domain accordingly |
* compute the translated image from the data |
* compute the source domain discriminator's loss on the data |
* compute the target domain discriminator's loss on the translated image |
# ? In this setting, each D[decoder][domain] is updated twice towards |
# real or fake data |
See readme's update d section for details |
Args: |
multi_domain_batch ([type]): [description] |
Returns: |
[type]: [description] |
""" |
disc_loss = { |
"m": {"Advent": 0}, |
"s": {"Advent": 0}, |
} |
if self.opts.dis.p.use_local_discriminator: |
disc_loss["p"] = {"global": 0, "local": 0} |
else: |
disc_loss["p"] = {"gan": 0} |
for domain, batch in multi_domain_batch.items(): |
x = batch["data"]["x"] |
if domain == "rf" and self.has_painter: |
m = batch["data"]["m"] |
with torch.no_grad(): |
fake = self.G.paint(m, x) |
if self.opts.gen.p.diff_aug.use: |
fake = self.diff_transforms(fake) |
x = self.diff_transforms(x) |
fake = fake.detach() |
fake.requires_grad_() |
if self.opts.dis.p.use_local_discriminator: |
fake_d_global = self.D["p"]["global"](fake) |
real_d_global = self.D["p"]["global"](x) |
fake_d_local = self.D["p"]["local"](fake * m) |
real_d_local = self.D["p"]["local"](x * m) |
global_loss = self.losses["D"]["p"](fake_d_global, False, True) |
global_loss += self.losses["D"]["p"](real_d_global, True, True) |
local_loss = self.losses["D"]["p"](fake_d_local, False, True) |
local_loss += self.losses["D"]["p"](real_d_local, True, True) |
disc_loss["p"]["global"] += global_loss |
disc_loss["p"]["local"] += local_loss |
else: |
real_cat = torch.cat([m, x], axis=1) |
fake_cat = torch.cat([m, fake], axis=1) |
real_fake_cat = torch.cat([real_cat, fake_cat], dim=0) |
real_fake_d = self.D["p"](real_fake_cat) |
real_d, fake_d = divide_pred(real_fake_d) |
disc_loss["p"]["gan"] = self.losses["D"]["p"](fake_d, False, True) |
disc_loss["p"]["gan"] += self.losses["D"]["p"](real_d, True, True) |
else: |
z = self.G.encode(x) |
s_pred = d_pred = cond = z_depth = None |
if "s" in batch["data"]: |
if "d" in self.opts.tasks and self.opts.gen.s.use_dada: |
d_pred, z_depth = self.G.decoders["d"](z) |
step_loss, s_pred = self.masker_s_loss( |
x, z, d_pred, z_depth, None, domain, for_="D" |
) |
step_loss *= self.opts.train.lambdas.advent.adv_main |
disc_loss["s"]["Advent"] += step_loss |
if "m" in batch["data"]: |
if "d" in self.opts.tasks: |
if self.opts.gen.m.use_spade: |
if d_pred is None: |
d_pred, z_depth = self.G.decoders["d"](z) |
cond = self.G.make_m_cond(d_pred, s_pred, x) |
elif self.opts.gen.m.use_dada: |
if d_pred is None: |
d_pred, z_depth = self.G.decoders["d"](z) |
step_loss, _ = self.masker_m_loss( |
x, |
z, |
None, |
domain, |
for_="D", |
cond=cond, |
z_depth=z_depth, |
depth_preds=d_pred, |
) |
step_loss *= self.opts.train.lambdas.advent.adv_main |
disc_loss["m"]["Advent"] += step_loss |
self.logger.losses.disc.update( |
{ |
dom: { |
k: v.item() if isinstance(v, torch.Tensor) else v |
for k, v in d.items() |
} |
for dom, d in disc_loss.items() |
} |
) |
loss = sum(v for d in disc_loss.values() for k, v in d.items()) |
return loss |
def get_G_loss(self, multi_domain_batch, verbose=0): |
m_loss = p_loss = None |
g_loss = 0 |
if any(t in self.opts.tasks for t in "msd"): |
m_loss = self.get_masker_loss(multi_domain_batch) |
self.logger.losses.gen.masker = m_loss.item() |
g_loss += m_loss |
if "p" in self.opts.tasks and not self.kitti_pretrain: |
p_loss = self.get_painter_loss(multi_domain_batch) |
self.logger.losses.gen.painter = p_loss.item() |
g_loss += p_loss |
assert g_loss != 0 and not isinstance(g_loss, int), "No update in get_G_loss!" |
self.logger.losses.gen.total_loss = g_loss.item() |
return g_loss |
def get_masker_loss(self, multi_domain_batch): |
"""Only update the representation part of the model, meaning everything |
but the translation part |
* for each batch in available domains: |
* compute task-specific losses |
* compute the adaptation and translation decoders' auto-encoding losses |
* compute the adaptation decoder's translation losses (GAN and Cycle) |
Args: |
multi_domain_batch (dict): dictionnary mapping domain names to batches from |
the trainer's loaders |
Returns: |
torch.Tensor: scalar loss tensor, weighted according to opts.train.lambdas |
""" |
m_loss = 0 |
for domain, batch in multi_domain_batch.items(): |
if domain == "rf": |
continue |
x = batch["data"]["x"] |
z = self.G.encode(x) |
d_pred = s_pred = z_depth = None |
for task in ["d", "s", "m"]: |
if task not in batch["data"]: |
continue |
target = batch["data"][task] |
if task == "d": |
loss, d_pred, z_depth = self.masker_d_loss( |
x, z, target, domain, "G" |
) |
m_loss += loss |
self.logger.losses.gen.task["d"][domain] = loss.item() |
elif task == "s": |
loss, s_pred = self.masker_s_loss( |
x, z, d_pred, z_depth, target, domain, "G" |
) |
m_loss += loss |
self.logger.losses.gen.task["s"][domain] = loss.item() |
elif task == "m": |
cond = None |
if self.opts.gen.m.use_spade: |
if not self.opts.gen.m.detach: |
d_pred = d_pred.clone() |
s_pred = s_pred.clone() |
cond = self.G.make_m_cond(d_pred, s_pred, x) |
loss, _ = self.masker_m_loss( |
x, |
z, |
target, |
domain, |
"G", |
cond=cond, |
z_depth=z_depth, |
depth_preds=d_pred, |
) |
m_loss += loss |
self.logger.losses.gen.task["m"][domain] = loss.item() |
return m_loss |
def get_painter_loss(self, multi_domain_batch): |
"""Computes the translation loss when flooding/deflooding images |
Args: |
multi_domain_batch (dict): dictionnary mapping domain names to batches from |
the trainer's loaders |
Returns: |
torch.Tensor: scalar loss tensor, weighted according to opts.train.lambdas |
""" |
step_loss = 0 |
lambdas = self.opts.train.lambdas |
batch_domain = "rf" |
batch = multi_domain_batch[batch_domain] |
x = batch["data"]["x"] |
m = batch["data"]["m"] |
fake_flooded = self.G.paint(m, x) |
if lambdas.G.p.vgg != 0: |
loss = self.losses["G"]["p"]["vgg"]( |
vgg_preprocess(fake_flooded * m), vgg_preprocess(x * m) |
) |
loss *= lambdas.G.p.vgg |
self.logger.losses.gen.p.vgg = loss.item() |
step_loss += loss |
if lambdas.G.p.tv != 0: |
loss = self.losses["G"]["p"]["tv"](fake_flooded * m) |
loss *= lambdas.G.p.tv |
self.logger.losses.gen.p.tv = loss.item() |
step_loss += loss |
if lambdas.G.p.context != 0: |
loss = self.losses["G"]["p"]["context"](fake_flooded, x, m) |
loss *= lambdas.G.p.context |
self.logger.losses.gen.p.context = loss.item() |
step_loss += loss |
if lambdas.G.p.reconstruction != 0: |
loss = self.losses["G"]["p"]["reconstruction"](fake_flooded, x, m) |
loss *= lambdas.G.p.reconstruction |
self.logger.losses.gen.p.reconstruction = loss.item() |
step_loss += loss |
if self.opts.gen.p.diff_aug.use: |
fake_flooded = self.diff_transforms(fake_flooded) |
x = self.diff_transforms(x) |
if self.opts.dis.p.use_local_discriminator: |
fake_d_global = self.D["p"]["global"](fake_flooded) |
fake_d_local = self.D["p"]["local"](fake_flooded * m) |
real_d_global = self.D["p"]["global"](x) |
self.logger.losses.gen.p.gan = 0 |
loss = self.losses["G"]["p"]["gan"](fake_d_global, True, False) |
loss += self.losses["G"]["p"]["gan"](fake_d_local, True, False) |
loss *= lambdas.G["p"]["gan"] |
self.logger.losses.gen.p.gan = loss.item() |
step_loss += loss |
if self.opts.dis.p.get_intermediate_features: |
loss = self.losses["G"]["p"]["featmatch"](real_d_global, fake_d_global) |
loss *= lambdas.G["p"]["featmatch"] |
if isinstance(loss, float): |
self.logger.losses.gen.p.featmatch = loss |
else: |
self.logger.losses.gen.p.featmatch = loss.item() |
step_loss += loss |
else: |
real_cat = torch.cat([m, x], axis=1) |
fake_cat = torch.cat([m, fake_flooded], axis=1) |
real_fake_cat = torch.cat([real_cat, fake_cat], dim=0) |
real_fake_d = self.D["p"](real_fake_cat) |
real_d, fake_d = divide_pred(real_fake_d) |
loss = self.losses["G"]["p"]["gan"](fake_d, True, False) |
self.logger.losses.gen.p.gan = loss.item() |
step_loss += loss |
if self.opts.dis.p.get_intermediate_features and lambdas.G.p.featmatch != 0: |
loss = self.losses["G"]["p"]["featmatch"](real_d, fake_d) |
loss *= lambdas.G.p.featmatch |
if isinstance(loss, float): |
self.logger.losses.gen.p.featmatch = loss |
else: |
self.logger.losses.gen.p.featmatch = loss.item() |
step_loss += loss |
return step_loss |
def masker_d_loss(self, x, z, target, domain, for_="G"): |
assert for_ in {"G", "D"} |
self.assert_z_matches_x(x, z) |
assert x.shape[0] == target.shape[0] |
zero_loss = torch.tensor(0.0, device=self.device) |
weight = self.opts.train.lambdas.G.d.main |
prediction, z_depth = self.G.decoders["d"](z) |
if self.opts.gen.d.classify.enable: |
target.squeeze_(1) |
full_loss = self.losses["G"]["tasks"]["d"](prediction, target) |
full_loss *= weight |
if weight == 0 or (domain == "r" and "d" not in self.pseudo_training_tasks): |
return zero_loss, prediction, z_depth |
return full_loss, prediction, z_depth |
def masker_s_loss(self, x, z, depth_preds, z_depth, target, domain, for_="G"): |
assert for_ in {"G", "D"} |
assert domain in {"r", "s"} |
self.assert_z_matches_x(x, z) |
assert x.shape[0] == target.shape[0] if target is not None else True |
full_loss = torch.tensor(0.0, device=self.device) |
softmax_preds = None |
pred = None |
if for_ == "G" or self.opts.gen.s.use_advent: |
pred = self.G.decoders["s"](z, z_depth) |
if for_ == "G": |
if domain == "s" or "s" in self.pseudo_training_tasks: |
if domain == "s": |
logger = self.logger.losses.gen.task["s"]["crossent"] |
weight = self.opts.train.lambdas.G["s"]["crossent"] |
else: |
logger = self.logger.losses.gen.task["s"]["crossent_pseudo"] |
weight = self.opts.train.lambdas.G["s"]["crossent_pseudo"] |
if weight != 0: |
loss_func = self.losses["G"]["tasks"]["s"]["crossent"] |
loss = loss_func(pred, target.squeeze(1)) |
loss *= weight |
full_loss += loss |
logger[domain] = loss.item() |
if domain == "r": |
weight = self.opts.train.lambdas.G["s"]["minent"] |
if self.opts.gen.s.use_minent and weight != 0: |
softmax_preds = softmax(pred, dim=1) |
loss = self.losses["G"]["tasks"]["s"]["minent"](softmax_preds) |
loss *= weight |
full_loss += loss |
self.logger.losses.gen.task["s"]["minent"]["r"] = loss.item() |
if self.opts.gen.s.use_advent: |
if self.opts.gen.s.use_dada and depth_preds is not None: |
depth_preds = depth_preds.detach() |
else: |
depth_preds = None |
if for_ == "D": |
domain_label = domain |
logger = {} |
loss_func = self.losses["D"]["advent"] |
pred = pred.detach() |
weight = self.opts.train.lambdas.advent.adv_main |
else: |
domain_label = "s" |
logger = self.logger.losses.gen.task["s"]["advent"] |
loss_func = self.losses["G"]["tasks"]["s"]["advent"] |
weight = self.opts.train.lambdas.G["s"]["advent"] |
if (for_ == "D" or domain == "r") and weight != 0: |
if softmax_preds is None: |
softmax_preds = softmax(pred, dim=1) |
loss = loss_func( |
softmax_preds, |
self.domain_labels[domain_label], |
self.D["s"]["Advent"], |
depth_preds, |
) |
loss *= weight |
full_loss += loss |
logger[domain] = loss.item() |
if for_ == "D": |
if self.opts.dis.s.gan_type == "GAN" or "WGAN_norm": |
pass |
elif self.opts.dis.s.gan_type == "WGAN": |
for p in self.D["s"]["Advent"].parameters(): |
p.data.clamp_( |
self.opts.dis.s.wgan_clamp_lower, |
self.opts.dis.s.wgan_clamp_upper, |
) |
elif self.opts.dis.s.gan_type == "WGAN_gp": |
prob_need_grad = autograd.Variable(pred, requires_grad=True) |
d_out = self.D["s"]["Advent"](prob_need_grad) |
gp = get_WGAN_gradient(prob_need_grad, d_out) |
gp_loss = gp * self.opts.train.lambdas.advent.WGAN_gp |
full_loss += gp_loss |
else: |
raise NotImplementedError |
return full_loss, pred |
def masker_m_loss( |
self, x, z, target, domain, for_="G", cond=None, z_depth=None, depth_preds=None |
): |
assert for_ in {"G", "D"} |
assert domain in {"r", "s"} |
self.assert_z_matches_x(x, z) |
assert x.shape[0] == target.shape[0] if target is not None else True |
full_loss = torch.tensor(0.0, device=self.device) |
pred_logits = self.G.decoders["m"](z, cond=cond, z_depth=z_depth) |
pred_prob = sigmoid(pred_logits) |
pred_prob_complementary = 1 - pred_prob |
prob = torch.cat([pred_prob, pred_prob_complementary], dim=1) |
if for_ == "G": |
weight = self.opts.train.lambdas.G.m.tv |
if weight != 0: |
loss = self.losses["G"]["tasks"]["m"]["tv"](pred_prob) |
loss *= weight |
full_loss += loss |
self.logger.losses.gen.task["m"]["tv"][domain] = loss.item() |
weight = self.opts.train.lambdas.G.m.bce |
if domain == "s" and weight != 0: |
loss = self.losses["G"]["tasks"]["m"]["bce"](pred_logits, target) |
loss *= weight |
full_loss += loss |
self.logger.losses.gen.task["m"]["bce"]["s"] = loss.item() |
if domain == "r": |
weight = self.opts.train.lambdas.G["m"]["gi"] |
if self.opts.gen.m.use_ground_intersection and weight != 0: |
loss = self.losses["G"]["tasks"]["m"]["gi"](pred_prob, target) |
loss *= weight |
full_loss += loss |
self.logger.losses.gen.task["m"]["gi"]["r"] = loss.item() |
weight = self.opts.train.lambdas.G.m.pl4m |
if self.use_pl4m and weight != 0: |
pl4m_loss = self.painter_loss_for_masker(x, pred_prob) |
pl4m_loss *= weight |
full_loss += pl4m_loss |
self.logger.losses.gen.task.m.pl4m.r = pl4m_loss.item() |
weight = self.opts.train.lambdas.advent.ent_main |
if self.opts.gen.m.use_minent and weight != 0: |
loss = self.losses["G"]["tasks"]["m"]["minent"](prob) |
loss *= weight |
full_loss += loss |
self.logger.losses.gen.task["m"]["minent"]["r"] = loss.item() |
if self.opts.gen.m.use_advent: |
if self.opts.gen.m.use_dada and depth_preds is not None: |
depth_preds = depth_preds.detach() |
depth_preds = torch.nn.functional.interpolate( |
depth_preds, size=x.shape[-2:], mode="nearest" |
) |
else: |
depth_preds = None |
if for_ == "D": |
domain_label = domain |
logger = {} |
loss_func = self.losses["D"]["advent"] |
prob = prob.detach() |
weight = self.opts.train.lambdas.advent.adv_main |
else: |
domain_label = "s" |
logger = self.logger.losses.gen.task["m"]["advent"] |
loss_func = self.losses["G"]["tasks"]["m"]["advent"] |
weight = self.opts.train.lambdas.advent.adv_main |
if (for_ == "D" or domain == "r") and weight != 0: |
loss = loss_func( |
prob.to(self.device), |
self.domain_labels[domain_label], |
self.D["m"]["Advent"], |
depth_preds, |
) |
loss *= weight |
full_loss += loss |
logger[domain] = loss.item() |
if for_ == "D": |
if self.opts.dis.m.gan_type == "GAN" or "WGAN_norm": |
pass |
elif self.opts.dis.m.gan_type == "WGAN": |
for p in self.D["s"]["Advent"].parameters(): |
p.data.clamp_( |
self.opts.dis.m.wgan_clamp_lower, |
self.opts.dis.m.wgan_clamp_upper, |
) |
elif self.opts.dis.m.gan_type == "WGAN_gp": |
prob_need_grad = autograd.Variable(prob, requires_grad=True) |
d_out = self.D["s"]["Advent"](prob_need_grad) |
gp = get_WGAN_gradient(prob_need_grad, d_out) |
gp_loss = self.opts.train.lambdas.advent.WGAN_gp * gp |
full_loss += gp_loss |
else: |
raise NotImplementedError |
return full_loss, prob |
def painter_loss_for_masker(self, x, m): |
for param in self.G.painter.parameters(): |
param.requires_grad = False |
fake_flooded = self.G.paint(m, x) |
if self.opts.dis.p.use_local_discriminator: |
fake_d_global = self.D["p"]["global"](fake_flooded) |
fake_d_local = self.D["p"]["local"](fake_flooded * m) |
pl4m_loss = self.losses["G"]["p"]["gan"](fake_d_global, True, False) |
pl4m_loss += self.losses["G"]["p"]["gan"](fake_d_local, True, False) |
else: |
real_cat = torch.cat([m, x], axis=1) |
fake_cat = torch.cat([m, fake_flooded], axis=1) |
real_fake_cat = torch.cat([real_cat, fake_cat], dim=0) |
real_fake_d = self.D["p"](real_fake_cat) |
_, fake_d = divide_pred(real_fake_d) |
pl4m_loss = self.losses["G"]["p"]["gan"](fake_d, True, False) |
if "p" in self.opts.tasks: |
for param in self.G.painter.parameters(): |
param.requires_grad = True |
return pl4m_loss |
@torch.no_grad() |
def run_evaluation(self, verbose=0): |
print("******************* Running Evaluation ***********************") |
start_time = time() |
self.eval_mode() |
val_logger = None |
nb_of_batches = None |
for i, multi_batch_tuple in enumerate(self.val_loaders): |
nb_of_batches = i + 1 |
multi_domain_batch = { |
batch["domain"][0]: self.batch_to_device(batch) |
for batch in multi_batch_tuple |
} |
self.get_G_loss(multi_domain_batch, verbose) |
if val_logger is None: |
val_logger = deepcopy(self.logger.losses.generator) |
else: |
val_logger = sum_dict(val_logger, self.logger.losses.generator) |
val_logger = div_dict(val_logger, nb_of_batches) |
self.logger.losses.generator = val_logger |
self.logger.log_losses(model_to_update="G", mode="val") |
for d in self.opts.domains: |
self.logger.log_comet_images("train", d) |
self.logger.log_comet_images("val", d) |
if "m" in self.opts.tasks and self.has_painter and not self.kitti_pretrain: |
self.logger.log_comet_combined_images("train", "r") |
self.logger.log_comet_combined_images("val", "r") |
if self.exp is not None: |
print() |
if "m" in self.opts.tasks or "s" in self.opts.tasks: |
self.eval_images("val", "r") |
self.eval_images("val", "s") |
if "p" in self.opts.tasks and not self.kitti_pretrain: |
val_fid = compute_val_fid(self) |
if self.exp is not None: |
self.exp.log_metric("val_fid", val_fid, step=self.logger.global_step) |
else: |
print("Validation FID Score", val_fid) |
self.train_mode() |
timing = int(time() - start_time) |
print("****************** Done in {}s *********************".format(timing)) |
def eval_images(self, mode, domain): |
if domain == "s" and self.kitti_pretrain: |
domain = "kitti" |
if domain == "rf" or domain not in self.display_images[mode]: |
return |
metric_funcs = {"accuracy": accuracy, "mIOU": mIOU} |
metric_avg_scores = {"m": {}} |
if "s" in self.opts.tasks: |
metric_avg_scores["s"] = {} |
if "d" in self.opts.tasks and domain == "s" and self.opts.gen.d.classify.enable: |
metric_avg_scores["d"] = {} |
for key in metric_funcs: |
for task in metric_avg_scores: |
metric_avg_scores[task][key] = [] |
for im_set in self.display_images[mode][domain]: |
x = im_set["data"]["x"].unsqueeze(0).to(self.device) |
z = self.G.encode(x) |
s_pred = d_pred = z_depth = None |
if "d" in metric_avg_scores: |
d_pred, z_depth = self.G.decoders["d"](z) |
d_pred = d_pred.detach().cpu() |
if domain == "s": |
d = im_set["data"]["d"].unsqueeze(0).detach() |
for metric in metric_funcs: |
metric_score = metric_funcs[metric](d_pred, d) |
metric_avg_scores["d"][metric].append(metric_score) |
if "s" in metric_avg_scores: |
if z_depth is None: |
if self.opts.gen.s.use_dada and "d" in self.opts.tasks: |
_, z_depth = self.G.decoders["d"](z) |
s_pred = self.G.decoders["s"](z, z_depth).detach().cpu() |
s = im_set["data"]["s"].unsqueeze(0).detach() |
for metric in metric_funcs: |
metric_score = metric_funcs[metric](s_pred, s) |
metric_avg_scores["s"][metric].append(metric_score) |
if "m" in self.opts: |
cond = None |
if s_pred is not None and d_pred is not None: |
cond = self.G.make_m_cond(d_pred, s_pred, x) |
if z_depth is None: |
if self.opts.gen.m.use_dada and "d" in self.opts.tasks: |
_, z_depth = self.G.decoders["d"](z) |
pred_mask = ( |
(self.G.mask(z=z, cond=cond, z_depth=z_depth)).detach().cpu() |
) |
pred_mask = (pred_mask > 0.5).to(torch.float32) |
pred_prob = torch.cat([1 - pred_mask, pred_mask], dim=1) |
m = im_set["data"]["m"].unsqueeze(0).detach() |
for metric in metric_funcs: |
if metric != "mIOU": |
metric_score = metric_funcs[metric](pred_mask, m) |
else: |
metric_score = metric_funcs[metric](pred_prob, m) |
metric_avg_scores["m"][metric].append(metric_score) |
metric_avg_scores = { |
task: { |
metric: np.mean(values) if values else float("nan") |
for metric, values in met_dict.items() |
} |
for task, met_dict in metric_avg_scores.items() |
} |
metric_avg_scores = { |
task: { |
metric: value if not np.isnan(value) else -1 |
for metric, value in met_dict.items() |
} |
for task, met_dict in metric_avg_scores.items() |
} |
if self.exp is not None: |
self.exp.log_metrics( |
flatten_opts(metric_avg_scores), |
prefix=f"metrics_{mode}_{domain}", |
step=self.logger.global_step, |
) |
else: |
print(f"metrics_{mode}_{domain}") |
print(flatten_opts(metric_avg_scores)) |
return 0 |
def functional_test_mode(self): |
import atexit |
self.opts.output_path = ( |
Path("~").expanduser() / "climategan" / "functional_tests" |
) |
Path(self.opts.output_path).mkdir(parents=True, exist_ok=True) |
with open(Path(self.opts.output_path) / "is_functional.test", "w") as f: |
f.write("trainer functional test - delete this dir") |
if self.exp is not None: |
self.exp.log_parameter("is_functional_test", True) |
atexit.register(self.del_output_path) |
def del_output_path(self, force=False): |
import shutil |
if not Path(self.opts.output_path).exists(): |
return |
if (Path(self.opts.output_path) / "is_functional.test").exists() or force: |
shutil.rmtree(self.opts.output_path) |
def compute_fire(self, x, seg_preds=None, z=None, z_depth=None): |
""" |
Transforms input tensor given wildfires event |
Args: |
x (torch.Tensor): Input tensor |
seg_preds (torch.Tensor): Semantic segmentation |
predictions for input tensor |
z (torch.Tensor): Latent vector of encoded "x". |
Can be None if seg_preds is given. |
Returns: |
torch.Tensor: Wildfire version of input tensor |
""" |
if seg_preds is None: |
if z is None: |
z = self.G.encode(x) |
seg_preds = self.G.decoders["s"](z, z_depth) |
return add_fire(x, seg_preds, self.opts.events.fire) |
def compute_flood( |
self, x, z=None, z_depth=None, m=None, s=None, cloudy=None, bin_value=-1 |
): |
""" |
Applies a flood (mask + paint) to an input image, with optionally |
pre-computed masker z or mask |
Args: |
x (torch.Tensor): B x C x H x W -1:1 input image |
z (torch.Tensor, optional): B x C x H x W Masker latent vector. |
Defaults to None. |
m (torch.Tensor, optional): B x 1 x H x W Mask. Defaults to None. |
bin_value (float, optional): Mask binarization value. |
Set to -1 to use smooth masks (no binarization) |
Returns: |
torch.Tensor: B x 3 x H x W -1:1 flooded image |
""" |
if m is None: |
if z is None: |
z = self.G.encode(x) |
if "d" in self.opts.tasks and self.opts.gen.m.use_dada and z_depth is None: |
_, z_depth = self.G.decoders["d"](z) |
m = self.G.mask(x=x, z=z, z_depth=z_depth) |
if bin_value >= 0: |
m = (m > bin_value).to(m.dtype) |
if cloudy: |
assert s is not None |
return self.G.paint_cloudy(m, x, s) |
return self.G.paint(m, x) |
def compute_smog(self, x, z=None, d=None, s=None, use_sky_seg=False): |
sky_mask = None |
if d is None or (use_sky_seg and s is None): |
if z is None: |
z = self.G.encode(x) |
if d is None: |
d, _ = self.G.decoders["d"](z) |
if use_sky_seg and s is None: |
if "s" not in self.opts.tasks: |
raise ValueError( |
"Cannot have " |
+ "(use_sky_seg is True and s is None and 's' not in tasks)" |
) |
s = self.G.decoders["s"](z) |
params = self.opts.events.smog |
airlight = params.airlight * torch.ones(3) |
airlight = airlight.view(1, -1, 1, 1).to(self.device) |
irradiance = srgb2lrgb(x) |
beta = torch.tensor([params.beta / params.vr] * 3) |
beta = beta.view(1, -1, 1, 1).to(self.device) |
d = normalize(d, mini=0.3, maxi=1.0) |
d = 1.0 / d |
d = normalize(d, mini=0.1, maxi=1) |
if sky_mask is not None: |
d[sky_mask] = 1 |
d = torch.nn.functional.interpolate( |
d, size=x.shape[-2:], mode="bilinear", align_corners=True |
) |
d = d.repeat(1, 3, 1, 1) |
transmission = torch.exp(d * -beta) |
smogged = transmission * irradiance + (1 - transmission) * airlight |
smogged = lrgb2srgb(smogged) |
alpha = params.alpha / 255 |
yellow_mask = torch.Tensor([params.yellow_color]) / 255 |
yellow_filter = ( |
yellow_mask.unsqueeze(2) |
.unsqueeze(2) |
.repeat(1, 1, smogged.shape[-2], smogged.shape[-1]) |
.to(self.device) |
) |
smogged = smogged * (1 - alpha) + yellow_filter * alpha |
return smogged |