|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from climategan.blocks import ( |
|
BaseDecoder, |
|
Conv2dBlock, |
|
InterpolateNearest2d, |
|
SPADEResnetBlock, |
|
) |
|
|
|
|
|
def create_mask_decoder(opts, no_init=False, verbose=0): |
|
if opts.gen.m.use_spade: |
|
if verbose > 0: |
|
print(" - Add Spade Mask Decoder") |
|
assert "d" in opts.tasks or "s" in opts.tasks |
|
return MaskSpadeDecoder(opts) |
|
else: |
|
if verbose > 0: |
|
print(" - Add Base Mask Decoder") |
|
return MaskBaseDecoder(opts) |
|
|
|
|
|
class MaskBaseDecoder(BaseDecoder): |
|
def __init__(self, opts): |
|
low_level_feats_dim = -1 |
|
use_v3 = opts.gen.encoder.architecture == "deeplabv3" |
|
use_mobile_net = opts.gen.deeplabv3.backbone == "mobilenet" |
|
use_low = opts.gen.m.use_low_level_feats |
|
use_dada = ("d" in opts.tasks) and opts.gen.m.use_dada |
|
|
|
if use_v3 and use_mobile_net: |
|
input_dim = 320 |
|
if use_low: |
|
low_level_feats_dim = 24 |
|
elif use_v3: |
|
input_dim = 2048 |
|
if use_low: |
|
low_level_feats_dim = 256 |
|
else: |
|
input_dim = 2048 |
|
|
|
super().__init__( |
|
n_upsample=opts.gen.m.n_upsample, |
|
n_res=opts.gen.m.n_res, |
|
input_dim=input_dim, |
|
proj_dim=opts.gen.m.proj_dim, |
|
output_dim=opts.gen.m.output_dim, |
|
norm=opts.gen.m.norm, |
|
activ=opts.gen.m.activ, |
|
pad_type=opts.gen.m.pad_type, |
|
output_activ="none", |
|
low_level_feats_dim=low_level_feats_dim, |
|
use_dada=use_dada, |
|
) |
|
|
|
|
|
class MaskSpadeDecoder(nn.Module): |
|
def __init__(self, opts): |
|
"""Create a SPADE-based decoder, which forwards z and the conditioning |
|
tensors seg (in the original paper, conditioning is on a semantic map only). |
|
All along, z is conditioned on seg. First 3 SpadeResblocks (SRB) do not shrink |
|
the channel dimension, and an upsampling is applied after each. Therefore |
|
2 upsamplings at this point. Then, for each remaining upsamplings |
|
(w.r.t. spade_n_up), the SRB shrinks channels by 2. Before final conv to get 3 |
|
channels, the number of channels is therefore: |
|
final_nc = channels(z) * 2 ** (spade_n_up - 2) |
|
Args: |
|
latent_dim (tuple): z's shape (only the number of channels matters) |
|
cond_nc (int): conditioning tensor's expected number of channels |
|
spade_n_up (int): Number of total upsamplings from z |
|
spade_use_spectral_norm (bool): use spectral normalization? |
|
spade_param_free_norm (str): norm to use before SPADE de-normalization |
|
spade_kernel_size (int): SPADE conv layers' kernel size |
|
Returns: |
|
[type]: [description] |
|
""" |
|
super().__init__() |
|
self.opts = opts |
|
latent_dim = opts.gen.m.spade.latent_dim |
|
cond_nc = opts.gen.m.spade.cond_nc |
|
spade_use_spectral_norm = opts.gen.m.spade.spade_use_spectral_norm |
|
spade_param_free_norm = opts.gen.m.spade.spade_param_free_norm |
|
if self.opts.gen.m.spade.activations.all_lrelu: |
|
spade_activation = "lrelu" |
|
else: |
|
spade_activation = None |
|
spade_kernel_size = 3 |
|
self.num_layers = opts.gen.m.spade.num_layers |
|
self.z_nc = latent_dim |
|
|
|
if ( |
|
opts.gen.encoder.architecture == "deeplabv3" |
|
and opts.gen.deeplabv3.backbone == "mobilenet" |
|
): |
|
self.input_dim = [320, 24] |
|
self.low_level_conv = Conv2dBlock( |
|
self.input_dim[1], |
|
self.input_dim[0], |
|
3, |
|
padding=1, |
|
activation="lrelu", |
|
pad_type="reflect", |
|
norm="spectral_batch", |
|
) |
|
self.merge_feats_conv = Conv2dBlock( |
|
self.input_dim[0] * 2, |
|
self.z_nc, |
|
3, |
|
padding=1, |
|
activation="lrelu", |
|
pad_type="reflect", |
|
norm="spectral_batch", |
|
) |
|
elif ( |
|
opts.gen.encoder.architecture == "deeplabv3" |
|
and opts.gen.deeplabv3.backbone == "resnet" |
|
): |
|
self.input_dim = [2048, 256] |
|
if self.opts.gen.m.use_proj: |
|
proj_dim = self.opts.gen.m.proj_dim |
|
self.low_level_conv = Conv2dBlock( |
|
self.input_dim[1], |
|
proj_dim, |
|
3, |
|
padding=1, |
|
activation="lrelu", |
|
pad_type="reflect", |
|
norm="spectral_batch", |
|
) |
|
self.high_level_conv = Conv2dBlock( |
|
self.input_dim[0], |
|
proj_dim, |
|
3, |
|
padding=1, |
|
activation="lrelu", |
|
pad_type="reflect", |
|
norm="spectral_batch", |
|
) |
|
self.merge_feats_conv = Conv2dBlock( |
|
proj_dim * 2, |
|
self.z_nc, |
|
3, |
|
padding=1, |
|
activation="lrelu", |
|
pad_type="reflect", |
|
norm="spectral_batch", |
|
) |
|
else: |
|
self.low_level_conv = Conv2dBlock( |
|
self.input_dim[1], |
|
self.input_dim[0], |
|
3, |
|
padding=1, |
|
activation="lrelu", |
|
pad_type="reflect", |
|
norm="spectral_batch", |
|
) |
|
self.merge_feats_conv = Conv2dBlock( |
|
self.input_dim[0] * 2, |
|
self.z_nc, |
|
3, |
|
padding=1, |
|
activation="lrelu", |
|
pad_type="reflect", |
|
norm="spectral_batch", |
|
) |
|
|
|
elif opts.gen.encoder.architecture == "deeplabv2": |
|
self.input_dim = 2048 |
|
self.fc_conv = Conv2dBlock( |
|
self.input_dim, |
|
self.z_nc, |
|
3, |
|
padding=1, |
|
activation="lrelu", |
|
pad_type="reflect", |
|
norm="spectral_batch", |
|
) |
|
else: |
|
raise ValueError("Unknown encoder type") |
|
|
|
self.spade_blocks = [] |
|
|
|
for i in range(self.num_layers): |
|
self.spade_blocks.append( |
|
SPADEResnetBlock( |
|
int(self.z_nc / (2**i)), |
|
int(self.z_nc / (2 ** (i + 1))), |
|
cond_nc, |
|
spade_use_spectral_norm, |
|
spade_param_free_norm, |
|
spade_kernel_size, |
|
spade_activation, |
|
) |
|
) |
|
self.spade_blocks = nn.Sequential(*self.spade_blocks) |
|
|
|
self.final_nc = int(self.z_nc / (2**self.num_layers)) |
|
self.mask_conv = Conv2dBlock( |
|
self.final_nc, |
|
1, |
|
3, |
|
padding=1, |
|
activation="none", |
|
pad_type="reflect", |
|
norm="spectral", |
|
) |
|
self.upsample = InterpolateNearest2d(scale_factor=2) |
|
|
|
def forward(self, z, cond, z_depth=None): |
|
if isinstance(z, (list, tuple)): |
|
z_h, z_l = z |
|
if self.opts.gen.m.use_proj: |
|
z_l = self.low_level_conv(z_l) |
|
z_l = F.interpolate(z_l, size=z_h.shape[-2:], mode="bilinear") |
|
z_h = self.high_level_conv(z_h) |
|
else: |
|
z_l = self.low_level_conv(z_l) |
|
z_l = F.interpolate(z_l, size=z_h.shape[-2:], mode="bilinear") |
|
z = torch.cat([z_h, z_l], axis=1) |
|
y = self.merge_feats_conv(z) |
|
else: |
|
y = self.fc_conv(z) |
|
|
|
for i in range(self.num_layers): |
|
y = self.spade_blocks[i](y, cond) |
|
y = self.upsample(y) |
|
y = self.mask_conv(y) |
|
return y |
|
|
|
def __str__(self): |
|
return "MaskerSpadeDecoder" |
|
|