|
import numpy as np |
|
import einops |
|
import torch |
|
import torch as th |
|
import torch.nn as nn |
|
from torch.nn import functional as thf |
|
import pytorch_lightning as pl |
|
import torchvision |
|
from copy import deepcopy |
|
from ldm.modules.diffusionmodules.util import ( |
|
conv_nd, |
|
linear, |
|
zero_module, |
|
timestep_embedding, |
|
) |
|
from contextlib import contextmanager, nullcontext |
|
from einops import rearrange, repeat |
|
from torchvision.utils import make_grid |
|
from ldm.modules.attention import SpatialTransformer |
|
from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock |
|
from ldm.models.diffusion.ddpm import LatentDiffusion |
|
from ldm.util import log_txt_as_img, exists, instantiate_from_config, default |
|
from ldm.models.diffusion.ddim import DDIMSampler |
|
from ldm.modules.ema import LitEma |
|
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution |
|
from ldm.modules.diffusionmodules.model import Encoder |
|
import lpips |
|
import kornia |
|
from kornia import color |
|
|
|
def disabled_train(self, mode=True): |
|
"""Overwrite model.train with this function to make sure train/eval mode |
|
does not change anymore.""" |
|
return self |
|
|
|
class View(nn.Module): |
|
def __init__(self, *shape): |
|
super().__init__() |
|
self.shape = shape |
|
|
|
def forward(self, x): |
|
return x.view(*self.shape) |
|
|
|
|
|
class SecretEncoder3(nn.Module): |
|
def __init__(self, secret_len, base_res=16, resolution=64) -> None: |
|
super().__init__() |
|
log_resolution = int(np.log2(resolution)) |
|
log_base = int(np.log2(base_res)) |
|
self.secret_len = secret_len |
|
self.secret_scaler = nn.Sequential( |
|
nn.Linear(secret_len, base_res*base_res*3), |
|
nn.SiLU(), |
|
View(-1, 3, base_res, base_res), |
|
nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), |
|
zero_module(conv_nd(2, 3, 3, 3, padding=1)) |
|
) |
|
|
|
def copy_encoder_weight(self, ae_model): |
|
|
|
return None |
|
|
|
def encode(self, x): |
|
x = self.secret_scaler(x) |
|
return x |
|
|
|
def forward(self, x, c): |
|
|
|
c = self.encode(c) |
|
return c, None |
|
|
|
|
|
class SecretEncoder4(nn.Module): |
|
"""same as SecretEncoder3 but with ch as input""" |
|
def __init__(self, secret_len, ch=3, base_res=16, resolution=64) -> None: |
|
super().__init__() |
|
log_resolution = int(np.log2(resolution)) |
|
log_base = int(np.log2(base_res)) |
|
self.secret_len = secret_len |
|
self.secret_scaler = nn.Sequential( |
|
nn.Linear(secret_len, base_res*base_res*ch), |
|
nn.SiLU(), |
|
View(-1, ch, base_res, base_res), |
|
nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), |
|
zero_module(conv_nd(2, ch, ch, 3, padding=1)) |
|
) |
|
|
|
def copy_encoder_weight(self, ae_model): |
|
|
|
return None |
|
|
|
def encode(self, x): |
|
x = self.secret_scaler(x) |
|
return x |
|
|
|
def forward(self, x, c): |
|
|
|
c = self.encode(c) |
|
return c, None |
|
|
|
class SecretEncoder6(nn.Module): |
|
"""join img emb with secret emb""" |
|
def __init__(self, secret_len, ch=3, base_res=16, resolution=64, emode='c3') -> None: |
|
super().__init__() |
|
assert emode in ['c3', 'c2', 'm3'] |
|
|
|
if emode == 'c3': |
|
secret_ch = ch |
|
join_ch = 2*ch |
|
elif emode == 'c2': |
|
secret_ch = 2 |
|
join_ch = ch |
|
elif emode == 'm3': |
|
secret_ch = ch |
|
join_ch = ch |
|
|
|
|
|
log_resolution = int(np.log2(resolution)) |
|
log_base = int(np.log2(base_res)) |
|
self.secret_len = secret_len |
|
self.emode = emode |
|
self.resolution = resolution |
|
self.secret_scaler = nn.Sequential( |
|
nn.Linear(secret_len, base_res*base_res*secret_ch), |
|
nn.SiLU(), |
|
View(-1, secret_ch, base_res, base_res), |
|
nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), |
|
) |
|
self.join_encoder = nn.Sequential( |
|
conv_nd(2, join_ch, join_ch, 3, padding=1), |
|
nn.SiLU(), |
|
conv_nd(2, join_ch, ch, 3, padding=1), |
|
nn.SiLU(), |
|
conv_nd(2, ch, ch, 3, padding=1), |
|
nn.SiLU() |
|
) |
|
self.out_layer = zero_module(conv_nd(2, ch, ch, 3, padding=1)) |
|
|
|
def copy_encoder_weight(self, ae_model): |
|
|
|
return None |
|
|
|
def encode(self, x): |
|
x = self.secret_scaler(x) |
|
return x |
|
|
|
def forward(self, x, c): |
|
|
|
c = self.encode(c) |
|
if self.emode == 'c3': |
|
x = torch.cat([x, c], dim=1) |
|
elif self.emode == 'c2': |
|
x = torch.cat([x.mean(dim=1, keepdim=True), c], dim=1) |
|
elif self.emode == 'm3': |
|
x = x * c |
|
dx = self.join_encoder(x) |
|
dx = self.out_layer(dx) |
|
return dx, None |
|
|
|
class SecretEncoder5(nn.Module): |
|
"""same as SecretEncoder3 but with ch as input""" |
|
def __init__(self, secret_len, ch=3, base_res=16, resolution=64, joint=False) -> None: |
|
super().__init__() |
|
log_resolution = int(np.log2(resolution)) |
|
log_base = int(np.log2(base_res)) |
|
self.secret_len = secret_len |
|
self.joint = joint |
|
self.resolution = resolution |
|
self.secret_scaler = nn.Sequential( |
|
nn.Linear(secret_len, base_res*base_res*ch), |
|
nn.SiLU(), |
|
View(-1, ch, base_res, base_res), |
|
nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), |
|
) |
|
if joint: |
|
self.join_encoder = nn.Sequential( |
|
conv_nd(2, 2*ch, 2*ch, 3, padding=1), |
|
nn.SiLU(), |
|
conv_nd(2, 2*ch, ch, 3, padding=1), |
|
nn.SiLU() |
|
) |
|
self.out_layer = zero_module(conv_nd(2, ch, ch, 3, padding=1)) |
|
|
|
def copy_encoder_weight(self, ae_model): |
|
|
|
return None |
|
|
|
def encode(self, x): |
|
x = self.secret_scaler(x) |
|
return x |
|
|
|
def forward(self, x, c): |
|
|
|
c = self.encode(c) |
|
if self.joint: |
|
x = thf.interpolate(x, size=(self.resolution, self.resolution), mode="bilinear", align_corners=False, antialias=True) |
|
c = self.join_encoder(torch.cat([x, c], dim=1)) |
|
c = self.out_layer(c) |
|
return c, None |
|
|
|
|
|
class SecretEncoder2(nn.Module): |
|
def __init__(self, secret_len, embed_dim, ddconfig, ckpt_path=None, |
|
ignore_keys=[], |
|
image_key="image", |
|
colorize_nlabels=None, |
|
monitor=None, |
|
ema_decay=None, |
|
learn_logvar=False) -> None: |
|
super().__init__() |
|
log_resolution = int(np.log2(ddconfig.resolution)) |
|
self.secret_len = secret_len |
|
self.learn_logvar = learn_logvar |
|
self.image_key = image_key |
|
self.encoder = Encoder(**ddconfig) |
|
self.encoder.conv_out = zero_module(self.encoder.conv_out) |
|
self.embed_dim = embed_dim |
|
|
|
if colorize_nlabels is not None: |
|
assert type(colorize_nlabels)==int |
|
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) |
|
|
|
if monitor is not None: |
|
self.monitor = monitor |
|
|
|
self.secret_scaler = nn.Sequential( |
|
nn.Linear(secret_len, 32*32*ddconfig.out_ch), |
|
nn.SiLU(), |
|
View(-1, ddconfig.out_ch, 32, 32), |
|
nn.Upsample(scale_factor=(2**(log_resolution-5), 2**(log_resolution-5))), |
|
|
|
) |
|
|
|
|
|
|
|
self.use_ema = ema_decay is not None |
|
if self.use_ema: |
|
self.ema_decay = ema_decay |
|
assert 0. < ema_decay < 1. |
|
self.model_ema = LitEma(self, decay=ema_decay) |
|
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") |
|
|
|
if ckpt_path is not None: |
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) |
|
|
|
|
|
def init_from_ckpt(self, path, ignore_keys=list()): |
|
sd = torch.load(path, map_location="cpu")["state_dict"] |
|
keys = list(sd.keys()) |
|
for k in keys: |
|
for ik in ignore_keys: |
|
if k.startswith(ik): |
|
print("Deleting key {} from state_dict.".format(k)) |
|
del sd[k] |
|
misses, ignores = self.load_state_dict(sd, strict=False) |
|
print(f"[SecretEncoder] Restored from {path}, misses: {misses}, ignores: {ignores}") |
|
|
|
def copy_encoder_weight(self, ae_model): |
|
|
|
return None |
|
self.encoder.load_state_dict(ae_model.encoder.state_dict()) |
|
self.quant_conv.load_state_dict(ae_model.quant_conv.state_dict()) |
|
|
|
@contextmanager |
|
def ema_scope(self, context=None): |
|
if self.use_ema: |
|
self.model_ema.store(self.parameters()) |
|
self.model_ema.copy_to(self) |
|
if context is not None: |
|
print(f"{context}: Switched to EMA weights") |
|
try: |
|
yield None |
|
finally: |
|
if self.use_ema: |
|
self.model_ema.restore(self.parameters()) |
|
if context is not None: |
|
print(f"{context}: Restored training weights") |
|
|
|
def on_train_batch_end(self, *args, **kwargs): |
|
if self.use_ema: |
|
self.model_ema(self) |
|
|
|
def encode(self, x): |
|
h = self.encoder(x) |
|
posterior = h |
|
return posterior |
|
|
|
def forward(self, x, c): |
|
|
|
c = self.secret_scaler(c) |
|
x = torch.cat([x, c], dim=1) |
|
z = self.encode(x) |
|
|
|
return z, None |
|
|
|
|
|
class SecretEncoder7(nn.Module): |
|
def __init__(self, secret_len, ddconfig, ckpt_path=None, |
|
ignore_keys=[],embed_dim=3, |
|
ema_decay=None) -> None: |
|
super().__init__() |
|
log_resolution = int(np.log2(ddconfig.resolution)) |
|
self.secret_len = secret_len |
|
self.encoder = Encoder(**ddconfig) |
|
|
|
self.quant_conv = nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) |
|
|
|
self.secret_scaler = nn.Sequential( |
|
nn.Linear(secret_len, 32*32*2), |
|
nn.SiLU(), |
|
View(-1, 2, 32, 32), |
|
|
|
|
|
) |
|
|
|
|
|
|
|
self.use_ema = ema_decay is not None |
|
if self.use_ema: |
|
self.ema_decay = ema_decay |
|
assert 0. < ema_decay < 1. |
|
self.model_ema = LitEma(self, decay=ema_decay) |
|
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") |
|
|
|
if ckpt_path is not None: |
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) |
|
|
|
|
|
def init_from_ckpt(self, path, ignore_keys=list()): |
|
sd = torch.load(path, map_location="cpu")["state_dict"] |
|
keys = list(sd.keys()) |
|
for k in keys: |
|
for ik in ignore_keys: |
|
if k.startswith(ik): |
|
print("Deleting key {} from state_dict.".format(k)) |
|
del sd[k] |
|
misses, ignores = self.load_state_dict(sd, strict=False) |
|
print(f"[SecretEncoder7] Restored from {path}, misses: {len(misses)}, ignores: {len(ignores)}. Do not worry as we are not using the decoder and the secret encoder is a novel module.") |
|
|
|
def copy_encoder_weight(self, ae_model): |
|
|
|
|
|
self.encoder.load_state_dict(deepcopy(ae_model.encoder.state_dict())) |
|
self.quant_conv.load_state_dict(deepcopy(ae_model.quant_conv.state_dict())) |
|
|
|
@contextmanager |
|
def ema_scope(self, context=None): |
|
if self.use_ema: |
|
self.model_ema.store(self.parameters()) |
|
self.model_ema.copy_to(self) |
|
if context is not None: |
|
print(f"{context}: Switched to EMA weights") |
|
try: |
|
yield None |
|
finally: |
|
if self.use_ema: |
|
self.model_ema.restore(self.parameters()) |
|
if context is not None: |
|
print(f"{context}: Restored training weights") |
|
|
|
def on_train_batch_end(self, *args, **kwargs): |
|
if self.use_ema: |
|
self.model_ema(self) |
|
|
|
def encode(self, x): |
|
h = self.encoder(x) |
|
h = self.quant_conv(h) |
|
return h |
|
|
|
def forward(self, x, c): |
|
|
|
c = self.secret_scaler(c) |
|
|
|
c = thf.interpolate(c, size=x.shape[-2:], mode="nearest") |
|
x = 0.2125 * x[:,0,...] + 0.7154 *x[:,1,...] + 0.0721 * x[:,2,...] |
|
x = torch.cat([x.unsqueeze(1), c], dim=1) |
|
z = self.encode(x) |
|
|
|
return z, None |
|
|
|
class SecretEncoder(nn.Module): |
|
def __init__(self, secret_len, embed_dim, ddconfig, ckpt_path=None, |
|
ignore_keys=[], |
|
image_key="image", |
|
colorize_nlabels=None, |
|
monitor=None, |
|
ema_decay=None, |
|
learn_logvar=False) -> None: |
|
super().__init__() |
|
log_resolution = int(np.log2(ddconfig.resolution)) |
|
self.secret_len = secret_len |
|
self.learn_logvar = learn_logvar |
|
self.image_key = image_key |
|
self.encoder = Encoder(**ddconfig) |
|
assert ddconfig["double_z"] |
|
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) |
|
self.embed_dim = embed_dim |
|
|
|
if colorize_nlabels is not None: |
|
assert type(colorize_nlabels)==int |
|
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) |
|
|
|
if monitor is not None: |
|
self.monitor = monitor |
|
|
|
self.use_ema = ema_decay is not None |
|
if self.use_ema: |
|
self.ema_decay = ema_decay |
|
assert 0. < ema_decay < 1. |
|
self.model_ema = LitEma(self, decay=ema_decay) |
|
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") |
|
|
|
if ckpt_path is not None: |
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) |
|
|
|
self.secret_scaler = nn.Sequential( |
|
nn.Linear(secret_len, 32*32*ddconfig.out_ch), |
|
nn.SiLU(), |
|
View(-1, ddconfig.out_ch, 32, 32), |
|
nn.Upsample(scale_factor=(2**(log_resolution-5), 2**(log_resolution-5))), |
|
zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1)) |
|
) |
|
|
|
self.out_layer = zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1)) |
|
|
|
def init_from_ckpt(self, path, ignore_keys=list()): |
|
sd = torch.load(path, map_location="cpu")["state_dict"] |
|
keys = list(sd.keys()) |
|
for k in keys: |
|
for ik in ignore_keys: |
|
if k.startswith(ik): |
|
print("Deleting key {} from state_dict.".format(k)) |
|
del sd[k] |
|
misses, ignores = self.load_state_dict(sd, strict=False) |
|
print(f"[SecretEncoder] Restored from {path}, misses: {misses}, ignores: {ignores}") |
|
|
|
def copy_encoder_weight(self, ae_model): |
|
|
|
self.encoder.load_state_dict(ae_model.encoder.state_dict()) |
|
self.quant_conv.load_state_dict(ae_model.quant_conv.state_dict()) |
|
|
|
@contextmanager |
|
def ema_scope(self, context=None): |
|
if self.use_ema: |
|
self.model_ema.store(self.parameters()) |
|
self.model_ema.copy_to(self) |
|
if context is not None: |
|
print(f"{context}: Switched to EMA weights") |
|
try: |
|
yield None |
|
finally: |
|
if self.use_ema: |
|
self.model_ema.restore(self.parameters()) |
|
if context is not None: |
|
print(f"{context}: Restored training weights") |
|
|
|
def on_train_batch_end(self, *args, **kwargs): |
|
if self.use_ema: |
|
self.model_ema(self) |
|
|
|
def encode(self, x): |
|
h = self.encoder(x) |
|
moments = self.quant_conv(h) |
|
posterior = DiagonalGaussianDistribution(moments) |
|
return posterior |
|
|
|
def forward(self, x, c): |
|
|
|
c = self.secret_scaler(c) |
|
x = x + c |
|
posterior = self.encode(x) |
|
z = posterior.sample() |
|
z = self.out_layer(z) |
|
return z, posterior |
|
|
|
|
|
class ControlAE(pl.LightningModule): |
|
def __init__(self, |
|
first_stage_key, |
|
first_stage_config, |
|
control_key, |
|
control_config, |
|
decoder_config, |
|
loss_config, |
|
noise_config='__none__', |
|
use_ema=False, |
|
secret_warmup=False, |
|
scale_factor=1., |
|
ckpt_path="__none__", |
|
): |
|
super().__init__() |
|
self.scale_factor = scale_factor |
|
self.control_key = control_key |
|
self.first_stage_key = first_stage_key |
|
self.ae = instantiate_from_config(first_stage_config) |
|
self.control = instantiate_from_config(control_config) |
|
self.decoder = instantiate_from_config(decoder_config) |
|
self.crop = kornia.augmentation.CenterCrop((224, 224), cropping_mode="resample") |
|
if noise_config != '__none__': |
|
print('Using noise') |
|
self.noise = instantiate_from_config(noise_config) |
|
|
|
self.control.copy_encoder_weight(self.ae) |
|
|
|
self.ae.eval() |
|
self.ae.train = disabled_train |
|
for p in self.ae.parameters(): |
|
p.requires_grad = False |
|
|
|
self.loss_layer = instantiate_from_config(loss_config) |
|
|
|
|
|
|
|
self.fixed_x = None |
|
self.fixed_img = None |
|
self.fixed_input_recon = None |
|
self.fixed_control = None |
|
self.register_buffer("fixed_input", torch.tensor(True)) |
|
|
|
|
|
self.secret_warmup = secret_warmup |
|
self.secret_baselen = 2 |
|
self.secret_len = control_config.params.secret_len |
|
if self.secret_warmup: |
|
assert self.secret_len == 2**(int(np.log2(self.secret_len))) |
|
|
|
self.use_ema = use_ema |
|
if self.use_ema: |
|
print('Using EMA') |
|
self.control_ema = LitEma(self.control) |
|
self.decoder_ema = LitEma(self.decoder) |
|
print(f"Keeping EMAs of {len(list(self.control_ema.buffers()) + list(self.decoder_ema.buffers()))}.") |
|
|
|
if ckpt_path != '__none__': |
|
self.init_from_ckpt(ckpt_path, ignore_keys=[]) |
|
|
|
def get_warmup_secret(self, old_secret): |
|
|
|
|
|
if self.secret_warmup: |
|
bsz = old_secret.shape[0] |
|
nrepeats = self.secret_len // self.secret_baselen |
|
new_secret = torch.zeros((bsz, self.secret_baselen), dtype=torch.float).random_(0, 2).repeat_interleave(nrepeats, dim=1) |
|
return new_secret.to(old_secret.device) |
|
else: |
|
return old_secret |
|
|
|
def init_from_ckpt(self, path, ignore_keys=list()): |
|
sd = torch.load(path, map_location="cpu")["state_dict"] |
|
keys = list(sd.keys()) |
|
for k in keys: |
|
for ik in ignore_keys: |
|
if k.startswith(ik): |
|
print("Deleting key {} from state_dict.".format(k)) |
|
del sd[k] |
|
self.load_state_dict(sd, strict=False) |
|
print(f"Restored from {path}") |
|
|
|
@contextmanager |
|
def ema_scope(self, context=None): |
|
if self.use_ema: |
|
self.control_ema.store(self.control.parameters()) |
|
self.decoder_ema.store(self.decoder.parameters()) |
|
self.control_ema.copy_to(self.control) |
|
self.decoder_ema.copy_to(self.decoder) |
|
if context is not None: |
|
print(f"{context}: Switched to EMA weights") |
|
try: |
|
yield None |
|
finally: |
|
if self.use_ema: |
|
self.control_ema.restore(self.control.parameters()) |
|
self.decoder_ema.restore(self.decoder.parameters()) |
|
if context is not None: |
|
print(f"{context}: Restored training weights") |
|
|
|
def on_train_batch_end(self, *args, **kwargs): |
|
if self.use_ema: |
|
self.control_ema(self.control) |
|
self.decoder_ema(self.decoder) |
|
|
|
def compute_loss(self, pred, target): |
|
|
|
lpips_loss = self.lpips_loss(pred, target).mean(dim=[1,2,3]) |
|
pred_yuv = color.rgb_to_yuv((pred + 1) / 2) |
|
target_yuv = color.rgb_to_yuv((target + 1) / 2) |
|
yuv_loss = torch.mean((pred_yuv - target_yuv)**2, dim=[2,3]) |
|
yuv_loss = 1.5*torch.mm(yuv_loss, self.yuv_scales).squeeze(1) |
|
return lpips_loss + yuv_loss |
|
|
|
def forward(self, x, image, c): |
|
if self.control.__class__.__name__ == 'SecretEncoder6': |
|
eps, posterior = self.control(x, c) |
|
else: |
|
eps, posterior = self.control(image, c) |
|
return x + eps, posterior |
|
|
|
@torch.no_grad() |
|
def get_input(self, batch, return_first_stage=False, bs=None): |
|
image = batch[self.first_stage_key] |
|
control = batch[self.control_key] |
|
control = self.get_warmup_secret(control) |
|
if bs is not None: |
|
image = image[:bs] |
|
control = control[:bs] |
|
else: |
|
bs = image.shape[0] |
|
|
|
image = einops.rearrange(image, "b h w c -> b c h w").contiguous() |
|
x = self.encode_first_stage(image).detach() |
|
image_rec = self.decode_first_stage(x).detach() |
|
|
|
|
|
|
|
if self.fixed_input: |
|
if self.fixed_x is None: |
|
print('[TRAINING] Warmup - using fixed input image for now!') |
|
self.fixed_x = x.detach().clone()[:bs] |
|
self.fixed_img = image.detach().clone()[:bs] |
|
self.fixed_input_recon = image_rec.detach().clone()[:bs] |
|
self.fixed_control = control.detach().clone()[:bs] |
|
x, image, image_rec = self.fixed_x, self.fixed_img, self.fixed_input_recon |
|
|
|
out = [x, control] |
|
if return_first_stage: |
|
out.extend([image, image_rec]) |
|
return out |
|
|
|
def decode_first_stage(self, z): |
|
z = 1./self.scale_factor * z |
|
image_rec = self.ae.decode(z) |
|
return image_rec |
|
|
|
def encode_first_stage(self, image): |
|
encoder_posterior = self.ae.encode(image) |
|
if isinstance(encoder_posterior, DiagonalGaussianDistribution): |
|
z = encoder_posterior.sample() |
|
elif isinstance(encoder_posterior, torch.Tensor): |
|
z = encoder_posterior |
|
else: |
|
raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") |
|
return self.scale_factor * z |
|
|
|
def shared_step(self, batch): |
|
x, c, img, _ = self.get_input(batch, return_first_stage=True) |
|
|
|
x, posterior = self(x, img, c) |
|
image_rec = self.decode_first_stage(x) |
|
|
|
if img.shape[-1] > 256: |
|
img = thf.interpolate(img, size=(256, 256), mode='bilinear', align_corners=False).detach() |
|
image_rec = thf.interpolate(image_rec, size=(256, 256), mode='bilinear', align_corners=False) |
|
if hasattr(self, 'noise') and self.noise.is_activated(): |
|
image_rec_noised = self.noise(image_rec, self.global_step, p=0.9) |
|
else: |
|
image_rec_noised = self.crop(image_rec) |
|
image_rec_noised = torch.clamp(image_rec_noised, -1, 1) |
|
pred = self.decoder(image_rec_noised) |
|
|
|
loss, loss_dict = self.loss_layer(img, image_rec, posterior, c, pred, self.global_step) |
|
bit_acc = loss_dict["bit_acc"] |
|
|
|
bit_acc_ = bit_acc.item() |
|
|
|
if (bit_acc_ > 0.98) and (not self.fixed_input) and self.noise.is_activated(): |
|
self.loss_layer.activate_ramp(self.global_step) |
|
|
|
if (bit_acc_ > 0.95) and (not self.fixed_input): |
|
if hasattr(self, 'noise') and (not self.noise.is_activated()): |
|
self.noise.activate(self.global_step) |
|
|
|
if (bit_acc_ > 0.9) and self.fixed_input: |
|
print(f'[TRAINING] High bit acc ({bit_acc_}) achieved, switch to full image dataset training.') |
|
self.fixed_input = ~self.fixed_input |
|
return loss, loss_dict |
|
|
|
def training_step(self, batch, batch_idx): |
|
loss, loss_dict = self.shared_step(batch) |
|
loss_dict = {f"train/{key}": val for key, val in loss_dict.items()} |
|
self.log_dict(loss_dict, prog_bar=True, |
|
logger=True, on_step=True, on_epoch=True) |
|
|
|
self.log("global_step", self.global_step, |
|
prog_bar=True, logger=True, on_step=True, on_epoch=False) |
|
|
|
|
|
|
|
|
|
return loss |
|
|
|
@torch.no_grad() |
|
def validation_step(self, batch, batch_idx): |
|
_, loss_dict_no_ema = self.shared_step(batch) |
|
loss_dict_no_ema = {f"val/{key}": val for key, val in loss_dict_no_ema.items() if key != 'img_lw'} |
|
with self.ema_scope(): |
|
_, loss_dict_ema = self.shared_step(batch) |
|
loss_dict_ema = {'val/' + key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} |
|
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) |
|
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) |
|
|
|
@torch.no_grad() |
|
def log_images(self, batch, fixed_input=False, **kwargs): |
|
log = dict() |
|
if fixed_input and self.fixed_img is not None: |
|
x, c, img, img_recon = self.fixed_x, self.fixed_control, self.fixed_img, self.fixed_input_recon |
|
else: |
|
x, c, img, img_recon = self.get_input(batch, return_first_stage=True) |
|
x, _ = self(x, img, c) |
|
image_out = self.decode_first_stage(x) |
|
if hasattr(self, 'noise') and self.noise.is_activated(): |
|
img_noise = self.noise(image_out, self.global_step, p=1.0) |
|
log['noised'] = img_noise |
|
log['input'] = img |
|
log['output'] = image_out |
|
log['recon'] = img_recon |
|
return log |
|
|
|
def configure_optimizers(self): |
|
lr = self.learning_rate |
|
params = list(self.control.parameters()) + list(self.decoder.parameters()) |
|
optimizer = torch.optim.AdamW(params, lr=lr) |
|
return optimizer |
|
|
|
|
|
|
|
|
|
|
|
|