Spaces:
Runtime error
Runtime error
""" The code is based on https://github.com/apple/ml-gsn/ with adaption. """ | |
import math | |
import torch | |
import torch.nn as nn | |
import math | |
import torch.nn.functional as F | |
from lib.torch_utils.ops.native_ops import FusedLeakyReLU, fused_leaky_relu, upfirdn2d | |
class DiscriminatorHead(nn.Module): | |
def __init__(self, in_channel, disc_stddev=False): | |
super().__init__() | |
self.disc_stddev = disc_stddev | |
stddev_dim = 1 if disc_stddev else 0 | |
self.conv_stddev = ConvLayer2d(in_channel=in_channel + stddev_dim, | |
out_channel=in_channel, | |
kernel_size=3, | |
activate=True) | |
self.final_linear = nn.Sequential( | |
nn.Flatten(), | |
EqualLinear(in_channel=in_channel * 4 * 4, out_channel=in_channel, activate=True), | |
EqualLinear(in_channel=in_channel, out_channel=1), | |
) | |
def cat_stddev(self, x, stddev_group=4, stddev_feat=1): | |
perm = torch.randperm(len(x)) | |
inv_perm = torch.argsort(perm) | |
batch, channel, height, width = x.shape | |
x = x[ | |
perm] # shuffle inputs so that all views in a single trajectory don't get put together | |
group = min(batch, stddev_group) | |
stddev = x.view(group, -1, stddev_feat, channel // stddev_feat, height, width) | |
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) | |
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) | |
stddev = stddev.repeat(group, 1, height, width) | |
stddev = stddev[inv_perm] # reorder inputs | |
x = x[inv_perm] | |
out = torch.cat([x, stddev], 1) | |
return out | |
def forward(self, x): | |
if self.disc_stddev: | |
x = self.cat_stddev(x) | |
x = self.conv_stddev(x) | |
out = self.final_linear(x) | |
return out | |
class ConvDecoder(nn.Module): | |
def __init__(self, in_channel, out_channel, in_res, out_res): | |
super().__init__() | |
log_size_in = int(math.log(in_res, 2)) | |
log_size_out = int(math.log(out_res, 2)) | |
self.layers = [] | |
in_ch = in_channel | |
for i in range(log_size_in, log_size_out): | |
out_ch = in_ch // 2 | |
self.layers.append( | |
ConvLayer2d(in_channel=in_ch, | |
out_channel=out_ch, | |
kernel_size=3, | |
upsample=True, | |
bias=True, | |
activate=True)) | |
in_ch = out_ch | |
self.layers.append( | |
ConvLayer2d(in_channel=in_ch, | |
out_channel=out_channel, | |
kernel_size=3, | |
bias=True, | |
activate=False)) | |
self.layers = nn.Sequential(*self.layers) | |
def forward(self, x): | |
return self.layers(x) | |
class StyleDiscriminator(nn.Module): | |
def __init__(self, in_channel, in_res, ch_mul=64, ch_max=512, **kwargs): | |
super().__init__() | |
log_size_in = int(math.log(in_res, 2)) | |
log_size_out = int(math.log(4, 2)) | |
self.conv_in = ConvLayer2d(in_channel=in_channel, out_channel=ch_mul, kernel_size=3) | |
# each resblock will half the resolution and double the number of features (until a maximum of ch_max) | |
self.layers = [] | |
in_channels = ch_mul | |
for i in range(log_size_in, log_size_out, -1): | |
out_channels = int(min(in_channels * 2, ch_max)) | |
self.layers.append( | |
ConvResBlock2d(in_channel=in_channels, out_channel=out_channels, downsample=True)) | |
in_channels = out_channels | |
self.layers = nn.Sequential(*self.layers) | |
self.disc_out = DiscriminatorHead(in_channel=in_channels, disc_stddev=True) | |
def forward(self, x): | |
x = self.conv_in(x) | |
x = self.layers(x) | |
out = self.disc_out(x) | |
return out | |
def make_kernel(k): | |
k = torch.tensor(k, dtype=torch.float32) | |
if k.ndim == 1: | |
k = k[None, :] * k[:, None] | |
k /= k.sum() | |
return k | |
class Blur(nn.Module): | |
"""Blur layer. | |
Applies a blur kernel to input image using finite impulse response filter. Blurring feature maps after | |
convolutional upsampling or before convolutional downsampling helps produces models that are more robust to | |
shifting inputs (https://richzhang.github.io/antialiased-cnns/). In the context of GANs, this can provide | |
cleaner gradients, and therefore more stable training. | |
Args: | |
---- | |
kernel: list, int | |
A list of integers representing a blur kernel. For exmaple: [1, 3, 3, 1]. | |
pad: tuple, int | |
A tuple of integers representing the number of rows/columns of padding to be added to the top/left and | |
the bottom/right respectively. | |
upsample_factor: int | |
Upsample factor. | |
""" | |
def __init__(self, kernel, pad, upsample_factor=1): | |
super().__init__() | |
kernel = make_kernel(kernel) | |
if upsample_factor > 1: | |
kernel = kernel * (upsample_factor**2) | |
self.register_buffer("kernel", kernel) | |
self.pad = pad | |
def forward(self, input): | |
out = upfirdn2d(input, self.kernel, pad=self.pad) | |
return out | |
class Upsample(nn.Module): | |
"""Upsampling layer. | |
Perform upsampling using a blur kernel. | |
Args: | |
---- | |
kernel: list, int | |
A list of integers representing a blur kernel. For exmaple: [1, 3, 3, 1]. | |
factor: int | |
Upsampling factor. | |
""" | |
def __init__(self, kernel=[1, 3, 3, 1], factor=2): | |
super().__init__() | |
self.factor = factor | |
kernel = make_kernel(kernel) * (factor**2) | |
self.register_buffer("kernel", kernel) | |
p = kernel.shape[0] - factor | |
pad0 = (p + 1) // 2 + factor - 1 | |
pad1 = p // 2 | |
self.pad = (pad0, pad1) | |
def forward(self, input): | |
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) | |
return out | |
class Downsample(nn.Module): | |
"""Downsampling layer. | |
Perform downsampling using a blur kernel. | |
Args: | |
---- | |
kernel: list, int | |
A list of integers representing a blur kernel. For exmaple: [1, 3, 3, 1]. | |
factor: int | |
Downsampling factor. | |
""" | |
def __init__(self, kernel=[1, 3, 3, 1], factor=2): | |
super().__init__() | |
self.factor = factor | |
kernel = make_kernel(kernel) | |
self.register_buffer("kernel", kernel) | |
p = kernel.shape[0] - factor | |
pad0 = (p + 1) // 2 | |
pad1 = p // 2 | |
self.pad = (pad0, pad1) | |
def forward(self, input): | |
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) | |
return out | |
class EqualLinear(nn.Module): | |
"""Linear layer with equalized learning rate. | |
During the forward pass the weights are scaled by the inverse of the He constant (i.e. sqrt(in_dim)) to | |
prevent vanishing gradients and accelerate training. This constant only works for ReLU or LeakyReLU | |
activation functions. | |
Args: | |
---- | |
in_channel: int | |
Input channels. | |
out_channel: int | |
Output channels. | |
bias: bool | |
Use bias term. | |
bias_init: float | |
Initial value for the bias. | |
lr_mul: float | |
Learning rate multiplier. By scaling weights and the bias we can proportionally scale the magnitude of | |
the gradients, effectively increasing/decreasing the learning rate for this layer. | |
activate: bool | |
Apply leakyReLU activation. | |
""" | |
def __init__(self, in_channel, out_channel, bias=True, bias_init=0, lr_mul=1, activate=False): | |
super().__init__() | |
self.weight = nn.Parameter(torch.randn(out_channel, in_channel).div_(lr_mul)) | |
if bias: | |
self.bias = nn.Parameter(torch.zeros(out_channel).fill_(bias_init)) | |
else: | |
self.bias = None | |
self.activate = activate | |
self.scale = (1 / math.sqrt(in_channel)) * lr_mul | |
self.lr_mul = lr_mul | |
def forward(self, input): | |
if self.activate: | |
out = F.linear(input, self.weight * self.scale) | |
out = fused_leaky_relu(out, self.bias * self.lr_mul) | |
else: | |
out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul) | |
return out | |
def __repr__(self): | |
return f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" | |
class EqualConv2d(nn.Module): | |
"""2D convolution layer with equalized learning rate. | |
During the forward pass the weights are scaled by the inverse of the He constant (i.e. sqrt(in_dim)) to | |
prevent vanishing gradients and accelerate training. This constant only works for ReLU or LeakyReLU | |
activation functions. | |
Args: | |
---- | |
in_channel: int | |
Input channels. | |
out_channel: int | |
Output channels. | |
kernel_size: int | |
Kernel size. | |
stride: int | |
Stride of convolutional kernel across the input. | |
padding: int | |
Amount of zero padding applied to both sides of the input. | |
bias: bool | |
Use bias term. | |
""" | |
def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True): | |
super().__init__() | |
self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size)) | |
self.scale = 1 / math.sqrt(in_channel * kernel_size**2) | |
self.stride = stride | |
self.padding = padding | |
if bias: | |
self.bias = nn.Parameter(torch.zeros(out_channel)) | |
else: | |
self.bias = None | |
def forward(self, input): | |
out = F.conv2d(input, | |
self.weight * self.scale, | |
bias=self.bias, | |
stride=self.stride, | |
padding=self.padding) | |
return out | |
def __repr__(self): | |
return (f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," | |
f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})") | |
class EqualConvTranspose2d(nn.Module): | |
"""2D transpose convolution layer with equalized learning rate. | |
During the forward pass the weights are scaled by the inverse of the He constant (i.e. sqrt(in_dim)) to | |
prevent vanishing gradients and accelerate training. This constant only works for ReLU or LeakyReLU | |
activation functions. | |
Args: | |
---- | |
in_channel: int | |
Input channels. | |
out_channel: int | |
Output channels. | |
kernel_size: int | |
Kernel size. | |
stride: int | |
Stride of convolutional kernel across the input. | |
padding: int | |
Amount of zero padding applied to both sides of the input. | |
output_padding: int | |
Extra padding added to input to achieve the desired output size. | |
bias: bool | |
Use bias term. | |
""" | |
def __init__(self, | |
in_channel, | |
out_channel, | |
kernel_size, | |
stride=1, | |
padding=0, | |
output_padding=0, | |
bias=True): | |
super().__init__() | |
self.weight = nn.Parameter(torch.randn(in_channel, out_channel, kernel_size, kernel_size)) | |
self.scale = 1 / math.sqrt(in_channel * kernel_size**2) | |
self.stride = stride | |
self.padding = padding | |
self.output_padding = output_padding | |
if bias: | |
self.bias = nn.Parameter(torch.zeros(out_channel)) | |
else: | |
self.bias = None | |
def forward(self, input): | |
out = F.conv_transpose2d( | |
input, | |
self.weight * self.scale, | |
bias=self.bias, | |
stride=self.stride, | |
padding=self.padding, | |
output_padding=self.output_padding, | |
) | |
return out | |
def __repr__(self): | |
return (f'{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]},' | |
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})') | |
class ConvLayer2d(nn.Sequential): | |
def __init__( | |
self, | |
in_channel, | |
out_channel, | |
kernel_size=3, | |
upsample=False, | |
downsample=False, | |
blur_kernel=[1, 3, 3, 1], | |
bias=True, | |
activate=True, | |
): | |
assert not (upsample and downsample), 'Cannot upsample and downsample simultaneously' | |
layers = [] | |
if upsample: | |
factor = 2 | |
p = (len(blur_kernel) - factor) - (kernel_size - 1) | |
pad0 = (p + 1) // 2 + factor - 1 | |
pad1 = p // 2 + 1 | |
layers.append( | |
EqualConvTranspose2d(in_channel, | |
out_channel, | |
kernel_size, | |
padding=0, | |
stride=2, | |
bias=bias and not activate)) | |
layers.append(Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)) | |
if downsample: | |
factor = 2 | |
p = (len(blur_kernel) - factor) + (kernel_size - 1) | |
pad0 = (p + 1) // 2 | |
pad1 = p // 2 | |
layers.append(Blur(blur_kernel, pad=(pad0, pad1))) | |
layers.append( | |
EqualConv2d(in_channel, | |
out_channel, | |
kernel_size, | |
padding=0, | |
stride=2, | |
bias=bias and not activate)) | |
if (not downsample) and (not upsample): | |
padding = kernel_size // 2 | |
layers.append( | |
EqualConv2d(in_channel, | |
out_channel, | |
kernel_size, | |
padding=padding, | |
stride=1, | |
bias=bias and not activate)) | |
if activate: | |
layers.append(FusedLeakyReLU(out_channel, bias=bias)) | |
super().__init__(*layers) | |
class ConvResBlock2d(nn.Module): | |
"""2D convolutional residual block with equalized learning rate. | |
Residual block composed of 3x3 convolutions and leaky ReLUs. | |
Args: | |
---- | |
in_channel: int | |
Input channels. | |
out_channel: int | |
Output channels. | |
upsample: bool | |
Apply upsampling via strided convolution in the first conv. | |
downsample: bool | |
Apply downsampling via strided convolution in the second conv. | |
""" | |
def __init__(self, in_channel, out_channel, upsample=False, downsample=False): | |
super().__init__() | |
assert not (upsample and downsample), 'Cannot upsample and downsample simultaneously' | |
mid_ch = in_channel if downsample else out_channel | |
self.conv1 = ConvLayer2d(in_channel, mid_ch, upsample=upsample, kernel_size=3) | |
self.conv2 = ConvLayer2d(mid_ch, out_channel, downsample=downsample, kernel_size=3) | |
if (in_channel != out_channel) or upsample or downsample: | |
self.skip = ConvLayer2d( | |
in_channel, | |
out_channel, | |
upsample=upsample, | |
downsample=downsample, | |
kernel_size=1, | |
activate=False, | |
bias=False, | |
) | |
def forward(self, input): | |
out = self.conv1(input) | |
out = self.conv2(out) | |
if hasattr(self, 'skip'): | |
skip = self.skip(input) | |
out = (out + skip) / math.sqrt(2) | |
else: | |
out = (out + input) / math.sqrt(2) | |
return out | |