climateGAN / climategan /discriminator.py
vict0rsch's picture
initial commit from cc-ai/climateGAN
448ebbd
"""Discriminator architecture for ClimateGAN's GAN components (a and t)
"""
import functools
import torch
import torch.nn as nn
from climategan.blocks import SpectralNorm
from climategan.tutils import init_weights
# from torch.optim import lr_scheduler
# mainly from https://github.com/sangwoomo/instagan/blob/master/models/networks.py
def create_discriminator(opts, device, no_init=False, verbose=0):
disc = OmniDiscriminator(opts)
if no_init:
return disc
for task, model in disc.items():
if isinstance(model, nn.ModuleDict):
for domain, domain_model in model.items():
init_weights(
domain_model,
init_type=opts.dis[task].init_type,
init_gain=opts.dis[task].init_gain,
verbose=verbose,
caller=f"create_discriminator {task} {domain}",
)
else:
init_weights(
model,
init_type=opts.dis[task].init_type,
init_gain=opts.dis[task].init_gain,
verbose=verbose,
caller=f"create_discriminator {task}",
)
return disc.to(device)
def define_D(
input_nc,
ndf,
n_layers=3,
norm="batch",
use_sigmoid=False,
get_intermediate_features=False,
num_D=1,
):
norm_layer = get_norm_layer(norm_type=norm)
net = MultiscaleDiscriminator(
input_nc,
ndf,
n_layers=n_layers,
norm_layer=norm_layer,
use_sigmoid=use_sigmoid,
get_intermediate_features=get_intermediate_features,
num_D=num_D,
)
return net
def get_norm_layer(norm_type="instance"):
if not norm_type:
print("norm_type is {}, defaulting to instance")
norm_type = "instance"
if norm_type == "batch":
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == "instance":
norm_layer = functools.partial(
nn.InstanceNorm2d, affine=False, track_running_stats=False
)
elif norm_type == "none":
norm_layer = None
else:
raise NotImplementedError("normalization layer [%s] is not found" % norm_type)
return norm_layer
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
def __init__(
self,
input_nc=3,
ndf=64,
n_layers=3,
norm_layer=nn.BatchNorm2d,
use_sigmoid=False,
get_intermediate_features=True,
):
super(NLayerDiscriminator, self).__init__()
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
self.get_intermediate_features = get_intermediate_features
kw = 4
padw = 1
sequence = [
[
# Use spectral normalization
SpectralNorm(
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw)
),
nn.LeakyReLU(0.2, True),
]
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
[
# Use spectral normalization
SpectralNorm( # TODO replace with Conv2dBlock
nn.Conv2d(
ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=2,
padding=padw,
bias=use_bias,
)
),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True),
]
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
[
# Use spectral normalization
SpectralNorm(
nn.Conv2d(
ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=1,
padding=padw,
bias=use_bias,
)
),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True),
]
]
# Use spectral normalization
sequence += [
[
SpectralNorm(
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
)
]
]
if use_sigmoid:
sequence += [[nn.Sigmoid()]]
# We divide the layers into groups to extract intermediate layer outputs
for n in range(len(sequence)):
self.add_module("model" + str(n), nn.Sequential(*sequence[n]))
# self.model = nn.Sequential(*sequence)
def forward(self, input):
results = [input]
for submodel in self.children():
intermediate_output = submodel(results[-1])
results.append(intermediate_output)
get_intermediate_features = self.get_intermediate_features
if get_intermediate_features:
return results[1:]
else:
return results[-1]
# def forward(self, input):
# return self.model(input)
# Source: https://github.com/NVIDIA/pix2pixHD
class MultiscaleDiscriminator(nn.Module):
def __init__(
self,
input_nc=3,
ndf=64,
n_layers=3,
norm_layer=nn.BatchNorm2d,
use_sigmoid=False,
get_intermediate_features=True,
num_D=3,
):
super(MultiscaleDiscriminator, self).__init__()
# self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
# use_sigmoid=False, num_D=3, getIntermFeat=False
self.n_layers = n_layers
self.ndf = ndf
self.norm_layer = norm_layer
self.use_sigmoid = use_sigmoid
self.get_intermediate_features = get_intermediate_features
self.num_D = num_D
for i in range(self.num_D):
netD = NLayerDiscriminator(
input_nc=input_nc,
ndf=self.ndf,
n_layers=self.n_layers,
norm_layer=self.norm_layer,
use_sigmoid=self.use_sigmoid,
get_intermediate_features=self.get_intermediate_features,
)
self.add_module("discriminator_%d" % i, netD)
self.downsample = nn.AvgPool2d(
3, stride=2, padding=[1, 1], count_include_pad=False
)
def forward(self, input):
result = []
get_intermediate_features = self.get_intermediate_features
for name, D in self.named_children():
if "discriminator" not in name:
continue
out = D(input)
if not get_intermediate_features:
out = [out]
result.append(out)
input = self.downsample(input)
return result
class OmniDiscriminator(nn.ModuleDict):
def __init__(self, opts):
super().__init__()
if "p" in opts.tasks:
if opts.dis.p.use_local_discriminator:
self["p"] = nn.ModuleDict(
{
"global": define_D(
input_nc=3,
ndf=opts.dis.p.ndf,
n_layers=opts.dis.p.n_layers,
norm=opts.dis.p.norm,
use_sigmoid=opts.dis.p.use_sigmoid,
get_intermediate_features=opts.dis.p.get_intermediate_features, # noqa: E501
num_D=opts.dis.p.num_D,
),
"local": define_D(
input_nc=3,
ndf=opts.dis.p.ndf,
n_layers=opts.dis.p.n_layers,
norm=opts.dis.p.norm,
use_sigmoid=opts.dis.p.use_sigmoid,
get_intermediate_features=opts.dis.p.get_intermediate_features, # noqa: E501
num_D=opts.dis.p.num_D,
),
}
)
else:
self["p"] = define_D(
input_nc=4, # image + mask
ndf=opts.dis.p.ndf,
n_layers=opts.dis.p.n_layers,
norm=opts.dis.p.norm,
use_sigmoid=opts.dis.p.use_sigmoid,
get_intermediate_features=opts.dis.p.get_intermediate_features,
num_D=opts.dis.p.num_D,
)
if "m" in opts.tasks:
if opts.gen.m.use_advent:
if opts.dis.m.architecture == "base":
if opts.dis.m.gan_type == "WGAN_norm":
self["m"] = nn.ModuleDict(
{
"Advent": get_fc_discriminator(
num_classes=2, use_norm=True
)
}
)
else:
self["m"] = nn.ModuleDict(
{
"Advent": get_fc_discriminator(
num_classes=2, use_norm=False
)
}
)
elif opts.dis.m.architecture == "OmniDiscriminator":
self["m"] = nn.ModuleDict(
{
"Advent": define_D(
input_nc=2,
ndf=opts.dis.m.ndf,
n_layers=opts.dis.m.n_layers,
norm=opts.dis.m.norm,
use_sigmoid=opts.dis.m.use_sigmoid,
get_intermediate_features=opts.dis.m.get_intermediate_features, # noqa: E501
num_D=opts.dis.m.num_D,
)
}
)
else:
raise Exception("This Discriminator is currently not supported!")
if "s" in opts.tasks:
if opts.gen.s.use_advent:
if opts.dis.s.gan_type == "WGAN_norm":
self["s"] = nn.ModuleDict(
{"Advent": get_fc_discriminator(num_classes=11, use_norm=True)}
)
else:
self["s"] = nn.ModuleDict(
{"Advent": get_fc_discriminator(num_classes=11, use_norm=False)}
)
def get_fc_discriminator(num_classes=2, ndf=64, use_norm=False):
if use_norm:
return torch.nn.Sequential(
SpectralNorm(
torch.nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1)
),
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
SpectralNorm(
torch.nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1)
),
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
SpectralNorm(
torch.nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1)
),
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
SpectralNorm(
torch.nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1)
),
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
SpectralNorm(
torch.nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=2, padding=1)
),
)
else:
return torch.nn.Sequential(
torch.nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1),
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
torch.nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1),
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
torch.nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1),
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
torch.nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1),
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
torch.nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=2, padding=1),
)