|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import climategan.strings as strings |
|
from climategan.blocks import InterpolateNearest2d, SPADEResnetBlock |
|
from climategan.norms import SpectralNorm |
|
|
|
|
|
def create_painter(opts, no_init=False, verbose=0): |
|
if verbose > 0: |
|
print(" - Add PainterSpadeDecoder Painter") |
|
return PainterSpadeDecoder(opts) |
|
|
|
|
|
class PainterSpadeDecoder(nn.Module): |
|
def __init__(self, opts): |
|
"""Create a SPADE-based decoder, which forwards z and the conditioning |
|
tensors seg (in the original paper, conditioning is on a semantic map only). |
|
All along, z is conditioned on seg. First 3 SpadeResblocks (SRB) do not shrink |
|
the channel dimension, and an upsampling is applied after each. Therefore |
|
2 upsamplings at this point. Then, for each remaining upsamplings |
|
(w.r.t. spade_n_up), the SRB shrinks channels by 2. Before final conv to get 3 |
|
channels, the number of channels is therefore: |
|
final_nc = channels(z) * 2 ** (spade_n_up - 2) |
|
Args: |
|
latent_dim (tuple): z's shape (only the number of channels matters) |
|
cond_nc (int): conditioning tensor's expected number of channels |
|
spade_n_up (int): Number of total upsamplings from z |
|
spade_use_spectral_norm (bool): use spectral normalization? |
|
spade_param_free_norm (str): norm to use before SPADE de-normalization |
|
spade_kernel_size (int): SPADE conv layers' kernel size |
|
Returns: |
|
[type]: [description] |
|
""" |
|
super().__init__() |
|
|
|
latent_dim = opts.gen.p.latent_dim |
|
cond_nc = 3 |
|
spade_n_up = opts.gen.p.spade_n_up |
|
spade_use_spectral_norm = opts.gen.p.spade_use_spectral_norm |
|
spade_param_free_norm = opts.gen.p.spade_param_free_norm |
|
spade_kernel_size = 3 |
|
|
|
self.z_nc = latent_dim |
|
self.spade_n_up = spade_n_up |
|
|
|
self.z_h = self.z_w = None |
|
|
|
self.fc = nn.Conv2d(3, latent_dim, 3, padding=1) |
|
self.head_0 = SPADEResnetBlock( |
|
self.z_nc, |
|
self.z_nc, |
|
cond_nc, |
|
spade_use_spectral_norm, |
|
spade_param_free_norm, |
|
spade_kernel_size, |
|
) |
|
|
|
self.G_middle_0 = SPADEResnetBlock( |
|
self.z_nc, |
|
self.z_nc, |
|
cond_nc, |
|
spade_use_spectral_norm, |
|
spade_param_free_norm, |
|
spade_kernel_size, |
|
) |
|
self.G_middle_1 = SPADEResnetBlock( |
|
self.z_nc, |
|
self.z_nc, |
|
cond_nc, |
|
spade_use_spectral_norm, |
|
spade_param_free_norm, |
|
spade_kernel_size, |
|
) |
|
|
|
self.up_spades = nn.Sequential( |
|
*[ |
|
SPADEResnetBlock( |
|
self.z_nc // 2 ** i, |
|
self.z_nc // 2 ** (i + 1), |
|
cond_nc, |
|
spade_use_spectral_norm, |
|
spade_param_free_norm, |
|
spade_kernel_size, |
|
) |
|
for i in range(spade_n_up - 2) |
|
] |
|
) |
|
|
|
self.final_nc = self.z_nc // 2 ** (spade_n_up - 2) |
|
|
|
self.final_spade = SPADEResnetBlock( |
|
self.final_nc, |
|
self.final_nc, |
|
cond_nc, |
|
spade_use_spectral_norm, |
|
spade_param_free_norm, |
|
spade_kernel_size, |
|
) |
|
self.final_shortcut = None |
|
if opts.gen.p.use_final_shortcut: |
|
self.final_shortcut = nn.Sequential( |
|
*[ |
|
SpectralNorm(nn.Conv2d(self.final_nc, 3, 1)), |
|
nn.BatchNorm2d(3), |
|
nn.LeakyReLU(0.2, True), |
|
] |
|
) |
|
|
|
self.conv_img = nn.Conv2d(self.final_nc, 3, 3, padding=1) |
|
|
|
self.upsample = InterpolateNearest2d(scale_factor=2) |
|
|
|
def set_latent_shape(self, shape, is_input=True): |
|
""" |
|
Sets the latent shape to start the upsampling from, i.e. z_h and z_w. |
|
If is_input is True, then this is the actual input shape which should |
|
be divided by 2 ** spade_n_up |
|
Otherwise, just sets z_h and z_w from shape[-2] and shape[-1] |
|
|
|
Args: |
|
shape (tuple): The shape to start sampling from. |
|
is_input (bool, optional): Whether to divide shape by 2 ** spade_n_up |
|
""" |
|
if isinstance(shape, (list, tuple)): |
|
self.z_h = shape[-2] |
|
self.z_w = shape[-1] |
|
elif isinstance(shape, int): |
|
self.z_h = self.z_w = shape |
|
else: |
|
raise ValueError("Unknown shape type:", shape) |
|
|
|
if is_input: |
|
self.z_h = self.z_h // (2 ** self.spade_n_up) |
|
self.z_w = self.z_w // (2 ** self.spade_n_up) |
|
|
|
def _apply(self, fn): |
|
|
|
super()._apply(fn) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return self |
|
|
|
def forward(self, z, cond): |
|
if z is None: |
|
assert self.z_h is not None and self.z_w is not None |
|
z = self.fc(F.interpolate(cond, size=(self.z_h, self.z_w))) |
|
y = self.head_0(z, cond) |
|
y = self.upsample(y) |
|
y = self.G_middle_0(y, cond) |
|
y = self.upsample(y) |
|
y = self.G_middle_1(y, cond) |
|
|
|
for i, up in enumerate(self.up_spades): |
|
y = self.upsample(y) |
|
y = up(y, cond) |
|
|
|
if self.final_shortcut is not None: |
|
cond = self.final_shortcut(y) |
|
y = self.final_spade(y, cond) |
|
y = self.conv_img(F.leaky_relu(y, 2e-1)) |
|
y = torch.tanh(y) |
|
return y |
|
|
|
def __str__(self): |
|
return strings.spadedecoder(self) |
|
|