climateGAN / climategan /painter.py
vict0rsch's picture
initial commit from cc-ai/climateGAN
448ebbd
import torch
import torch.nn as nn
import torch.nn.functional as F
import climategan.strings as strings
from climategan.blocks import InterpolateNearest2d, SPADEResnetBlock
from climategan.norms import SpectralNorm
def create_painter(opts, no_init=False, verbose=0):
if verbose > 0:
print(" - Add PainterSpadeDecoder Painter")
return PainterSpadeDecoder(opts)
class PainterSpadeDecoder(nn.Module):
def __init__(self, opts):
"""Create a SPADE-based decoder, which forwards z and the conditioning
tensors seg (in the original paper, conditioning is on a semantic map only).
All along, z is conditioned on seg. First 3 SpadeResblocks (SRB) do not shrink
the channel dimension, and an upsampling is applied after each. Therefore
2 upsamplings at this point. Then, for each remaining upsamplings
(w.r.t. spade_n_up), the SRB shrinks channels by 2. Before final conv to get 3
channels, the number of channels is therefore:
final_nc = channels(z) * 2 ** (spade_n_up - 2)
Args:
latent_dim (tuple): z's shape (only the number of channels matters)
cond_nc (int): conditioning tensor's expected number of channels
spade_n_up (int): Number of total upsamplings from z
spade_use_spectral_norm (bool): use spectral normalization?
spade_param_free_norm (str): norm to use before SPADE de-normalization
spade_kernel_size (int): SPADE conv layers' kernel size
Returns:
[type]: [description]
"""
super().__init__()
latent_dim = opts.gen.p.latent_dim
cond_nc = 3
spade_n_up = opts.gen.p.spade_n_up
spade_use_spectral_norm = opts.gen.p.spade_use_spectral_norm
spade_param_free_norm = opts.gen.p.spade_param_free_norm
spade_kernel_size = 3
self.z_nc = latent_dim
self.spade_n_up = spade_n_up
self.z_h = self.z_w = None
self.fc = nn.Conv2d(3, latent_dim, 3, padding=1)
self.head_0 = SPADEResnetBlock(
self.z_nc,
self.z_nc,
cond_nc,
spade_use_spectral_norm,
spade_param_free_norm,
spade_kernel_size,
)
self.G_middle_0 = SPADEResnetBlock(
self.z_nc,
self.z_nc,
cond_nc,
spade_use_spectral_norm,
spade_param_free_norm,
spade_kernel_size,
)
self.G_middle_1 = SPADEResnetBlock(
self.z_nc,
self.z_nc,
cond_nc,
spade_use_spectral_norm,
spade_param_free_norm,
spade_kernel_size,
)
self.up_spades = nn.Sequential(
*[
SPADEResnetBlock(
self.z_nc // 2 ** i,
self.z_nc // 2 ** (i + 1),
cond_nc,
spade_use_spectral_norm,
spade_param_free_norm,
spade_kernel_size,
)
for i in range(spade_n_up - 2)
]
)
self.final_nc = self.z_nc // 2 ** (spade_n_up - 2)
self.final_spade = SPADEResnetBlock(
self.final_nc,
self.final_nc,
cond_nc,
spade_use_spectral_norm,
spade_param_free_norm,
spade_kernel_size,
)
self.final_shortcut = None
if opts.gen.p.use_final_shortcut:
self.final_shortcut = nn.Sequential(
*[
SpectralNorm(nn.Conv2d(self.final_nc, 3, 1)),
nn.BatchNorm2d(3),
nn.LeakyReLU(0.2, True),
]
)
self.conv_img = nn.Conv2d(self.final_nc, 3, 3, padding=1)
self.upsample = InterpolateNearest2d(scale_factor=2)
def set_latent_shape(self, shape, is_input=True):
"""
Sets the latent shape to start the upsampling from, i.e. z_h and z_w.
If is_input is True, then this is the actual input shape which should
be divided by 2 ** spade_n_up
Otherwise, just sets z_h and z_w from shape[-2] and shape[-1]
Args:
shape (tuple): The shape to start sampling from.
is_input (bool, optional): Whether to divide shape by 2 ** spade_n_up
"""
if isinstance(shape, (list, tuple)):
self.z_h = shape[-2]
self.z_w = shape[-1]
elif isinstance(shape, int):
self.z_h = self.z_w = shape
else:
raise ValueError("Unknown shape type:", shape)
if is_input:
self.z_h = self.z_h // (2 ** self.spade_n_up)
self.z_w = self.z_w // (2 ** self.spade_n_up)
def _apply(self, fn):
# print("Applying SpadeDecoder", fn)
super()._apply(fn)
# self.head_0 = fn(self.head_0)
# self.G_middle_0 = fn(self.G_middle_0)
# self.G_middle_1 = fn(self.G_middle_1)
# for i, up in enumerate(self.up_spades):
# self.up_spades[i] = fn(up)
# self.conv_img = fn(self.conv_img)
return self
def forward(self, z, cond):
if z is None:
assert self.z_h is not None and self.z_w is not None
z = self.fc(F.interpolate(cond, size=(self.z_h, self.z_w)))
y = self.head_0(z, cond)
y = self.upsample(y)
y = self.G_middle_0(y, cond)
y = self.upsample(y)
y = self.G_middle_1(y, cond)
for i, up in enumerate(self.up_spades):
y = self.upsample(y)
y = up(y, cond)
if self.final_shortcut is not None:
cond = self.final_shortcut(y)
y = self.final_spade(y, cond)
y = self.conv_img(F.leaky_relu(y, 2e-1))
y = torch.tanh(y)
return y
def __str__(self):
return strings.spadedecoder(self)