|
""" |
|
Main component: the trainer handles everything: |
|
* initializations |
|
* training |
|
* saving |
|
""" |
|
import inspect |
|
import warnings |
|
from copy import deepcopy |
|
from pathlib import Path |
|
from time import time |
|
|
|
import numpy as np |
|
from comet_ml import ExistingExperiment, Experiment |
|
|
|
warnings.simplefilter("ignore", UserWarning) |
|
|
|
import torch |
|
import torch.nn as nn |
|
from addict import Dict |
|
from torch import autograd, sigmoid, softmax |
|
from torch.cuda.amp import GradScaler, autocast |
|
from tqdm import tqdm |
|
|
|
from climategan.data import get_all_loaders, decode_segmap_merged_labels |
|
|
|
from climategan.discriminator import OmniDiscriminator, create_discriminator |
|
from climategan.eval_metrics import accuracy, mIOU |
|
from climategan.fid import compute_val_fid |
|
from climategan.fire import add_fire |
|
from climategan.generator import OmniGenerator, create_generator |
|
from climategan.logger import Logger |
|
from climategan.losses import get_losses |
|
from climategan.optim import get_optimizer |
|
from climategan.transforms import DiffTransforms |
|
from climategan.tutils import ( |
|
divide_pred, |
|
get_num_params, |
|
get_WGAN_gradient, |
|
lrgb2srgb, |
|
normalize, |
|
print_num_parameters, |
|
shuffle_batch_tuple, |
|
srgb2lrgb, |
|
tensor_to_uint8_numpy_image, |
|
vgg_preprocess, |
|
zero_grad, |
|
) |
|
from climategan.utils import ( |
|
comet_kwargs, |
|
div_dict, |
|
find_target_size, |
|
flatten_opts, |
|
get_display_indices, |
|
get_existing_comet_id, |
|
get_latest_opts, |
|
merge, |
|
resolve, |
|
sum_dict, |
|
Timer, |
|
) |
|
|
|
try: |
|
import torch_xla.core.xla_model as xm |
|
except ImportError: |
|
pass |
|
|
|
|
|
class Trainer: |
|
"""Main trainer class""" |
|
|
|
def __init__(self, opts, comet_exp=None, verbose=0, device=None): |
|
"""Trainer class to gather various model training procedures |
|
such as training evaluating saving and logging |
|
|
|
init: |
|
* creates an addict.Dict logger |
|
* creates logger.exp as a comet_exp experiment if `comet` arg is True |
|
* sets the device (1 GPU or CPU) |
|
|
|
Args: |
|
opts (addict.Dict): options to configure the trainer, the data, the models |
|
comet (bool, optional): whether to log the trainer with comet.ml. |
|
Defaults to False. |
|
verbose (int, optional): printing level to debug. Defaults to 0. |
|
""" |
|
super().__init__() |
|
|
|
self.opts = opts |
|
self.verbose = verbose |
|
self.logger = Logger(self) |
|
|
|
self.losses = None |
|
self.G = self.D = None |
|
self.real_val_fid_stats = None |
|
self.use_pl4m = False |
|
self.is_setup = False |
|
self.loaders = self.all_loaders = None |
|
self.exp = None |
|
|
|
self.current_mode = "train" |
|
self.diff_transforms = None |
|
self.kitti_pretrain = self.opts.train.kitti.pretrain |
|
self.pseudo_training_tasks = set(self.opts.train.pseudo.tasks) |
|
|
|
self.lr_names = {} |
|
self.base_display_images = {} |
|
self.kitty_display_images = {} |
|
self.domain_labels = {"s": 0, "r": 1} |
|
|
|
self.device = device or torch.device( |
|
"cuda:0" if torch.cuda.is_available() else "cpu" |
|
) |
|
|
|
if isinstance(comet_exp, Experiment): |
|
self.exp = comet_exp |
|
|
|
if self.opts.train.amp: |
|
optimizers = [ |
|
self.opts.gen.opt.optimizer.lower(), |
|
self.opts.dis.opt.optimizer.lower(), |
|
] |
|
if "extraadam" in optimizers: |
|
raise ValueError( |
|
"AMP does not work with ExtraAdam ({})".format(optimizers) |
|
) |
|
self.grad_scaler_d = GradScaler() |
|
self.grad_scaler_g = GradScaler() |
|
|
|
|
|
|
|
|
|
if ( |
|
self.opts.gen.s.depth_feat_fusion is True |
|
or self.opts.gen.s.depth_dada_fusion is True |
|
): |
|
self.opts.gen.s.use_dada = True |
|
|
|
@torch.no_grad() |
|
def paint_and_mask(self, image_batch, mask_batch=None, resolution="approx"): |
|
""" |
|
Paints a batch of images (or a single image with a batch dim of 1). If |
|
masks are not provided, they are inferred from the masker. |
|
Resolution can either be the train-time resolution or the closest |
|
multiple of 2 ** spade_n_up |
|
|
|
Operations performed without gradient |
|
|
|
If resolution == "approx" then the output image has the shape: |
|
(dim // 2 ** spade_n_up) * 2 ** spade_n_up, for dim in [height, width] |
|
eg: (1000, 1300) => (896, 1280) for spade_n_up = 7 |
|
If resolution == "exact" then the output image has the same shape: |
|
we first process in "approx" mode then upsample bilinear |
|
If resolution == "basic" image output shape is the train-time's |
|
(typically 640x640) |
|
If resolution == "upsample" image is inferred as "basic" and |
|
then upsampled to original size |
|
|
|
Args: |
|
image_batch (torch.Tensor): 4D batch of images to flood |
|
mask_batch (torch.Tensor, optional): Masks for the images. |
|
Defaults to None (infer with Masker). |
|
resolution (str, optional): "approx", "exact" or False |
|
|
|
Returns: |
|
torch.Tensor: N x C x H x W where H and W depend on `resolution` |
|
""" |
|
assert resolution in {"approx", "exact", "basic", "upsample"} |
|
previous_mode = self.current_mode |
|
if previous_mode == "train": |
|
self.eval_mode() |
|
|
|
if mask_batch is None: |
|
mask_batch = self.G.mask(x=image_batch) |
|
else: |
|
assert len(image_batch) == len(mask_batch) |
|
assert image_batch.shape[-2:] == mask_batch.shape[-2:] |
|
|
|
if resolution not in {"approx", "exact"}: |
|
painted = self.G.paint(mask_batch, image_batch) |
|
|
|
if resolution == "upsample": |
|
painted = nn.functional.interpolate( |
|
painted, size=image_batch.shape[-2:], mode="bilinear" |
|
) |
|
else: |
|
|
|
zh = self.G.painter.z_h |
|
zw = self.G.painter.z_w |
|
|
|
self.G.painter.z_h = ( |
|
image_batch.shape[-2] // 2**self.opts.gen.p.spade_n_up |
|
) |
|
self.G.painter.z_w = ( |
|
image_batch.shape[-1] // 2**self.opts.gen.p.spade_n_up |
|
) |
|
|
|
painted = self.G.paint(mask_batch, image_batch) |
|
|
|
self.G.painter.z_h = zh |
|
self.G.painter.z_w = zw |
|
if resolution == "exact": |
|
painted = nn.functional.interpolate( |
|
painted, size=image_batch.shape[-2:], mode="bilinear" |
|
) |
|
|
|
if previous_mode == "train": |
|
self.train_mode() |
|
|
|
return painted |
|
|
|
def _p(self, *args, **kwargs): |
|
""" |
|
verbose-dependant print util |
|
""" |
|
if self.verbose > 0: |
|
print(*args, **kwargs) |
|
|
|
@torch.no_grad() |
|
def infer_all( |
|
self, |
|
x, |
|
numpy=True, |
|
stores={}, |
|
bin_value=-1, |
|
half=False, |
|
xla=False, |
|
cloudy=True, |
|
auto_resize_640=False, |
|
ignore_event=set(), |
|
return_intermediates=False, |
|
): |
|
""" |
|
Create a dictionary of events from a numpy or tensor, |
|
single or batch image data. |
|
|
|
stores is a dictionary of times for the Timer class. |
|
|
|
bin_value is used to binarize (or not) flood masks |
|
|
|
all values in the output dictionary have 4 dimensions: |
|
BxHxWxC if numpy else BxCxHxW |
|
""" |
|
assert self.is_setup |
|
assert len(x.shape) in {3, 4}, f"Unknown Data shape {x.shape}" |
|
|
|
|
|
if not isinstance(x, torch.Tensor): |
|
x = torch.tensor(x, device=self.device) |
|
|
|
|
|
if len(x.shape) == 3: |
|
x.unsqueeze_(0) |
|
|
|
|
|
if x.shape[1] != 3: |
|
assert x.shape[-1] == 3, f"Unknown x shape to permute {x.shape}" |
|
x = x.permute(0, 3, 1, 2) |
|
|
|
|
|
if x.device != self.device: |
|
x = x.to(self.device) |
|
|
|
|
|
if auto_resize_640 and (x.shape[-1] != 640 or x.shape[-2] != 640): |
|
x = torch.nn.functional.interpolate(x, (640, 640), mode="bilinear") |
|
|
|
if half: |
|
x = x.half() |
|
|
|
|
|
self.G.painter.set_latent_shape(x.shape, True) |
|
|
|
with Timer(store=stores.get("all events", [])): |
|
|
|
with Timer(store=stores.get("encode", [])): |
|
z = self.G.encode(x) |
|
if xla: |
|
xm.mark_step() |
|
|
|
|
|
with Timer(store=stores.get("depth", [])): |
|
depth, z_depth = self.G.decoders["d"](z) |
|
if xla: |
|
xm.mark_step() |
|
with Timer(store=stores.get("segmentation", [])): |
|
segmentation = self.G.decoders["s"](z, z_depth) |
|
if xla: |
|
xm.mark_step() |
|
with Timer(store=stores.get("mask", [])): |
|
cond = self.G.make_m_cond(depth, segmentation, x) |
|
mask = self.G.mask(z=z, cond=cond, z_depth=z_depth) |
|
if xla: |
|
xm.mark_step() |
|
|
|
|
|
if "wildfire" not in ignore_event: |
|
with Timer(store=stores.get("wildfire", [])): |
|
wildfire = self.compute_fire(x, seg_preds=segmentation) |
|
if "smog" not in ignore_event: |
|
with Timer(store=stores.get("smog", [])): |
|
smog = self.compute_smog(x, d=depth, s=segmentation) |
|
if "flood" not in ignore_event: |
|
with Timer(store=stores.get("flood", [])): |
|
flood = self.compute_flood( |
|
x, |
|
m=mask, |
|
s=segmentation, |
|
cloudy=cloudy, |
|
bin_value=bin_value, |
|
) |
|
|
|
if xla: |
|
xm.mark_step() |
|
|
|
output_data = {} |
|
|
|
if numpy: |
|
with Timer(store=stores.get("numpy", [])): |
|
if "flood" not in ignore_event: |
|
|
|
flood = tensor_to_uint8_numpy_image(flood) |
|
|
|
output_data["flood"] = flood |
|
if "wildfire" not in ignore_event: |
|
wildfire = tensor_to_uint8_numpy_image(wildfire) |
|
output_data["wildfire"] = wildfire |
|
if "smog" not in ignore_event: |
|
smog = tensor_to_uint8_numpy_image(smog) |
|
output_data["smog"] = smog |
|
|
|
if return_intermediates: |
|
if numpy: |
|
output_data["mask"] = ( |
|
((mask > bin_value) * 255).cpu().numpy().astype(np.uint8) |
|
) |
|
output_data["depth"] = tensor_to_uint8_numpy_image(depth) |
|
output_data["segmentation"] = ( |
|
decode_segmap_merged_labels(segmentation, "r", False) |
|
.cpu() |
|
.permute(0, 2, 3, 1) |
|
.numpy() |
|
.astype(np.uint8) |
|
) |
|
else: |
|
output_data["mask"] = mask |
|
output_data["depth"] = depth |
|
output_data["segmentation"] = segmentation |
|
|
|
return output_data |
|
|
|
@classmethod |
|
def resume_from_path( |
|
cls, |
|
path, |
|
overrides={}, |
|
setup=True, |
|
inference=False, |
|
new_exp=False, |
|
device=None, |
|
verbose=1, |
|
): |
|
""" |
|
Resume and optionally setup a trainer from a specific path, |
|
using the latest opts and checkpoint. Requires path to contain opts.yaml |
|
(or increased), url.txt (or increased) and checkpoints/ |
|
|
|
Args: |
|
path (str | pathlib.Path): Trainer to resume |
|
overrides (dict, optional): Override loaded opts with those. Defaults to {}. |
|
setup (bool, optional): Wether or not to setup the trainer before |
|
returning it. Defaults to True. |
|
inference (bool, optional): Setup should be done in inference mode or not. |
|
Defaults to False. |
|
new_exp (bool, optional): Re-use existing comet exp in path or create |
|
a new one? Defaults to False. |
|
device (torch.device, optional): Device to use |
|
|
|
Returns: |
|
climategan.Trainer: Loaded and resumed trainer |
|
""" |
|
p = resolve(path) |
|
assert p.exists() |
|
|
|
c = p / "checkpoints" |
|
assert c.exists() and c.is_dir() |
|
|
|
opts = get_latest_opts(p) |
|
opts = Dict(merge(overrides, opts)) |
|
opts.train.resume = True |
|
|
|
if new_exp is None: |
|
exp = None |
|
elif new_exp is True: |
|
exp = Experiment(project_name="climategan", **comet_kwargs) |
|
exp.log_asset_folder( |
|
str(resolve(Path(__file__)).parent), |
|
recursive=True, |
|
log_file_name=True, |
|
) |
|
exp.log_parameters(flatten_opts(opts)) |
|
else: |
|
comet_id = get_existing_comet_id(p) |
|
exp = ExistingExperiment(previous_experiment=comet_id, **comet_kwargs) |
|
|
|
trainer = cls(opts, comet_exp=exp, device=device, verbose=verbose) |
|
|
|
if setup: |
|
trainer.setup(inference=inference) |
|
return trainer |
|
|
|
def save(self): |
|
save_dir = Path(self.opts.output_path) / Path("checkpoints") |
|
save_dir.mkdir(exist_ok=True) |
|
save_path = save_dir / "latest_ckpt.pth" |
|
|
|
|
|
|
|
save_dict = { |
|
"epoch": self.logger.epoch, |
|
"G": self.G.state_dict(), |
|
"g_opt": self.g_opt.state_dict(), |
|
"step": self.logger.global_step, |
|
} |
|
|
|
if self.D is not None and get_num_params(self.D) > 0: |
|
save_dict["D"] = self.D.state_dict() |
|
save_dict["d_opt"] = self.d_opt.state_dict() |
|
|
|
if ( |
|
self.logger.epoch >= self.opts.train.min_save_epoch |
|
and self.logger.epoch % self.opts.train.save_n_epochs == 0 |
|
): |
|
torch.save(save_dict, save_dir / f"epoch_{self.logger.epoch}_ckpt.pth") |
|
|
|
torch.save(save_dict, save_path) |
|
|
|
def resume(self, inference=False): |
|
tpu = "xla" in str(self.device) |
|
if tpu: |
|
print("Resuming on TPU:", self.device) |
|
|
|
m_path = Path(self.opts.load_paths.m) |
|
p_path = Path(self.opts.load_paths.p) |
|
pm_path = Path(self.opts.load_paths.pm) |
|
output_path = Path(self.opts.output_path) |
|
|
|
map_loc = self.device if not tpu else "cpu" |
|
|
|
if "m" in self.opts.tasks and "p" in self.opts.tasks: |
|
|
|
|
|
|
|
|
|
|
|
|
|
if all([str(p) == "none" for p in [m_path, p_path, pm_path]]): |
|
checkpoint_path = output_path / "checkpoints/latest_ckpt.pth" |
|
print("Resuming P+M model from", str(checkpoint_path)) |
|
checkpoint = torch.load(checkpoint_path, map_location=map_loc) |
|
|
|
|
|
|
|
|
|
elif str(pm_path) != "none": |
|
assert pm_path.exists() |
|
|
|
if pm_path.is_dir(): |
|
checkpoint_path = pm_path / "checkpoints/latest_ckpt.pth" |
|
else: |
|
assert pm_path.suffix == ".pth" |
|
checkpoint_path = pm_path |
|
|
|
print("Resuming P+M model from", str(checkpoint_path)) |
|
checkpoint = torch.load(checkpoint_path, map_location=map_loc) |
|
|
|
|
|
|
|
elif m_path != p_path: |
|
assert m_path.exists() |
|
assert p_path.exists() |
|
|
|
if m_path.is_dir(): |
|
m_path = m_path / "checkpoints/latest_ckpt.pth" |
|
|
|
if p_path.is_dir(): |
|
p_path = p_path / "checkpoints/latest_ckpt.pth" |
|
|
|
assert m_path.suffix == ".pth" |
|
assert p_path.suffix == ".pth" |
|
|
|
print(f"Resuming P+M model from \n -{p_path} \nand \n -{m_path}") |
|
m_checkpoint = torch.load(m_path, map_location=map_loc) |
|
p_checkpoint = torch.load(p_path, map_location=map_loc) |
|
checkpoint = merge(m_checkpoint, p_checkpoint) |
|
|
|
else: |
|
raise ValueError( |
|
"Cannot resume a P+M model with provided load_paths:\n{}".format( |
|
self.opts.load_paths |
|
) |
|
) |
|
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
if str(m_path) != "none" and str(p_path) != "none": |
|
raise ValueError( |
|
"Opts tasks are {} but received 2 values for the load_paths".format( |
|
self.opts.tasks |
|
) |
|
) |
|
|
|
|
|
elif str(m_path) != "none": |
|
assert m_path.exists() |
|
assert "m" in self.opts.tasks |
|
model = "M" |
|
if m_path.is_dir(): |
|
m_path = m_path / "checkpoints/latest_ckpt.pth" |
|
checkpoint_path = m_path |
|
|
|
|
|
elif str(p_path) != "none": |
|
assert p_path.exists() |
|
assert "p" in self.opts.tasks |
|
model = "P" |
|
if p_path.is_dir(): |
|
p_path = p_path / "checkpoints/latest_ckpt.pth" |
|
checkpoint_path = p_path |
|
|
|
|
|
else: |
|
model = "P" if "p" in self.opts.tasks else "M" |
|
checkpoint_path = output_path / "checkpoints/latest_ckpt.pth" |
|
|
|
print(f"Resuming {model} model from {checkpoint_path}") |
|
checkpoint = torch.load(checkpoint_path, map_location=map_loc) |
|
|
|
|
|
|
|
if tpu: |
|
checkpoint = xm.send_cpu_data_to_device(checkpoint, self.device) |
|
|
|
|
|
|
|
|
|
if inference: |
|
incompatible_keys = self.G.load_state_dict(checkpoint["G"], strict=False) |
|
if incompatible_keys.missing_keys: |
|
print("WARNING: Missing keys in self.G.load_state_dict, keeping inits") |
|
print(incompatible_keys.missing_keys) |
|
if incompatible_keys.unexpected_keys: |
|
print("WARNING: Ignoring Unexpected keys in self.G.load_state_dict") |
|
print(incompatible_keys.unexpected_keys) |
|
else: |
|
self.G.load_state_dict(checkpoint["G"]) |
|
|
|
if inference: |
|
|
|
print("Done loading checkpoints.") |
|
return |
|
|
|
self.g_opt.load_state_dict(checkpoint["g_opt"]) |
|
|
|
|
|
|
|
|
|
|
|
for _ in range(self.logger.epoch + 1): |
|
self.update_learning_rates() |
|
|
|
|
|
|
|
|
|
if self.D is not None and get_num_params(self.D) > 0: |
|
self.D.load_state_dict(checkpoint["D"]) |
|
self.d_opt.load_state_dict(checkpoint["d_opt"]) |
|
|
|
|
|
|
|
|
|
self.logger.epoch = checkpoint["epoch"] |
|
self.logger.global_step = checkpoint["step"] |
|
self.exp.log_text( |
|
"Resuming from epoch {} & step {}".format( |
|
checkpoint["epoch"], checkpoint["step"] |
|
) |
|
) |
|
|
|
if self.logger.global_step % 2 != 0: |
|
self.logger.global_step += 1 |
|
|
|
def eval_mode(self): |
|
""" |
|
Set trainer's models in eval mode |
|
""" |
|
if self.G is not None: |
|
self.G.eval() |
|
if self.D is not None: |
|
self.D.eval() |
|
self.current_mode = "eval" |
|
|
|
def train_mode(self): |
|
""" |
|
Set trainer's models in train mode |
|
""" |
|
if self.G is not None: |
|
self.G.train() |
|
if self.D is not None: |
|
self.D.train() |
|
|
|
self.current_mode = "train" |
|
|
|
def assert_z_matches_x(self, x, z): |
|
assert x.shape[0] == ( |
|
z.shape[0] if not isinstance(z, (list, tuple)) else z[0].shape[0] |
|
), "x-> {}, z->{}".format( |
|
x.shape, z.shape if not isinstance(z, (list, tuple)) else z[0].shape |
|
) |
|
|
|
def batch_to_device(self, b): |
|
"""sends the data in b to self.device |
|
|
|
Args: |
|
b (dict): the batch dictionnay |
|
|
|
Returns: |
|
dict: the batch dictionnary with its "data" field sent to self.device |
|
""" |
|
for task, tensor in b["data"].items(): |
|
b["data"][task] = tensor.to(self.device) |
|
return b |
|
|
|
def sample_painter_z(self, batch_size): |
|
return self.G.sample_painter_z(batch_size, self.device) |
|
|
|
@property |
|
def train_loaders(self): |
|
"""Get a zip of all training loaders |
|
|
|
Returns: |
|
generator: zip generator yielding tuples: |
|
(batch_rf, batch_rn, batch_sf, batch_sn) |
|
""" |
|
return zip(*list(self.loaders["train"].values())) |
|
|
|
@property |
|
def val_loaders(self): |
|
"""Get a zip of all validation loaders |
|
|
|
Returns: |
|
generator: zip generator yielding tuples: |
|
(batch_rf, batch_rn, batch_sf, batch_sn) |
|
""" |
|
return zip(*list(self.loaders["val"].values())) |
|
|
|
def compute_latent_shape(self): |
|
"""Compute the latent shape, i.e. the Encoder's output shape, |
|
from a batch. |
|
|
|
Raises: |
|
ValueError: If no loader, the latent_shape cannot be inferred |
|
|
|
Returns: |
|
tuple: (c, h, w) |
|
""" |
|
x = None |
|
for mode in self.all_loaders: |
|
for domain in self.all_loaders.loaders[mode]: |
|
x = ( |
|
self.all_loaders[mode][domain] |
|
.dataset[0]["data"]["x"] |
|
.to(self.device) |
|
) |
|
break |
|
if x is not None: |
|
break |
|
|
|
if x is None: |
|
raise ValueError("No batch found to compute_latent_shape") |
|
|
|
x = x.unsqueeze(0) |
|
z = self.G.encode(x) |
|
return z.shape[1:] if not isinstance(z, (list, tuple)) else z[0].shape[1:] |
|
|
|
def g_opt_step(self): |
|
"""Run an optimizing step ; if using ExtraAdam, there needs to be an extrapolation |
|
step every other step |
|
""" |
|
if "extra" in self.opts.gen.opt.optimizer.lower() and ( |
|
self.logger.global_step % 2 == 0 |
|
): |
|
self.g_opt.extrapolation() |
|
else: |
|
self.g_opt.step() |
|
|
|
def d_opt_step(self): |
|
"""Run an optimizing step ; if using ExtraAdam, there needs to be an extrapolation |
|
step every other step |
|
""" |
|
if "extra" in self.opts.dis.opt.optimizer.lower() and ( |
|
self.logger.global_step % 2 == 0 |
|
): |
|
self.d_opt.extrapolation() |
|
else: |
|
self.d_opt.step() |
|
|
|
def update_learning_rates(self): |
|
if self.g_scheduler is not None: |
|
self.g_scheduler.step() |
|
if self.d_scheduler is not None: |
|
self.d_scheduler.step() |
|
|
|
def setup(self, inference=False): |
|
"""Prepare the trainer before it can be used to train the models: |
|
* initialize G and D |
|
* creates 2 optimizers |
|
""" |
|
self.logger.global_step = 0 |
|
start_time = time() |
|
self.logger.time.start_time = start_time |
|
verbose = self.verbose |
|
|
|
if not inference: |
|
self.all_loaders = get_all_loaders(self.opts) |
|
|
|
|
|
|
|
|
|
__t = time() |
|
print("Creating generator...") |
|
|
|
self.G: OmniGenerator = create_generator( |
|
self.opts, device=self.device, no_init=inference, verbose=verbose |
|
) |
|
|
|
self.has_painter = get_num_params(self.G.painter) or self.G.load_val_painter() |
|
|
|
if self.has_painter: |
|
self.G.painter.set_latent_shape(find_target_size(self.opts, "x"), True) |
|
|
|
print(f"Generator OK in {time() - __t:.1f}s.") |
|
|
|
if inference: |
|
print("Inference mode: no Discriminator, no optimizers") |
|
print_num_parameters(self) |
|
self.switch_data(to="base") |
|
if self.opts.train.resume: |
|
self.resume(True) |
|
self.eval_mode() |
|
print("Trainer is in evaluation mode.") |
|
print("Setup done.") |
|
self.is_setup = True |
|
return |
|
|
|
|
|
|
|
|
|
|
|
self.D: OmniDiscriminator = create_discriminator( |
|
self.opts, self.device, verbose=verbose |
|
) |
|
print("Discriminator OK.") |
|
|
|
print_num_parameters(self) |
|
|
|
|
|
|
|
|
|
|
|
self.g_opt, self.g_scheduler, self.lr_names["G"] = get_optimizer( |
|
self.G, self.opts.gen.opt, self.opts.tasks |
|
) |
|
|
|
if get_num_params(self.D) > 0: |
|
self.d_opt, self.d_scheduler, self.lr_names["D"] = get_optimizer( |
|
self.D, self.opts.dis.opt, self.opts.tasks, True |
|
) |
|
else: |
|
self.d_opt, self.d_scheduler = None, None |
|
|
|
self.losses = get_losses(self.opts, verbose, device=self.device) |
|
|
|
if "p" in self.opts.tasks and self.opts.gen.p.diff_aug.use: |
|
self.diff_transforms = DiffTransforms(self.opts.gen.p.diff_aug) |
|
|
|
if verbose > 0: |
|
for mode, mode_dict in self.all_loaders.items(): |
|
for domain, domain_loader in mode_dict.items(): |
|
print( |
|
"Loader {} {} : {}".format( |
|
mode, domain, len(domain_loader.dataset) |
|
) |
|
) |
|
|
|
|
|
|
|
|
|
self.set_display_images() |
|
|
|
|
|
|
|
|
|
self.logger.log_architecture() |
|
|
|
|
|
|
|
|
|
if self.kitti_pretrain: |
|
self.switch_data(to="kitti") |
|
else: |
|
self.switch_data(to="base") |
|
|
|
|
|
|
|
|
|
print(" " * 50, end="\r") |
|
print("Done creating display images") |
|
|
|
if self.opts.train.resume: |
|
print("Resuming Model (inference: False)") |
|
self.resume(False) |
|
else: |
|
print("Not resuming: starting a new model") |
|
|
|
print("Setup done.") |
|
self.is_setup = True |
|
|
|
def switch_data(self, to="kitti"): |
|
caller = inspect.stack()[1].function |
|
print(f"[{caller}] Switching data source to", to) |
|
self.data_source = to |
|
if to == "kitti": |
|
self.display_images = self.kitty_display_images |
|
if self.all_loaders is not None: |
|
self.loaders = { |
|
mode: {"s": self.all_loaders[mode]["kitti"]} |
|
for mode in self.all_loaders |
|
} |
|
else: |
|
self.display_images = self.base_display_images |
|
if self.all_loaders is not None: |
|
self.loaders = { |
|
mode: { |
|
domain: self.all_loaders[mode][domain] |
|
for domain in self.all_loaders[mode] |
|
if domain != "kitti" |
|
} |
|
for mode in self.all_loaders |
|
} |
|
if ( |
|
self.logger.global_step % 2 != 0 |
|
and "extra" in self.opts.dis.opt.optimizer.lower() |
|
): |
|
print( |
|
"Warning: artificially bumping step to run an extrapolation step first." |
|
) |
|
self.logger.global_step += 1 |
|
|
|
def set_display_images(self, use_all=False): |
|
for mode, mode_dict in self.all_loaders.items(): |
|
|
|
if self.kitti_pretrain: |
|
self.kitty_display_images[mode] = {} |
|
self.base_display_images[mode] = {} |
|
|
|
for domain in mode_dict: |
|
|
|
if self.kitti_pretrain and domain == "kitti": |
|
target_dict = self.kitty_display_images |
|
else: |
|
if domain == "kitti": |
|
continue |
|
target_dict = self.base_display_images |
|
|
|
dataset = self.all_loaders[mode][domain].dataset |
|
display_indices = ( |
|
get_display_indices(self.opts, domain, len(dataset)) |
|
if not use_all |
|
else list(range(len(dataset))) |
|
) |
|
ldis = len(display_indices) |
|
print( |
|
f" Creating {ldis} {mode} {domain} display images...", |
|
end="\r", |
|
flush=True, |
|
) |
|
target_dict[mode][domain] = [ |
|
Dict(dataset[i]) |
|
for i in display_indices |
|
if (print(f"({i})", end="\r") is None and i < len(dataset)) |
|
] |
|
if self.exp is not None: |
|
for im_id, d in enumerate(target_dict[mode][domain]): |
|
self.exp.log_parameter( |
|
"display_image_{}_{}_{}".format(mode, domain, im_id), |
|
d["paths"], |
|
) |
|
|
|
def train(self): |
|
"""For each epoch: |
|
* train |
|
* eval |
|
* save |
|
""" |
|
assert self.is_setup |
|
|
|
for self.logger.epoch in range( |
|
self.logger.epoch, self.logger.epoch + self.opts.train.epochs |
|
): |
|
|
|
if ( |
|
self.logger.epoch == self.opts.gen.p.pl4m_epoch |
|
and get_num_params(self.G.painter) > 0 |
|
and "p" in self.opts.tasks |
|
and self.opts.gen.m.use_pl4m |
|
): |
|
print( |
|
"\n\n >>> Enabling pl4m at epoch {}\n\n".format(self.logger.epoch) |
|
) |
|
self.use_pl4m = True |
|
|
|
self.run_epoch() |
|
self.run_evaluation(verbose=1) |
|
self.save() |
|
|
|
|
|
if self.logger.epoch == self.opts.train.kitti.epochs - 1: |
|
self.switch_data(to="base") |
|
self.kitti_pretrain = False |
|
|
|
|
|
if self.logger.epoch == self.opts.train.pseudo.epochs - 1: |
|
self.pseudo_training_tasks = set() |
|
|
|
def run_epoch(self): |
|
"""Runs an epoch: |
|
* checks trainer is setup |
|
* gets a tuple of batches per domain |
|
* sends batches to device |
|
* updates sequentially G, D |
|
""" |
|
assert self.is_setup |
|
self.train_mode() |
|
if self.exp is not None: |
|
self.exp.log_parameter("epoch", self.logger.epoch) |
|
epoch_len = min(len(loader) for loader in self.loaders["train"].values()) |
|
epoch_desc = "Epoch {}".format(self.logger.epoch) |
|
self.logger.time.epoch_start = time() |
|
|
|
for multi_batch_tuple in tqdm( |
|
self.train_loaders, |
|
desc=epoch_desc, |
|
total=epoch_len, |
|
mininterval=0.5, |
|
unit="batch", |
|
): |
|
|
|
self.logger.time.step_start = time() |
|
multi_batch_tuple = shuffle_batch_tuple(multi_batch_tuple) |
|
|
|
|
|
multi_domain_batch = { |
|
batch["domain"][0]: self.batch_to_device(batch) |
|
for batch in multi_batch_tuple |
|
} |
|
|
|
|
|
|
|
|
|
|
|
if self.d_opt is not None: |
|
for param in self.D.parameters(): |
|
param.requires_grad = False |
|
|
|
self.update_G(multi_domain_batch) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.d_opt is not None and not self.kitti_pretrain: |
|
for param in self.D.parameters(): |
|
param.requires_grad = True |
|
|
|
self.update_D(multi_domain_batch) |
|
|
|
|
|
|
|
|
|
self.logger.global_step += 1 |
|
self.logger.log_step_time(time()) |
|
|
|
if not self.kitti_pretrain: |
|
self.update_learning_rates() |
|
|
|
self.logger.log_learning_rates() |
|
self.logger.log_epoch_time(time()) |
|
|
|
def update_G(self, multi_domain_batch, verbose=0): |
|
"""Perform an update on g from multi_domain_batch which is a dictionary |
|
domain => batch |
|
|
|
* automatic mixed precision according to self.opts.train.amp |
|
* compute loss for each task |
|
* loss.backward() |
|
* g_opt_step() |
|
* g_opt.step() or .extrapolation() depending on self.logger.global_step |
|
* logs losses on comet.ml with self.logger.log_losses(model_to_update="G") |
|
|
|
Args: |
|
multi_domain_batch (dict): dictionnary of domain batches |
|
""" |
|
zero_grad(self.G) |
|
if self.opts.train.amp: |
|
with autocast(): |
|
g_loss = self.get_G_loss(multi_domain_batch, verbose) |
|
self.grad_scaler_g.scale(g_loss).backward() |
|
self.grad_scaler_g.step(self.g_opt) |
|
self.grad_scaler_g.update() |
|
else: |
|
g_loss = self.get_G_loss(multi_domain_batch, verbose) |
|
g_loss.backward() |
|
self.g_opt_step() |
|
|
|
self.logger.log_losses(model_to_update="G", mode="train") |
|
|
|
def update_D(self, multi_domain_batch, verbose=0): |
|
zero_grad(self.D) |
|
|
|
if self.opts.train.amp: |
|
with autocast(): |
|
d_loss = self.get_D_loss(multi_domain_batch, verbose) |
|
self.grad_scaler_d.scale(d_loss).backward() |
|
self.grad_scaler_d.step(self.d_opt) |
|
self.grad_scaler_d.update() |
|
else: |
|
d_loss = self.get_D_loss(multi_domain_batch, verbose) |
|
d_loss.backward() |
|
self.d_opt_step() |
|
|
|
self.logger.losses.disc.total_loss = d_loss.item() |
|
self.logger.log_losses(model_to_update="D", mode="train") |
|
|
|
def get_D_loss(self, multi_domain_batch, verbose=0): |
|
"""Compute the discriminators' losses: |
|
|
|
* for each domain-specific batch: |
|
* encode the image |
|
* get the conditioning tensor if using spade |
|
* source domain is the data's domain, sequentially r|s then f|n |
|
* get the target domain accordingly |
|
* compute the translated image from the data |
|
* compute the source domain discriminator's loss on the data |
|
* compute the target domain discriminator's loss on the translated image |
|
|
|
# ? In this setting, each D[decoder][domain] is updated twice towards |
|
# real or fake data |
|
|
|
See readme's update d section for details |
|
|
|
Args: |
|
multi_domain_batch ([type]): [description] |
|
|
|
Returns: |
|
[type]: [description] |
|
""" |
|
|
|
disc_loss = { |
|
"m": {"Advent": 0}, |
|
"s": {"Advent": 0}, |
|
} |
|
if self.opts.dis.p.use_local_discriminator: |
|
disc_loss["p"] = {"global": 0, "local": 0} |
|
else: |
|
disc_loss["p"] = {"gan": 0} |
|
|
|
for domain, batch in multi_domain_batch.items(): |
|
x = batch["data"]["x"] |
|
|
|
|
|
|
|
|
|
if domain == "rf" and self.has_painter: |
|
m = batch["data"]["m"] |
|
|
|
with torch.no_grad(): |
|
|
|
fake = self.G.paint(m, x) |
|
if self.opts.gen.p.diff_aug.use: |
|
fake = self.diff_transforms(fake) |
|
x = self.diff_transforms(x) |
|
fake = fake.detach() |
|
fake.requires_grad_() |
|
|
|
if self.opts.dis.p.use_local_discriminator: |
|
fake_d_global = self.D["p"]["global"](fake) |
|
real_d_global = self.D["p"]["global"](x) |
|
|
|
fake_d_local = self.D["p"]["local"](fake * m) |
|
real_d_local = self.D["p"]["local"](x * m) |
|
|
|
global_loss = self.losses["D"]["p"](fake_d_global, False, True) |
|
global_loss += self.losses["D"]["p"](real_d_global, True, True) |
|
|
|
local_loss = self.losses["D"]["p"](fake_d_local, False, True) |
|
local_loss += self.losses["D"]["p"](real_d_local, True, True) |
|
|
|
disc_loss["p"]["global"] += global_loss |
|
disc_loss["p"]["local"] += local_loss |
|
else: |
|
real_cat = torch.cat([m, x], axis=1) |
|
fake_cat = torch.cat([m, fake], axis=1) |
|
real_fake_cat = torch.cat([real_cat, fake_cat], dim=0) |
|
real_fake_d = self.D["p"](real_fake_cat) |
|
real_d, fake_d = divide_pred(real_fake_d) |
|
disc_loss["p"]["gan"] = self.losses["D"]["p"](fake_d, False, True) |
|
disc_loss["p"]["gan"] += self.losses["D"]["p"](real_d, True, True) |
|
|
|
|
|
|
|
|
|
else: |
|
z = self.G.encode(x) |
|
s_pred = d_pred = cond = z_depth = None |
|
|
|
if "s" in batch["data"]: |
|
if "d" in self.opts.tasks and self.opts.gen.s.use_dada: |
|
d_pred, z_depth = self.G.decoders["d"](z) |
|
|
|
step_loss, s_pred = self.masker_s_loss( |
|
x, z, d_pred, z_depth, None, domain, for_="D" |
|
) |
|
step_loss *= self.opts.train.lambdas.advent.adv_main |
|
disc_loss["s"]["Advent"] += step_loss |
|
|
|
if "m" in batch["data"]: |
|
if "d" in self.opts.tasks: |
|
if self.opts.gen.m.use_spade: |
|
if d_pred is None: |
|
d_pred, z_depth = self.G.decoders["d"](z) |
|
cond = self.G.make_m_cond(d_pred, s_pred, x) |
|
elif self.opts.gen.m.use_dada: |
|
if d_pred is None: |
|
d_pred, z_depth = self.G.decoders["d"](z) |
|
|
|
step_loss, _ = self.masker_m_loss( |
|
x, |
|
z, |
|
None, |
|
domain, |
|
for_="D", |
|
cond=cond, |
|
z_depth=z_depth, |
|
depth_preds=d_pred, |
|
) |
|
step_loss *= self.opts.train.lambdas.advent.adv_main |
|
disc_loss["m"]["Advent"] += step_loss |
|
|
|
self.logger.losses.disc.update( |
|
{ |
|
dom: { |
|
k: v.item() if isinstance(v, torch.Tensor) else v |
|
for k, v in d.items() |
|
} |
|
for dom, d in disc_loss.items() |
|
} |
|
) |
|
|
|
loss = sum(v for d in disc_loss.values() for k, v in d.items()) |
|
return loss |
|
|
|
def get_G_loss(self, multi_domain_batch, verbose=0): |
|
m_loss = p_loss = None |
|
|
|
|
|
g_loss = 0 |
|
|
|
if any(t in self.opts.tasks for t in "msd"): |
|
m_loss = self.get_masker_loss(multi_domain_batch) |
|
self.logger.losses.gen.masker = m_loss.item() |
|
g_loss += m_loss |
|
|
|
if "p" in self.opts.tasks and not self.kitti_pretrain: |
|
p_loss = self.get_painter_loss(multi_domain_batch) |
|
self.logger.losses.gen.painter = p_loss.item() |
|
g_loss += p_loss |
|
|
|
assert g_loss != 0 and not isinstance(g_loss, int), "No update in get_G_loss!" |
|
|
|
self.logger.losses.gen.total_loss = g_loss.item() |
|
|
|
return g_loss |
|
|
|
def get_masker_loss(self, multi_domain_batch): |
|
"""Only update the representation part of the model, meaning everything |
|
but the translation part |
|
|
|
* for each batch in available domains: |
|
* compute task-specific losses |
|
* compute the adaptation and translation decoders' auto-encoding losses |
|
* compute the adaptation decoder's translation losses (GAN and Cycle) |
|
|
|
Args: |
|
multi_domain_batch (dict): dictionnary mapping domain names to batches from |
|
the trainer's loaders |
|
|
|
Returns: |
|
torch.Tensor: scalar loss tensor, weighted according to opts.train.lambdas |
|
""" |
|
m_loss = 0 |
|
for domain, batch in multi_domain_batch.items(): |
|
|
|
if domain == "rf": |
|
continue |
|
|
|
x = batch["data"]["x"] |
|
z = self.G.encode(x) |
|
|
|
|
|
|
|
|
|
d_pred = s_pred = z_depth = None |
|
for task in ["d", "s", "m"]: |
|
if task not in batch["data"]: |
|
continue |
|
|
|
target = batch["data"][task] |
|
|
|
if task == "d": |
|
loss, d_pred, z_depth = self.masker_d_loss( |
|
x, z, target, domain, "G" |
|
) |
|
m_loss += loss |
|
self.logger.losses.gen.task["d"][domain] = loss.item() |
|
|
|
elif task == "s": |
|
loss, s_pred = self.masker_s_loss( |
|
x, z, d_pred, z_depth, target, domain, "G" |
|
) |
|
m_loss += loss |
|
self.logger.losses.gen.task["s"][domain] = loss.item() |
|
|
|
elif task == "m": |
|
cond = None |
|
if self.opts.gen.m.use_spade: |
|
if not self.opts.gen.m.detach: |
|
d_pred = d_pred.clone() |
|
s_pred = s_pred.clone() |
|
cond = self.G.make_m_cond(d_pred, s_pred, x) |
|
|
|
loss, _ = self.masker_m_loss( |
|
x, |
|
z, |
|
target, |
|
domain, |
|
"G", |
|
cond=cond, |
|
z_depth=z_depth, |
|
depth_preds=d_pred, |
|
) |
|
m_loss += loss |
|
self.logger.losses.gen.task["m"][domain] = loss.item() |
|
|
|
return m_loss |
|
|
|
def get_painter_loss(self, multi_domain_batch): |
|
"""Computes the translation loss when flooding/deflooding images |
|
|
|
Args: |
|
multi_domain_batch (dict): dictionnary mapping domain names to batches from |
|
the trainer's loaders |
|
|
|
Returns: |
|
torch.Tensor: scalar loss tensor, weighted according to opts.train.lambdas |
|
""" |
|
step_loss = 0 |
|
|
|
lambdas = self.opts.train.lambdas |
|
batch_domain = "rf" |
|
batch = multi_domain_batch[batch_domain] |
|
|
|
x = batch["data"]["x"] |
|
|
|
|
|
m = batch["data"]["m"] |
|
fake_flooded = self.G.paint(m, x) |
|
|
|
|
|
|
|
|
|
if lambdas.G.p.vgg != 0: |
|
loss = self.losses["G"]["p"]["vgg"]( |
|
vgg_preprocess(fake_flooded * m), vgg_preprocess(x * m) |
|
) |
|
loss *= lambdas.G.p.vgg |
|
self.logger.losses.gen.p.vgg = loss.item() |
|
step_loss += loss |
|
|
|
|
|
|
|
|
|
if lambdas.G.p.tv != 0: |
|
loss = self.losses["G"]["p"]["tv"](fake_flooded * m) |
|
loss *= lambdas.G.p.tv |
|
self.logger.losses.gen.p.tv = loss.item() |
|
step_loss += loss |
|
|
|
|
|
|
|
|
|
if lambdas.G.p.context != 0: |
|
loss = self.losses["G"]["p"]["context"](fake_flooded, x, m) |
|
loss *= lambdas.G.p.context |
|
self.logger.losses.gen.p.context = loss.item() |
|
step_loss += loss |
|
|
|
|
|
|
|
|
|
if lambdas.G.p.reconstruction != 0: |
|
loss = self.losses["G"]["p"]["reconstruction"](fake_flooded, x, m) |
|
loss *= lambdas.G.p.reconstruction |
|
self.logger.losses.gen.p.reconstruction = loss.item() |
|
step_loss += loss |
|
|
|
|
|
|
|
|
|
if self.opts.gen.p.diff_aug.use: |
|
fake_flooded = self.diff_transforms(fake_flooded) |
|
x = self.diff_transforms(x) |
|
|
|
if self.opts.dis.p.use_local_discriminator: |
|
fake_d_global = self.D["p"]["global"](fake_flooded) |
|
fake_d_local = self.D["p"]["local"](fake_flooded * m) |
|
|
|
real_d_global = self.D["p"]["global"](x) |
|
|
|
|
|
|
|
|
|
self.logger.losses.gen.p.gan = 0 |
|
|
|
loss = self.losses["G"]["p"]["gan"](fake_d_global, True, False) |
|
loss += self.losses["G"]["p"]["gan"](fake_d_local, True, False) |
|
loss *= lambdas.G["p"]["gan"] |
|
|
|
self.logger.losses.gen.p.gan = loss.item() |
|
|
|
step_loss += loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.opts.dis.p.get_intermediate_features: |
|
loss = self.losses["G"]["p"]["featmatch"](real_d_global, fake_d_global) |
|
loss *= lambdas.G["p"]["featmatch"] |
|
|
|
if isinstance(loss, float): |
|
self.logger.losses.gen.p.featmatch = loss |
|
else: |
|
self.logger.losses.gen.p.featmatch = loss.item() |
|
|
|
step_loss += loss |
|
|
|
|
|
|
|
|
|
else: |
|
real_cat = torch.cat([m, x], axis=1) |
|
fake_cat = torch.cat([m, fake_flooded], axis=1) |
|
real_fake_cat = torch.cat([real_cat, fake_cat], dim=0) |
|
|
|
real_fake_d = self.D["p"](real_fake_cat) |
|
real_d, fake_d = divide_pred(real_fake_d) |
|
|
|
loss = self.losses["G"]["p"]["gan"](fake_d, True, False) |
|
self.logger.losses.gen.p.gan = loss.item() |
|
step_loss += loss |
|
|
|
|
|
|
|
|
|
if self.opts.dis.p.get_intermediate_features and lambdas.G.p.featmatch != 0: |
|
loss = self.losses["G"]["p"]["featmatch"](real_d, fake_d) |
|
loss *= lambdas.G.p.featmatch |
|
|
|
if isinstance(loss, float): |
|
self.logger.losses.gen.p.featmatch = loss |
|
else: |
|
self.logger.losses.gen.p.featmatch = loss.item() |
|
|
|
step_loss += loss |
|
|
|
return step_loss |
|
|
|
def masker_d_loss(self, x, z, target, domain, for_="G"): |
|
assert for_ in {"G", "D"} |
|
self.assert_z_matches_x(x, z) |
|
assert x.shape[0] == target.shape[0] |
|
zero_loss = torch.tensor(0.0, device=self.device) |
|
weight = self.opts.train.lambdas.G.d.main |
|
|
|
prediction, z_depth = self.G.decoders["d"](z) |
|
|
|
if self.opts.gen.d.classify.enable: |
|
target.squeeze_(1) |
|
|
|
full_loss = self.losses["G"]["tasks"]["d"](prediction, target) |
|
full_loss *= weight |
|
|
|
if weight == 0 or (domain == "r" and "d" not in self.pseudo_training_tasks): |
|
return zero_loss, prediction, z_depth |
|
|
|
return full_loss, prediction, z_depth |
|
|
|
def masker_s_loss(self, x, z, depth_preds, z_depth, target, domain, for_="G"): |
|
assert for_ in {"G", "D"} |
|
assert domain in {"r", "s"} |
|
self.assert_z_matches_x(x, z) |
|
assert x.shape[0] == target.shape[0] if target is not None else True |
|
full_loss = torch.tensor(0.0, device=self.device) |
|
softmax_preds = None |
|
|
|
|
|
|
|
pred = None |
|
if for_ == "G" or self.opts.gen.s.use_advent: |
|
pred = self.G.decoders["s"](z, z_depth) |
|
|
|
|
|
|
|
if for_ == "G": |
|
if domain == "s" or "s" in self.pseudo_training_tasks: |
|
if domain == "s": |
|
logger = self.logger.losses.gen.task["s"]["crossent"] |
|
weight = self.opts.train.lambdas.G["s"]["crossent"] |
|
else: |
|
logger = self.logger.losses.gen.task["s"]["crossent_pseudo"] |
|
weight = self.opts.train.lambdas.G["s"]["crossent_pseudo"] |
|
|
|
if weight != 0: |
|
|
|
loss_func = self.losses["G"]["tasks"]["s"]["crossent"] |
|
loss = loss_func(pred, target.squeeze(1)) |
|
loss *= weight |
|
full_loss += loss |
|
logger[domain] = loss.item() |
|
|
|
if domain == "r": |
|
weight = self.opts.train.lambdas.G["s"]["minent"] |
|
if self.opts.gen.s.use_minent and weight != 0: |
|
softmax_preds = softmax(pred, dim=1) |
|
|
|
loss = self.losses["G"]["tasks"]["s"]["minent"](softmax_preds) |
|
loss *= weight |
|
full_loss += loss |
|
|
|
self.logger.losses.gen.task["s"]["minent"]["r"] = loss.item() |
|
|
|
|
|
if self.opts.gen.s.use_advent: |
|
if self.opts.gen.s.use_dada and depth_preds is not None: |
|
depth_preds = depth_preds.detach() |
|
else: |
|
depth_preds = None |
|
|
|
if for_ == "D": |
|
domain_label = domain |
|
logger = {} |
|
loss_func = self.losses["D"]["advent"] |
|
pred = pred.detach() |
|
weight = self.opts.train.lambdas.advent.adv_main |
|
else: |
|
domain_label = "s" |
|
logger = self.logger.losses.gen.task["s"]["advent"] |
|
loss_func = self.losses["G"]["tasks"]["s"]["advent"] |
|
weight = self.opts.train.lambdas.G["s"]["advent"] |
|
|
|
if (for_ == "D" or domain == "r") and weight != 0: |
|
if softmax_preds is None: |
|
softmax_preds = softmax(pred, dim=1) |
|
loss = loss_func( |
|
softmax_preds, |
|
self.domain_labels[domain_label], |
|
self.D["s"]["Advent"], |
|
depth_preds, |
|
) |
|
loss *= weight |
|
full_loss += loss |
|
logger[domain] = loss.item() |
|
|
|
if for_ == "D": |
|
|
|
if self.opts.dis.s.gan_type == "GAN" or "WGAN_norm": |
|
pass |
|
elif self.opts.dis.s.gan_type == "WGAN": |
|
for p in self.D["s"]["Advent"].parameters(): |
|
p.data.clamp_( |
|
self.opts.dis.s.wgan_clamp_lower, |
|
self.opts.dis.s.wgan_clamp_upper, |
|
) |
|
elif self.opts.dis.s.gan_type == "WGAN_gp": |
|
prob_need_grad = autograd.Variable(pred, requires_grad=True) |
|
d_out = self.D["s"]["Advent"](prob_need_grad) |
|
gp = get_WGAN_gradient(prob_need_grad, d_out) |
|
gp_loss = gp * self.opts.train.lambdas.advent.WGAN_gp |
|
full_loss += gp_loss |
|
else: |
|
raise NotImplementedError |
|
|
|
return full_loss, pred |
|
|
|
def masker_m_loss( |
|
self, x, z, target, domain, for_="G", cond=None, z_depth=None, depth_preds=None |
|
): |
|
assert for_ in {"G", "D"} |
|
assert domain in {"r", "s"} |
|
self.assert_z_matches_x(x, z) |
|
assert x.shape[0] == target.shape[0] if target is not None else True |
|
full_loss = torch.tensor(0.0, device=self.device) |
|
|
|
pred_logits = self.G.decoders["m"](z, cond=cond, z_depth=z_depth) |
|
pred_prob = sigmoid(pred_logits) |
|
pred_prob_complementary = 1 - pred_prob |
|
prob = torch.cat([pred_prob, pred_prob_complementary], dim=1) |
|
|
|
if for_ == "G": |
|
|
|
weight = self.opts.train.lambdas.G.m.tv |
|
if weight != 0: |
|
loss = self.losses["G"]["tasks"]["m"]["tv"](pred_prob) |
|
loss *= weight |
|
full_loss += loss |
|
|
|
self.logger.losses.gen.task["m"]["tv"][domain] = loss.item() |
|
|
|
weight = self.opts.train.lambdas.G.m.bce |
|
if domain == "s" and weight != 0: |
|
|
|
loss = self.losses["G"]["tasks"]["m"]["bce"](pred_logits, target) |
|
loss *= weight |
|
full_loss += loss |
|
self.logger.losses.gen.task["m"]["bce"]["s"] = loss.item() |
|
|
|
if domain == "r": |
|
|
|
weight = self.opts.train.lambdas.G["m"]["gi"] |
|
if self.opts.gen.m.use_ground_intersection and weight != 0: |
|
|
|
loss = self.losses["G"]["tasks"]["m"]["gi"](pred_prob, target) |
|
loss *= weight |
|
full_loss += loss |
|
self.logger.losses.gen.task["m"]["gi"]["r"] = loss.item() |
|
|
|
weight = self.opts.train.lambdas.G.m.pl4m |
|
if self.use_pl4m and weight != 0: |
|
|
|
pl4m_loss = self.painter_loss_for_masker(x, pred_prob) |
|
pl4m_loss *= weight |
|
full_loss += pl4m_loss |
|
self.logger.losses.gen.task.m.pl4m.r = pl4m_loss.item() |
|
|
|
weight = self.opts.train.lambdas.advent.ent_main |
|
if self.opts.gen.m.use_minent and weight != 0: |
|
|
|
loss = self.losses["G"]["tasks"]["m"]["minent"](prob) |
|
loss *= weight |
|
full_loss += loss |
|
self.logger.losses.gen.task["m"]["minent"]["r"] = loss.item() |
|
|
|
if self.opts.gen.m.use_advent: |
|
|
|
if self.opts.gen.m.use_dada and depth_preds is not None: |
|
depth_preds = depth_preds.detach() |
|
depth_preds = torch.nn.functional.interpolate( |
|
depth_preds, size=x.shape[-2:], mode="nearest" |
|
) |
|
else: |
|
depth_preds = None |
|
|
|
if for_ == "D": |
|
domain_label = domain |
|
logger = {} |
|
loss_func = self.losses["D"]["advent"] |
|
prob = prob.detach() |
|
weight = self.opts.train.lambdas.advent.adv_main |
|
else: |
|
domain_label = "s" |
|
logger = self.logger.losses.gen.task["m"]["advent"] |
|
loss_func = self.losses["G"]["tasks"]["m"]["advent"] |
|
weight = self.opts.train.lambdas.advent.adv_main |
|
|
|
if (for_ == "D" or domain == "r") and weight != 0: |
|
loss = loss_func( |
|
prob.to(self.device), |
|
self.domain_labels[domain_label], |
|
self.D["m"]["Advent"], |
|
depth_preds, |
|
) |
|
loss *= weight |
|
full_loss += loss |
|
logger[domain] = loss.item() |
|
|
|
if for_ == "D": |
|
|
|
if self.opts.dis.m.gan_type == "GAN" or "WGAN_norm": |
|
pass |
|
elif self.opts.dis.m.gan_type == "WGAN": |
|
for p in self.D["s"]["Advent"].parameters(): |
|
p.data.clamp_( |
|
self.opts.dis.m.wgan_clamp_lower, |
|
self.opts.dis.m.wgan_clamp_upper, |
|
) |
|
elif self.opts.dis.m.gan_type == "WGAN_gp": |
|
prob_need_grad = autograd.Variable(prob, requires_grad=True) |
|
d_out = self.D["s"]["Advent"](prob_need_grad) |
|
gp = get_WGAN_gradient(prob_need_grad, d_out) |
|
gp_loss = self.opts.train.lambdas.advent.WGAN_gp * gp |
|
full_loss += gp_loss |
|
else: |
|
raise NotImplementedError |
|
|
|
return full_loss, prob |
|
|
|
def painter_loss_for_masker(self, x, m): |
|
|
|
|
|
for param in self.G.painter.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
|
|
fake_flooded = self.G.paint(m, x) |
|
|
|
if self.opts.dis.p.use_local_discriminator: |
|
fake_d_global = self.D["p"]["global"](fake_flooded) |
|
fake_d_local = self.D["p"]["local"](fake_flooded * m) |
|
|
|
|
|
|
|
|
|
pl4m_loss = self.losses["G"]["p"]["gan"](fake_d_global, True, False) |
|
pl4m_loss += self.losses["G"]["p"]["gan"](fake_d_local, True, False) |
|
else: |
|
real_cat = torch.cat([m, x], axis=1) |
|
fake_cat = torch.cat([m, fake_flooded], axis=1) |
|
real_fake_cat = torch.cat([real_cat, fake_cat], dim=0) |
|
|
|
real_fake_d = self.D["p"](real_fake_cat) |
|
_, fake_d = divide_pred(real_fake_d) |
|
|
|
pl4m_loss = self.losses["G"]["p"]["gan"](fake_d, True, False) |
|
|
|
if "p" in self.opts.tasks: |
|
for param in self.G.painter.parameters(): |
|
param.requires_grad = True |
|
|
|
return pl4m_loss |
|
|
|
@torch.no_grad() |
|
def run_evaluation(self, verbose=0): |
|
print("******************* Running Evaluation ***********************") |
|
start_time = time() |
|
self.eval_mode() |
|
val_logger = None |
|
nb_of_batches = None |
|
for i, multi_batch_tuple in enumerate(self.val_loaders): |
|
|
|
|
|
|
|
nb_of_batches = i + 1 |
|
multi_domain_batch = { |
|
batch["domain"][0]: self.batch_to_device(batch) |
|
for batch in multi_batch_tuple |
|
} |
|
self.get_G_loss(multi_domain_batch, verbose) |
|
|
|
if val_logger is None: |
|
val_logger = deepcopy(self.logger.losses.generator) |
|
else: |
|
val_logger = sum_dict(val_logger, self.logger.losses.generator) |
|
|
|
val_logger = div_dict(val_logger, nb_of_batches) |
|
self.logger.losses.generator = val_logger |
|
self.logger.log_losses(model_to_update="G", mode="val") |
|
|
|
for d in self.opts.domains: |
|
self.logger.log_comet_images("train", d) |
|
self.logger.log_comet_images("val", d) |
|
|
|
if "m" in self.opts.tasks and self.has_painter and not self.kitti_pretrain: |
|
self.logger.log_comet_combined_images("train", "r") |
|
self.logger.log_comet_combined_images("val", "r") |
|
|
|
if self.exp is not None: |
|
print() |
|
|
|
if "m" in self.opts.tasks or "s" in self.opts.tasks: |
|
self.eval_images("val", "r") |
|
self.eval_images("val", "s") |
|
|
|
if "p" in self.opts.tasks and not self.kitti_pretrain: |
|
val_fid = compute_val_fid(self) |
|
if self.exp is not None: |
|
self.exp.log_metric("val_fid", val_fid, step=self.logger.global_step) |
|
else: |
|
print("Validation FID Score", val_fid) |
|
|
|
self.train_mode() |
|
timing = int(time() - start_time) |
|
print("****************** Done in {}s *********************".format(timing)) |
|
|
|
def eval_images(self, mode, domain): |
|
if domain == "s" and self.kitti_pretrain: |
|
domain = "kitti" |
|
if domain == "rf" or domain not in self.display_images[mode]: |
|
return |
|
|
|
metric_funcs = {"accuracy": accuracy, "mIOU": mIOU} |
|
metric_avg_scores = {"m": {}} |
|
if "s" in self.opts.tasks: |
|
metric_avg_scores["s"] = {} |
|
if "d" in self.opts.tasks and domain == "s" and self.opts.gen.d.classify.enable: |
|
metric_avg_scores["d"] = {} |
|
|
|
for key in metric_funcs: |
|
for task in metric_avg_scores: |
|
metric_avg_scores[task][key] = [] |
|
|
|
for im_set in self.display_images[mode][domain]: |
|
x = im_set["data"]["x"].unsqueeze(0).to(self.device) |
|
z = self.G.encode(x) |
|
|
|
s_pred = d_pred = z_depth = None |
|
|
|
if "d" in metric_avg_scores: |
|
d_pred, z_depth = self.G.decoders["d"](z) |
|
d_pred = d_pred.detach().cpu() |
|
|
|
if domain == "s": |
|
d = im_set["data"]["d"].unsqueeze(0).detach() |
|
|
|
for metric in metric_funcs: |
|
metric_score = metric_funcs[metric](d_pred, d) |
|
metric_avg_scores["d"][metric].append(metric_score) |
|
|
|
if "s" in metric_avg_scores: |
|
if z_depth is None: |
|
if self.opts.gen.s.use_dada and "d" in self.opts.tasks: |
|
_, z_depth = self.G.decoders["d"](z) |
|
s_pred = self.G.decoders["s"](z, z_depth).detach().cpu() |
|
s = im_set["data"]["s"].unsqueeze(0).detach() |
|
|
|
for metric in metric_funcs: |
|
metric_score = metric_funcs[metric](s_pred, s) |
|
metric_avg_scores["s"][metric].append(metric_score) |
|
|
|
if "m" in self.opts: |
|
cond = None |
|
if s_pred is not None and d_pred is not None: |
|
cond = self.G.make_m_cond(d_pred, s_pred, x) |
|
if z_depth is None: |
|
if self.opts.gen.m.use_dada and "d" in self.opts.tasks: |
|
_, z_depth = self.G.decoders["d"](z) |
|
|
|
pred_mask = ( |
|
(self.G.mask(z=z, cond=cond, z_depth=z_depth)).detach().cpu() |
|
) |
|
pred_mask = (pred_mask > 0.5).to(torch.float32) |
|
pred_prob = torch.cat([1 - pred_mask, pred_mask], dim=1) |
|
|
|
m = im_set["data"]["m"].unsqueeze(0).detach() |
|
|
|
for metric in metric_funcs: |
|
if metric != "mIOU": |
|
metric_score = metric_funcs[metric](pred_mask, m) |
|
else: |
|
metric_score = metric_funcs[metric](pred_prob, m) |
|
|
|
metric_avg_scores["m"][metric].append(metric_score) |
|
|
|
metric_avg_scores = { |
|
task: { |
|
metric: np.mean(values) if values else float("nan") |
|
for metric, values in met_dict.items() |
|
} |
|
for task, met_dict in metric_avg_scores.items() |
|
} |
|
metric_avg_scores = { |
|
task: { |
|
metric: value if not np.isnan(value) else -1 |
|
for metric, value in met_dict.items() |
|
} |
|
for task, met_dict in metric_avg_scores.items() |
|
} |
|
if self.exp is not None: |
|
self.exp.log_metrics( |
|
flatten_opts(metric_avg_scores), |
|
prefix=f"metrics_{mode}_{domain}", |
|
step=self.logger.global_step, |
|
) |
|
else: |
|
print(f"metrics_{mode}_{domain}") |
|
print(flatten_opts(metric_avg_scores)) |
|
|
|
return 0 |
|
|
|
def functional_test_mode(self): |
|
import atexit |
|
|
|
self.opts.output_path = ( |
|
Path("~").expanduser() / "climategan" / "functional_tests" |
|
) |
|
Path(self.opts.output_path).mkdir(parents=True, exist_ok=True) |
|
with open(Path(self.opts.output_path) / "is_functional.test", "w") as f: |
|
f.write("trainer functional test - delete this dir") |
|
|
|
if self.exp is not None: |
|
self.exp.log_parameter("is_functional_test", True) |
|
atexit.register(self.del_output_path) |
|
|
|
def del_output_path(self, force=False): |
|
import shutil |
|
|
|
if not Path(self.opts.output_path).exists(): |
|
return |
|
|
|
if (Path(self.opts.output_path) / "is_functional.test").exists() or force: |
|
shutil.rmtree(self.opts.output_path) |
|
|
|
def compute_fire(self, x, seg_preds=None, z=None, z_depth=None): |
|
""" |
|
Transforms input tensor given wildfires event |
|
Args: |
|
x (torch.Tensor): Input tensor |
|
seg_preds (torch.Tensor): Semantic segmentation |
|
predictions for input tensor |
|
z (torch.Tensor): Latent vector of encoded "x". |
|
Can be None if seg_preds is given. |
|
Returns: |
|
torch.Tensor: Wildfire version of input tensor |
|
""" |
|
|
|
if seg_preds is None: |
|
if z is None: |
|
z = self.G.encode(x) |
|
seg_preds = self.G.decoders["s"](z, z_depth) |
|
|
|
return add_fire(x, seg_preds, self.opts.events.fire) |
|
|
|
def compute_flood( |
|
self, x, z=None, z_depth=None, m=None, s=None, cloudy=None, bin_value=-1 |
|
): |
|
""" |
|
Applies a flood (mask + paint) to an input image, with optionally |
|
pre-computed masker z or mask |
|
|
|
Args: |
|
x (torch.Tensor): B x C x H x W -1:1 input image |
|
z (torch.Tensor, optional): B x C x H x W Masker latent vector. |
|
Defaults to None. |
|
m (torch.Tensor, optional): B x 1 x H x W Mask. Defaults to None. |
|
bin_value (float, optional): Mask binarization value. |
|
Set to -1 to use smooth masks (no binarization) |
|
|
|
Returns: |
|
torch.Tensor: B x 3 x H x W -1:1 flooded image |
|
""" |
|
|
|
if m is None: |
|
if z is None: |
|
z = self.G.encode(x) |
|
if "d" in self.opts.tasks and self.opts.gen.m.use_dada and z_depth is None: |
|
_, z_depth = self.G.decoders["d"](z) |
|
m = self.G.mask(x=x, z=z, z_depth=z_depth) |
|
|
|
if bin_value >= 0: |
|
m = (m > bin_value).to(m.dtype) |
|
|
|
if cloudy: |
|
assert s is not None |
|
return self.G.paint_cloudy(m, x, s) |
|
|
|
return self.G.paint(m, x) |
|
|
|
def compute_smog(self, x, z=None, d=None, s=None, use_sky_seg=False): |
|
|
|
|
|
sky_mask = None |
|
if d is None or (use_sky_seg and s is None): |
|
if z is None: |
|
z = self.G.encode(x) |
|
if d is None: |
|
d, _ = self.G.decoders["d"](z) |
|
if use_sky_seg and s is None: |
|
if "s" not in self.opts.tasks: |
|
raise ValueError( |
|
"Cannot have " |
|
+ "(use_sky_seg is True and s is None and 's' not in tasks)" |
|
) |
|
s = self.G.decoders["s"](z) |
|
|
|
|
|
|
|
params = self.opts.events.smog |
|
|
|
airlight = params.airlight * torch.ones(3) |
|
airlight = airlight.view(1, -1, 1, 1).to(self.device) |
|
|
|
irradiance = srgb2lrgb(x) |
|
|
|
beta = torch.tensor([params.beta / params.vr] * 3) |
|
beta = beta.view(1, -1, 1, 1).to(self.device) |
|
|
|
d = normalize(d, mini=0.3, maxi=1.0) |
|
d = 1.0 / d |
|
d = normalize(d, mini=0.1, maxi=1) |
|
|
|
if sky_mask is not None: |
|
d[sky_mask] = 1 |
|
|
|
d = torch.nn.functional.interpolate( |
|
d, size=x.shape[-2:], mode="bilinear", align_corners=True |
|
) |
|
|
|
d = d.repeat(1, 3, 1, 1) |
|
|
|
transmission = torch.exp(d * -beta) |
|
|
|
smogged = transmission * irradiance + (1 - transmission) * airlight |
|
|
|
smogged = lrgb2srgb(smogged) |
|
|
|
|
|
alpha = params.alpha / 255 |
|
yellow_mask = torch.Tensor([params.yellow_color]) / 255 |
|
yellow_filter = ( |
|
yellow_mask.unsqueeze(2) |
|
.unsqueeze(2) |
|
.repeat(1, 1, smogged.shape[-2], smogged.shape[-1]) |
|
.to(self.device) |
|
) |
|
|
|
smogged = smogged * (1 - alpha) + yellow_filter * alpha |
|
|
|
return smogged |
|
|