resefa / models /pggan_generator.py
akhaliq's picture
akhaliq HF staff
add files
8ca3a29
raw
history blame
16 kB
# python3.7
"""Contains the implementation of generator described in PGGAN.
Paper: https://arxiv.org/pdf/1710.10196.pdf
Official TensorFlow implementation:
https://github.com/tkarras/progressive_growing_of_gans
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['PGGANGenerator']
# Resolutions allowed.
_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
# pylint: disable=missing-function-docstring
class PGGANGenerator(nn.Module):
"""Defines the generator network in PGGAN.
NOTE: The synthesized images are with `RGB` channel order and pixel range
[-1, 1].
Settings for the network:
(1) resolution: The resolution of the output image.
(2) init_res: The initial resolution to start with convolution. (default: 4)
(3) z_dim: Dimension of the input latent space, Z. (default: 512)
(4) image_channels: Number of channels of the output image. (default: 3)
(5) final_tanh: Whether to use `tanh` to control the final pixel range.
(default: False)
(6) label_dim: Dimension of the additional label for conditional generation.
In one-hot conditioning case, it is equal to the number of classes. If
set to 0, conditioning training will be disabled. (default: 0)
(7) fused_scale: Whether to fused `upsample` and `conv2d` together,
resulting in `conv2d_transpose`. (default: False)
(8) use_wscale: Whether to use weight scaling. (default: True)
(9) wscale_gain: The factor to control weight scaling. (default: sqrt(2.0))
(10) fmaps_base: Factor to control number of feature maps for each layer.
(default: 16 << 10)
(11) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
(12) eps: A small value to avoid divide overflow. (default: 1e-8)
"""
def __init__(self,
resolution,
init_res=4,
z_dim=512,
image_channels=3,
final_tanh=False,
label_dim=0,
fused_scale=False,
use_wscale=True,
wscale_gain=np.sqrt(2.0),
fmaps_base=16 << 10,
fmaps_max=512,
eps=1e-8):
"""Initializes with basic settings.
Raises:
ValueError: If the `resolution` is not supported.
"""
super().__init__()
if resolution not in _RESOLUTIONS_ALLOWED:
raise ValueError(f'Invalid resolution: `{resolution}`!\n'
f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
self.init_res = init_res
self.init_res_log2 = int(np.log2(self.init_res))
self.resolution = resolution
self.final_res_log2 = int(np.log2(self.resolution))
self.z_dim = z_dim
self.image_channels = image_channels
self.final_tanh = final_tanh
self.label_dim = label_dim
self.fused_scale = fused_scale
self.use_wscale = use_wscale
self.wscale_gain = wscale_gain
self.fmaps_base = fmaps_base
self.fmaps_max = fmaps_max
self.eps = eps
# Dimension of latent space, which is convenient for sampling.
self.latent_dim = (self.z_dim,)
# Number of convolutional layers.
self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2
# Level-of-details (used for progressive training).
self.register_buffer('lod', torch.zeros(()))
self.pth_to_tf_var_mapping = {'lod': 'lod'}
for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1):
res = 2 ** res_log2
in_channels = self.get_nf(res // 2)
out_channels = self.get_nf(res)
block_idx = res_log2 - self.init_res_log2
# First convolution layer for each resolution.
if res == self.init_res:
self.add_module(
f'layer{2 * block_idx}',
ConvLayer(in_channels=z_dim + label_dim,
out_channels=out_channels,
kernel_size=init_res,
padding=init_res - 1,
add_bias=True,
upsample=False,
fused_scale=False,
use_wscale=use_wscale,
wscale_gain=wscale_gain,
activation_type='lrelu',
eps=eps))
tf_layer_name = 'Dense'
else:
self.add_module(
f'layer{2 * block_idx}',
ConvLayer(in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
add_bias=True,
upsample=True,
fused_scale=fused_scale,
use_wscale=use_wscale,
wscale_gain=wscale_gain,
activation_type='lrelu',
eps=eps))
tf_layer_name = 'Conv0_up' if fused_scale else 'Conv0'
self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.weight'] = (
f'{res}x{res}/{tf_layer_name}/weight')
self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.bias'] = (
f'{res}x{res}/{tf_layer_name}/bias')
# Second convolution layer for each resolution.
self.add_module(
f'layer{2 * block_idx + 1}',
ConvLayer(in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
add_bias=True,
upsample=False,
fused_scale=False,
use_wscale=use_wscale,
wscale_gain=wscale_gain,
activation_type='lrelu',
eps=eps))
tf_layer_name = 'Conv' if res == self.init_res else 'Conv1'
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.weight'] = (
f'{res}x{res}/{tf_layer_name}/weight')
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.bias'] = (
f'{res}x{res}/{tf_layer_name}/bias')
# Output convolution layer for each resolution.
self.add_module(
f'output{block_idx}',
ConvLayer(in_channels=out_channels,
out_channels=image_channels,
kernel_size=1,
padding=0,
add_bias=True,
upsample=False,
fused_scale=False,
use_wscale=use_wscale,
wscale_gain=1.0,
activation_type='linear',
eps=eps))
self.pth_to_tf_var_mapping[f'output{block_idx}.weight'] = (
f'ToRGB_lod{self.final_res_log2 - res_log2}/weight')
self.pth_to_tf_var_mapping[f'output{block_idx}.bias'] = (
f'ToRGB_lod{self.final_res_log2 - res_log2}/bias')
def get_nf(self, res):
"""Gets number of feature maps according to the given resolution."""
return min(self.fmaps_base // res, self.fmaps_max)
def forward(self, z, label=None, lod=None):
if z.ndim != 2 or z.shape[1] != self.z_dim:
raise ValueError(f'Input latent code should be with shape '
f'[batch_size, latent_dim], where '
f'`latent_dim` equals to {self.z_dim}!\n'
f'But `{z.shape}` is received!')
z = self.layer0.pixel_norm(z)
if self.label_dim:
if label is None:
raise ValueError(f'Model requires an additional label '
f'(with size {self.label_dim}) as input, '
f'but no label is received!')
if label.ndim != 2 or label.shape != (z.shape[0], self.label_dim):
raise ValueError(f'Input label should be with shape '
f'[batch_size, label_dim], where '
f'`batch_size` equals to that of '
f'latent codes ({z.shape[0]}) and '
f'`label_dim` equals to {self.label_dim}!\n'
f'But `{label.shape}` is received!')
label = label.to(dtype=torch.float32)
z = torch.cat((z, label), dim=1)
lod = self.lod.item() if lod is None else lod
if lod + self.init_res_log2 > self.final_res_log2:
raise ValueError(f'Maximum level-of-details (lod) is '
f'{self.final_res_log2 - self.init_res_log2}, '
f'but `{lod}` is received!')
x = z.view(z.shape[0], self.z_dim + self.label_dim, 1, 1)
for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1):
current_lod = self.final_res_log2 - res_log2
block_idx = res_log2 - self.init_res_log2
if lod < current_lod + 1:
x = getattr(self, f'layer{2 * block_idx}')(x)
x = getattr(self, f'layer{2 * block_idx + 1}')(x)
if current_lod - 1 < lod <= current_lod:
image = getattr(self, f'output{block_idx}')(x)
elif current_lod < lod < current_lod + 1:
alpha = np.ceil(lod) - lod
temp = getattr(self, f'output{block_idx}')(x)
image = F.interpolate(image, scale_factor=2, mode='nearest')
image = temp * alpha + image * (1 - alpha)
elif lod >= current_lod + 1:
image = F.interpolate(image, scale_factor=2, mode='nearest')
if self.final_tanh:
image = torch.tanh(image)
results = {
'z': z,
'label': label,
'image': image,
}
return results
class PixelNormLayer(nn.Module):
"""Implements pixel-wise feature vector normalization layer."""
def __init__(self, dim, eps):
super().__init__()
self.dim = dim
self.eps = eps
def extra_repr(self):
return f'dim={self.dim}, epsilon={self.eps}'
def forward(self, x):
scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt()
return x * scale
class UpsamplingLayer(nn.Module):
"""Implements the upsampling layer.
Basically, this layer can be used to upsample feature maps with nearest
neighbor interpolation.
"""
def __init__(self, scale_factor):
super().__init__()
self.scale_factor = scale_factor
def extra_repr(self):
return f'factor={self.scale_factor}'
def forward(self, x):
if self.scale_factor <= 1:
return x
return F.interpolate(x, scale_factor=self.scale_factor, mode='nearest')
class ConvLayer(nn.Module):
"""Implements the convolutional layer.
Basically, this layer executes pixel-wise normalization, upsampling (if
needed), convolution, and activation in sequence.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
padding,
add_bias,
upsample,
fused_scale,
use_wscale,
wscale_gain,
activation_type,
eps):
"""Initializes with layer settings.
Args:
in_channels: Number of channels of the input tensor.
out_channels: Number of channels of the output tensor.
kernel_size: Size of the convolutional kernels.
padding: Padding used in convolution.
add_bias: Whether to add bias onto the convolutional result.
upsample: Whether to upsample the input tensor before convolution.
fused_scale: Whether to fused `upsample` and `conv2d` together,
resulting in `conv2d_transpose`.
use_wscale: Whether to use weight scaling.
wscale_gain: Gain factor for weight scaling.
activation_type: Type of activation.
eps: A small value to avoid divide overflow.
"""
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.add_bias = add_bias
self.upsample = upsample
self.fused_scale = fused_scale
self.use_wscale = use_wscale
self.wscale_gain = wscale_gain
self.activation_type = activation_type
self.eps = eps
self.pixel_norm = PixelNormLayer(dim=1, eps=eps)
if upsample and not fused_scale:
self.up = UpsamplingLayer(scale_factor=2)
else:
self.up = nn.Identity()
if upsample and fused_scale:
self.use_conv2d_transpose = True
weight_shape = (in_channels, out_channels, kernel_size, kernel_size)
self.stride = 2
self.padding = 1
else:
self.use_conv2d_transpose = False
weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
self.stride = 1
fan_in = kernel_size * kernel_size * in_channels
wscale = wscale_gain / np.sqrt(fan_in)
if use_wscale:
self.weight = nn.Parameter(torch.randn(*weight_shape))
self.wscale = wscale
else:
self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale)
self.wscale = 1.0
if add_bias:
self.bias = nn.Parameter(torch.zeros(out_channels))
else:
self.bias = None
assert activation_type in ['linear', 'relu', 'lrelu']
def extra_repr(self):
return (f'in_ch={self.in_channels}, '
f'out_ch={self.out_channels}, '
f'ksize={self.kernel_size}, '
f'padding={self.padding}, '
f'wscale_gain={self.wscale_gain:.3f}, '
f'bias={self.add_bias}, '
f'upsample={self.scale_factor}, '
f'fused_scale={self.fused_scale}, '
f'act={self.activation_type}')
def forward(self, x):
x = self.pixel_norm(x)
x = self.up(x)
weight = self.weight
if self.wscale != 1.0:
weight = weight * self.wscale
if self.use_conv2d_transpose:
weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0.0)
weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] +
weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1])
x = F.conv_transpose2d(x,
weight=weight,
bias=self.bias,
stride=self.stride,
padding=self.padding)
else:
x = F.conv2d(x,
weight=weight,
bias=self.bias,
stride=self.stride,
padding=self.padding)
if self.activation_type == 'linear':
pass
elif self.activation_type == 'relu':
x = F.relu(x, inplace=True)
elif self.activation_type == 'lrelu':
x = F.leaky_relu(x, negative_slope=0.2, inplace=True)
else:
raise NotImplementedError(f'Not implemented activation type '
f'`{self.activation_type}`!')
return x
# pylint: enable=missing-function-docstring