Spaces:
Build error
Build error
import jax | |
import jax.numpy as jnp | |
from jax import random | |
import flax.linen as nn | |
from jax import jit | |
import numpy as np | |
from functools import partial | |
from typing import Any | |
import h5py | |
#------------------------------------------------------ | |
# Other | |
#------------------------------------------------------ | |
def minibatch_stddev_layer(x, group_size=None, num_new_features=1): | |
if group_size is None: | |
group_size = x.shape[0] | |
else: | |
# Minibatch must be divisible by (or smaller than) group_size. | |
group_size = min(group_size, x.shape[0]) | |
G = group_size | |
F = num_new_features | |
_, H, W, C = x.shape | |
c = C // F | |
# [NHWC] Cast to FP32. | |
y = x.astype(jnp.float32) | |
# [GnHWFc] Split minibatch N into n groups of size G, and channels C into F groups of size c. | |
y = jnp.reshape(y, newshape=(G, -1, H, W, F, c)) | |
# [GnHWFc] Subtract mean over group. | |
y -= jnp.mean(y, axis=0) | |
# [nHWFc] Calc variance over group. | |
y = jnp.mean(jnp.square(y), axis=0) | |
# [nHWFc] Calc stddev over group. | |
y = jnp.sqrt(y + 1e-8) | |
# [nF] Take average over channels and pixels. | |
y = jnp.mean(y, axis=(1, 2, 4)) | |
# [nF] Cast back to original data type. | |
y = y.astype(x.dtype) | |
# [n11F] Add missing dimensions. | |
y = jnp.reshape(y, newshape=(-1, 1, 1, F)) | |
# [NHWC] Replicate over group and pixels. | |
y = jnp.tile(y, (G, H, W, 1)) | |
return jnp.concatenate((x, y), axis=3) | |
#------------------------------------------------------ | |
# Activation | |
#------------------------------------------------------ | |
def apply_activation(x, activation='linear', alpha=0.2, gain=np.sqrt(2)): | |
gain = jnp.array(gain, dtype=x.dtype) | |
if activation == 'relu': | |
return jax.nn.relu(x) * gain | |
if activation == 'leaky_relu': | |
return jax.nn.leaky_relu(x, negative_slope=alpha) * gain | |
return x | |
#------------------------------------------------------ | |
# Weights | |
#------------------------------------------------------ | |
def get_weight(shape, lr_multiplier=1, bias=True, param_dict=None, layer_name='', key=None): | |
if param_dict is None: | |
w = random.normal(key, shape=shape, dtype=jnp.float32) / lr_multiplier | |
if bias: b = jnp.zeros(shape=(shape[-1],), dtype=jnp.float32) | |
else: | |
w = jnp.array(param_dict[layer_name]['weight']).astype(jnp.float32) | |
if bias: b = jnp.array(param_dict[layer_name]['bias']).astype(jnp.float32) | |
if bias: return w, b | |
return w | |
def equalize_lr_weight(w, lr_multiplier=1): | |
""" | |
Equalized learning rate, see: https://arxiv.org/pdf/1710.10196.pdf. | |
Args: | |
w (tensor): Weight parameter. Shape [kernel, kernel, fmaps_in, fmaps_out] | |
for convolutions and shape [in, out] for MLPs. | |
lr_multiplier (float): Learning rate multiplier. | |
Returns: | |
(tensor): Scaled weight parameter. | |
""" | |
in_features = np.prod(w.shape[:-1]) | |
gain = lr_multiplier / np.sqrt(in_features) | |
w *= gain | |
return w | |
def equalize_lr_bias(b, lr_multiplier=1): | |
""" | |
Equalized learning rate, see: https://arxiv.org/pdf/1710.10196.pdf. | |
Args: | |
b (tensor): Bias parameter. | |
lr_multiplier (float): Learning rate multiplier. | |
Returns: | |
(tensor): Scaled bias parameter. | |
""" | |
gain = lr_multiplier | |
b *= gain | |
return b | |
#------------------------------------------------------ | |
# Normalization | |
#------------------------------------------------------ | |
def normalize_2nd_moment(x, eps=1e-8): | |
return x * jax.lax.rsqrt(jnp.mean(jnp.square(x), axis=1, keepdims=True) + eps) | |
#------------------------------------------------------ | |
# Upsampling | |
#------------------------------------------------------ | |
def setup_filter(f, normalize=True, flip_filter=False, gain=1, separable=None): | |
""" | |
Convenience function to setup 2D FIR filter for `upfirdn2d()`. | |
Args: | |
f (tensor): Tensor or python list of the shape. | |
normalize (bool): Normalize the filter so that it retains the magnitude. | |
for constant input signal (DC)? (default: True). | |
flip_filter (bool): Flip the filter? (default: False). | |
gain (int): Overall scaling factor for signal magnitude (default: 1). | |
separable: Return a separable filter? (default: select automatically). | |
Returns: | |
(tensor): Output filter of shape [filter_height, filter_width] or [filter_taps] | |
""" | |
# Validate. | |
if f is None: | |
f = 1 | |
f = jnp.array(f, dtype=jnp.float32) | |
assert f.ndim in [0, 1, 2] | |
assert f.size > 0 | |
if f.ndim == 0: | |
f = f[jnp.newaxis] | |
# Separable? | |
if separable is None: | |
separable = (f.ndim == 1 and f.size >= 8) | |
if f.ndim == 1 and not separable: | |
f = jnp.outer(f, f) | |
assert f.ndim == (1 if separable else 2) | |
# Apply normalize, flip, gain, and device. | |
if normalize: | |
f /= jnp.sum(f) | |
if flip_filter: | |
for i in range(f.ndim): | |
f = jnp.flip(f, axis=i) | |
f = f * (gain ** (f.ndim / 2)) | |
return f | |
def upfirdn2d(x, f, padding=(2, 1, 2, 1), up=1, down=1, strides=(1, 1), flip_filter=False, gain=1): | |
if f is None: | |
f = jnp.ones((1, 1), dtype=jnp.float32) | |
B, H, W, C = x.shape | |
padx0, padx1, pady0, pady1 = padding | |
# upsample by inserting zeros | |
x = jnp.reshape(x, newshape=(B, H, 1, W, 1, C)) | |
x = jnp.pad(x, pad_width=((0, 0), (0, 0), (0, up - 1), (0, 0), (0, up - 1), (0, 0))) | |
x = jnp.reshape(x, newshape=(B, H * up, W * up, C)) | |
# padding | |
x = jnp.pad(x, pad_width=((0, 0), (max(pady0, 0), max(pady1, 0)), (max(padx0, 0), max(padx1, 0)), (0, 0))) | |
x = x[:, max(-pady0, 0) : x.shape[1] - max(-pady1, 0), max(-padx0, 0) : x.shape[2] - max(-padx1, 0)] | |
# setup filter | |
f = f * (gain ** (f.ndim / 2)) | |
if not flip_filter: | |
for i in range(f.ndim): | |
f = jnp.flip(f, axis=i) | |
# convole filter | |
f = jnp.repeat(jnp.expand_dims(f, axis=(-2, -1)), repeats=C, axis=-1) | |
if f.ndim == 4: | |
x = jax.lax.conv_general_dilated(x, | |
f.astype(x.dtype), | |
window_strides=strides or (1,) * (x.ndim - 2), | |
padding='valid', | |
dimension_numbers=nn.linear._conv_dimension_numbers(x.shape), | |
feature_group_count=C) | |
else: | |
x = jax.lax.conv_general_dilated(x, | |
jnp.expand_dims(f, axis=0).astype(x.dtype), | |
window_strides=strides or (1,) * (x.ndim - 2), | |
padding='valid', | |
dimension_numbers=nn.linear._conv_dimension_numbers(x.shape), | |
feature_group_count=C) | |
x = jax.lax.conv_general_dilated(x, | |
jnp.expand_dims(f, axis=1).astype(x.dtype), | |
window_strides=strides or (1,) * (x.ndim - 2), | |
padding='valid', | |
dimension_numbers=nn.linear._conv_dimension_numbers(x.shape), | |
feature_group_count=C) | |
x = x[:, ::down, ::down] | |
return x | |
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1): | |
if f.ndim == 1: | |
fh, fw = f.shape[0], f.shape[0] | |
elif f.ndim == 2: | |
fh, fw = f.shape[0], f.shape[1] | |
else: | |
raise ValueError('Invalid filter shape:', f.shape) | |
padx0 = padding + (fw + up - 1) // 2 | |
padx1 = padding + (fw - up) // 2 | |
pady0 = padding + (fh + up - 1) // 2 | |
pady1 = padding + (fh - up) // 2 | |
return upfirdn2d(x, f=f, up=up, padding=(padx0, padx1, pady0, pady1), flip_filter=flip_filter, gain=gain * up * up) | |
#------------------------------------------------------ | |
# Linear | |
#------------------------------------------------------ | |
class LinearLayer(nn.Module): | |
""" | |
Linear Layer. | |
Attributes: | |
in_features (int): Input dimension. | |
out_features (int): Output dimension. | |
use_bias (bool): If True, use bias. | |
bias_init (int): Bias init. | |
lr_multiplier (float): Learning rate multiplier. | |
activation (str): Activation function: 'relu', 'lrelu', etc. | |
param_dict (h5py.Group): Parameter dict with pretrained parameters. | |
layer_name (str): Layer name. | |
dtype (str): Data type. | |
rng (jax.random.PRNGKey): Random seed for initialization. | |
""" | |
in_features: int | |
out_features: int | |
use_bias: bool=True | |
bias_init: int=0 | |
lr_multiplier: float=1 | |
activation: str='linear' | |
param_dict: h5py.Group=None | |
layer_name: str=None | |
dtype: str='float32' | |
rng: Any=random.PRNGKey(0) | |
def __call__(self, x): | |
""" | |
Run Linear Layer. | |
Args: | |
x (tensor): Input tensor of shape [N, in_features]. | |
Returns: | |
(tensor): Output tensor of shape [N, out_features]. | |
""" | |
w_shape = [self.in_features, self.out_features] | |
params = get_weight(w_shape, self.lr_multiplier, self.use_bias, self.param_dict, self.layer_name, self.rng) | |
if self.use_bias: | |
w, b = params | |
else: | |
w = params | |
w = self.param(name='weight', init_fn=lambda *_ : w) | |
w = equalize_lr_weight(w, self.lr_multiplier) | |
x = jnp.matmul(x, w.astype(x.dtype)) | |
if self.use_bias: | |
b = self.param(name='bias', init_fn=lambda *_ : b) | |
b = equalize_lr_bias(b, self.lr_multiplier) | |
x += b.astype(x.dtype) | |
x += self.bias_init | |
x = apply_activation(x, activation=self.activation) | |
return x | |
#------------------------------------------------------ | |
# Convolution | |
#------------------------------------------------------ | |
def conv_downsample_2d(x, w, k=None, factor=2, gain=1, padding=0): | |
""" | |
Fused downsample convolution. | |
Padding is performed only once at the beginning, not between the operations. | |
The fused op is considerably more efficient than performing the same calculation | |
using standard TensorFlow ops. It supports gradients of arbitrary order. | |
Args: | |
x (tensor): Input tensor of the shape [N, H, W, C]. | |
w (tensor): Weight tensor of the shape [filterH, filterW, inChannels, outChannels]. | |
Grouped convolution can be performed by inChannels = x.shape[0] // numGroups. | |
k (tensor): FIR filter of the shape [firH, firW] or [firN]. | |
The default is `[1] * factor`, which corresponds to average pooling. | |
factor (int): Downsampling factor (default: 2). | |
gain (float): Scaling factor for signal magnitude (default: 1.0). | |
padding (int): Number of pixels to pad or crop the output on each side (default: 0). | |
Returns: | |
(tensor): Output of the shape [N, H // factor, W // factor, C]. | |
""" | |
assert isinstance(factor, int) and factor >= 1 | |
assert isinstance(padding, int) | |
# Check weight shape. | |
ch, cw, _inC, _outC = w.shape | |
assert cw == ch | |
# Setup filter kernel. | |
k = setup_filter(k, gain=gain) | |
assert k.shape[0] == k.shape[1] | |
# Execute. | |
pad0 = (k.shape[0] - factor + cw) // 2 + padding * factor | |
pad1 = (k.shape[0] - factor + cw - 1) // 2 + padding * factor | |
x = upfirdn2d(x=x, f=k, padding=(pad0, pad0, pad1, pad1)) | |
x = jax.lax.conv_general_dilated(x, | |
w, | |
window_strides=(factor, factor), | |
padding='VALID', | |
dimension_numbers=nn.linear._conv_dimension_numbers(x.shape)) | |
return x | |
def upsample_conv_2d(x, w, k=None, factor=2, gain=1, padding=0): | |
""" | |
Fused upsample convolution. | |
Padding is performed only once at the beginning, not between the operations. | |
The fused op is considerably more efficient than performing the same calculation | |
using standard TensorFlow ops. It supports gradients of arbitrary order. | |
Args: | |
x (tensor): Input tensor of the shape [N, H, W, C]. | |
w (tensor): Weight tensor of the shape [filterH, filterW, inChannels, outChannels]. | |
Grouped convolution can be performed by inChannels = x.shape[0] // numGroups. | |
k (tensor): FIR filter of the shape [firH, firW] or [firN]. | |
The default is [1] * factor, which corresponds to nearest-neighbor upsampling. | |
factor (int): Integer upsampling factor (default: 2). | |
gain (float): Scaling factor for signal magnitude (default: 1.0). | |
padding (int): Number of pixels to pad or crop the output on each side (default: 0). | |
Returns: | |
(tensor): Output of the shape [N, H * factor, W * factor, C]. | |
""" | |
assert isinstance(factor, int) and factor >= 1 | |
assert isinstance(padding, int) | |
# Check weight shape. | |
ch, cw, _inC, _outC = w.shape | |
inC = w.shape[2] | |
outC = w.shape[3] | |
assert cw == ch | |
# Fast path for 1x1 convolution. | |
if cw == 1 and ch == 1: | |
x = jax.lax.conv_general_dilated(x, | |
w, | |
window_strides=(1, 1), | |
padding='VALID', | |
dimension_numbers=nn.linear._conv_dimension_numbers(x.shape)) | |
k = setup_filter(k, gain=gain * (factor ** 2)) | |
pad0 = (k.shape[0] + factor - cw) // 2 + padding | |
pad1 = (k.shape[0] - factor) // 2 + padding | |
x = upfirdn2d(x, f=k, up=factor, padding=(pad0, pad1, pad0, pad1)) | |
return x | |
# Setup filter kernel. | |
k = setup_filter(k, gain=gain * (factor ** 2)) | |
assert k.shape[0] == k.shape[1] | |
# Determine data dimensions. | |
stride = (factor, factor) | |
output_shape = ((x.shape[1] - 1) * factor + ch, (x.shape[2] - 1) * factor + cw) | |
num_groups = x.shape[3] // inC | |
# Transpose weights. | |
w = jnp.reshape(w, (ch, cw, inC, num_groups, -1)) | |
w = jnp.transpose(w[::-1, ::-1], (0, 1, 4, 3, 2)) | |
w = jnp.reshape(w, (ch, cw, -1, num_groups * inC)) | |
# Execute. | |
x = gradient_based_conv_transpose(lhs=x, | |
rhs=w, | |
strides=stride, | |
padding='VALID', | |
output_padding=(0, 0, 0, 0), | |
output_shape=output_shape, | |
) | |
pad0 = (k.shape[0] + factor - cw) // 2 + padding | |
pad1 = (k.shape[0] - factor - cw + 3) // 2 + padding | |
x = upfirdn2d(x=x, f=k, padding=(pad0, pad1, pad0, pad1)) | |
return x | |
def conv2d(x, w, up=False, down=False, resample_kernel=None, padding=0): | |
assert not (up and down) | |
kernel = w.shape[0] | |
assert w.shape[1] == kernel | |
assert kernel >= 1 and kernel % 2 == 1 | |
num_groups = x.shape[3] // w.shape[2] | |
w = w.astype(x.dtype) | |
if up: | |
x = upsample_conv_2d(x, w, k=resample_kernel, padding=padding) | |
elif down: | |
x = conv_downsample_2d(x, w, k=resample_kernel, padding=padding) | |
else: | |
padding_mode = {0: 'SAME', -(kernel // 2): 'VALID'}[padding] | |
x = jax.lax.conv_general_dilated(x, | |
w, | |
window_strides=(1, 1), | |
padding=padding_mode, | |
dimension_numbers=nn.linear._conv_dimension_numbers(x.shape), | |
feature_group_count=num_groups) | |
return x | |
def modulated_conv2d_layer(x, w, s, fmaps, kernel, up=False, down=False, demodulate=True, resample_kernel=None, fused_modconv=False): | |
assert not (up and down) | |
assert kernel >= 1 and kernel % 2 == 1 | |
# Get weight. | |
wshape = (kernel, kernel, x.shape[3], fmaps) | |
if x.dtype.name == 'float16' and not fused_modconv and demodulate: | |
w *= jnp.sqrt(1 / np.prod(wshape[:-1])) / jnp.max(jnp.abs(w), axis=(0, 1, 2)) # Pre-normalize to avoid float16 overflow. | |
ww = w[jnp.newaxis] # [BkkIO] Introduce minibatch dimension. | |
# Modulate. | |
if x.dtype.name == 'float16' and not fused_modconv and demodulate: | |
s *= 1 / jnp.max(jnp.abs(s)) # Pre-normalize to avoid float16 overflow. | |
ww *= s[:, jnp.newaxis, jnp.newaxis, :, jnp.newaxis].astype(w.dtype) # [BkkIO] Scale input feature maps. | |
# Demodulate. | |
if demodulate: | |
d = jax.lax.rsqrt(jnp.sum(jnp.square(ww), axis=(1, 2, 3)) + 1e-8) # [BO] Scaling factor. | |
ww *= d[:, jnp.newaxis, jnp.newaxis, jnp.newaxis, :] # [BkkIO] Scale output feature maps. | |
# Reshape/scale input. | |
if fused_modconv: | |
x = jnp.transpose(x, axes=(0, 3, 1, 2)) | |
x = jnp.reshape(x, (1, -1, x.shape[2], x.shape[3])) # Fused => reshape minibatch to convolution groups. | |
x = jnp.transpose(x, axes=(0, 2, 3, 1)) | |
w = jnp.reshape(jnp.transpose(ww, (1, 2, 3, 0, 4)), (ww.shape[1], ww.shape[2], ww.shape[3], -1)) | |
else: | |
x *= s[:, jnp.newaxis, jnp.newaxis].astype(x.dtype) # [BIhw] Not fused => scale input activations. | |
# 2D convolution. | |
x = conv2d(x, w.astype(x.dtype), up=up, down=down, resample_kernel=resample_kernel) | |
# Reshape/scale output. | |
if fused_modconv: | |
x = jnp.transpose(x, axes=(0, 3, 1, 2)) | |
x = jnp.reshape(x, (-1, fmaps, x.shape[2], x.shape[3])) # Fused => reshape convolution groups back to minibatch. | |
x = jnp.transpose(x, axes=(0, 2, 3, 1)) | |
elif demodulate: | |
x *= d[:, jnp.newaxis, jnp.newaxis].astype(x.dtype) # [BOhw] Not fused => scale output activations. | |
return x | |
def _deconv_output_length(input_length, filter_size, padding, output_padding=None, stride=0, dilation=1): | |
""" | |
Taken from: https://github.com/google/jax/pull/5772/commits | |
Determines the output length of a transposed convolution given the input length. | |
Function modified from Keras. | |
Arguments: | |
input_length: Integer. | |
filter_size: Integer. | |
padding: one of `"SAME"`, `"VALID"`, or a 2-integer tuple. | |
output_padding: Integer, amount of padding along the output dimension. Can | |
be set to `None` in which case the output length is inferred. | |
stride: Integer. | |
dilation: Integer. | |
Returns: | |
The output length (integer). | |
""" | |
if input_length is None: | |
return None | |
# Get the dilated kernel size | |
filter_size = filter_size + (filter_size - 1) * (dilation - 1) | |
# Infer length if output padding is None, else compute the exact length | |
if output_padding is None: | |
if padding == 'VALID': | |
length = input_length * stride + max(filter_size - stride, 0) | |
elif padding == 'SAME': | |
length = input_length * stride | |
else: | |
length = ((input_length - 1) * stride + filter_size - padding[0] - padding[1]) | |
else: | |
if padding == 'SAME': | |
pad = filter_size // 2 | |
total_pad = pad * 2 | |
elif padding == 'VALID': | |
total_pad = 0 | |
else: | |
total_pad = padding[0] + padding[1] | |
length = ((input_length - 1) * stride + filter_size - total_pad + output_padding) | |
return length | |
def _compute_adjusted_padding(input_size, output_size, kernel_size, stride, padding, dilation=1): | |
""" | |
Taken from: https://github.com/google/jax/pull/5772/commits | |
Computes adjusted padding for desired ConvTranspose `output_size`. | |
Ported from DeepMind Haiku. | |
""" | |
kernel_size = (kernel_size - 1) * dilation + 1 | |
if padding == 'VALID': | |
expected_input_size = (output_size - kernel_size + stride) // stride | |
if input_size != expected_input_size: | |
raise ValueError(f'The expected input size with the current set of input ' | |
f'parameters is {expected_input_size} which doesn\'t ' | |
f'match the actual input size {input_size}.') | |
padding_before = 0 | |
elif padding == 'SAME': | |
expected_input_size = (output_size + stride - 1) // stride | |
if input_size != expected_input_size: | |
raise ValueError(f'The expected input size with the current set of input ' | |
f'parameters is {expected_input_size} which doesn\'t ' | |
f'match the actual input size {input_size}.') | |
padding_needed = max(0, (input_size - 1) * stride + kernel_size - output_size) | |
padding_before = padding_needed // 2 | |
else: | |
padding_before = padding[0] # type: ignore[assignment] | |
expanded_input_size = (input_size - 1) * stride + 1 | |
padded_out_size = output_size + kernel_size - 1 | |
pad_before = kernel_size - 1 - padding_before | |
pad_after = padded_out_size - expanded_input_size - pad_before | |
return (pad_before, pad_after) | |
def _flip_axes(x, axes): | |
""" | |
Taken from: https://github.com/google/jax/blob/master/jax/_src/lax/lax.py | |
Flip ndarray 'x' along each axis specified in axes tuple. | |
""" | |
for axis in axes: | |
x = jnp.flip(x, axis) | |
return x | |
def gradient_based_conv_transpose(lhs, | |
rhs, | |
strides, | |
padding, | |
output_padding, | |
output_shape=None, | |
dilation=None, | |
dimension_numbers=None, | |
transpose_kernel=True, | |
feature_group_count=1, | |
precision=None): | |
""" | |
Taken from: https://github.com/google/jax/pull/5772/commits | |
Convenience wrapper for calculating the N-d transposed convolution. | |
Much like `conv_transpose`, this function calculates transposed convolutions | |
via fractionally strided convolution rather than calculating the gradient | |
(transpose) of a forward convolution. However, the latter is more common | |
among deep learning frameworks, such as TensorFlow, PyTorch, and Keras. | |
This function provides the same set of APIs to help reproduce results in these frameworks. | |
Args: | |
lhs: a rank `n+2` dimensional input array. | |
rhs: a rank `n+2` dimensional array of kernel weights. | |
strides: sequence of `n` integers, amounts to strides of the corresponding forward convolution. | |
padding: `"SAME"`, `"VALID"`, or a sequence of `n` integer 2-tuples that controls | |
the before-and-after padding for each `n` spatial dimension of | |
the corresponding forward convolution. | |
output_padding: A sequence of integers specifying the amount of padding along | |
each spacial dimension of the output tensor, used to disambiguate the output shape of | |
transposed convolutions when the stride is larger than 1. | |
(see a detailed description at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html) | |
The amount of output padding along a given dimension must | |
be lower than the stride along that same dimension. | |
If set to `None` (default), the output shape is inferred. | |
If both `output_padding` and `output_shape` are specified, they have to be mutually compatible. | |
output_shape: Output shape of the spatial dimensions of a transpose | |
convolution. Can be `None` or an iterable of `n` integers. If a `None` value is given (default), | |
the shape is automatically calculated. | |
Similar to `output_padding`, `output_shape` is also for disambiguating the output shape | |
when stride > 1 (see also | |
https://www.tensorflow.org/api_docs/python/tf/nn/conv2d_transpose) | |
If both `output_padding` and `output_shape` are specified, they have to be mutually compatible. | |
dilation: `None`, or a sequence of `n` integers, giving the | |
dilation factor to apply in each spatial dimension of `rhs`. Dilated convolution | |
is also known as atrous convolution. | |
dimension_numbers: tuple of dimension descriptors as in lax.conv_general_dilated. Defaults to tensorflow convention. | |
transpose_kernel: if `True` flips spatial axes and swaps the input/output | |
channel axes of the kernel. This makes the output of this function identical | |
to the gradient-derived functions like keras.layers.Conv2DTranspose and | |
torch.nn.ConvTranspose2d applied to the same kernel. | |
Although for typical use in neural nets this is unnecessary | |
and makes input/output channel specification confusing, you need to set this to `True` | |
in order to match the behavior in many deep learning frameworks, such as TensorFlow, Keras, and PyTorch. | |
precision: Optional. Either ``None``, which means the default precision for | |
the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``, | |
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two | |
``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``. | |
Returns: | |
Transposed N-d convolution. | |
""" | |
assert len(lhs.shape) == len(rhs.shape) and len(lhs.shape) >= 2 | |
ndims = len(lhs.shape) | |
one = (1,) * (ndims - 2) | |
# Set dimensional layout defaults if not specified. | |
if dimension_numbers is None: | |
if ndims == 2: | |
dimension_numbers = ('NC', 'IO', 'NC') | |
elif ndims == 3: | |
dimension_numbers = ('NHC', 'HIO', 'NHC') | |
elif ndims == 4: | |
dimension_numbers = ('NHWC', 'HWIO', 'NHWC') | |
elif ndims == 5: | |
dimension_numbers = ('NHWDC', 'HWDIO', 'NHWDC') | |
else: | |
raise ValueError('No 4+ dimensional dimension_number defaults.') | |
dn = jax.lax.conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers) | |
k_shape = np.take(rhs.shape, dn.rhs_spec) | |
k_sdims = k_shape[2:] # type: ignore[index] | |
i_shape = np.take(lhs.shape, dn.lhs_spec) | |
i_sdims = i_shape[2:] # type: ignore[index] | |
# Calculate correct output shape given padding and strides. | |
if dilation is None: | |
dilation = (1,) * (rhs.ndim - 2) | |
if output_padding is None: | |
output_padding = [None] * (rhs.ndim - 2) # type: ignore[list-item] | |
if isinstance(padding, str): | |
if padding in {'SAME', 'VALID'}: | |
padding = [padding] * (rhs.ndim - 2) # type: ignore[list-item] | |
else: | |
raise ValueError(f"`padding` must be 'VALID' or 'SAME'. Passed: {padding}.") | |
inferred_output_shape = tuple(map(_deconv_output_length, i_sdims, k_sdims, padding, output_padding, strides, dilation)) | |
if output_shape is None: | |
output_shape = inferred_output_shape # type: ignore[assignment] | |
else: | |
if not output_shape == inferred_output_shape: | |
raise ValueError(f'`output_padding` and `output_shape` are not compatible.' | |
f'Inferred output shape from `output_padding`: {inferred_output_shape}, ' | |
f'but got `output_shape` {output_shape}') | |
pads = tuple(map(_compute_adjusted_padding, i_sdims, output_shape, k_sdims, strides, padding, dilation)) | |
if transpose_kernel: | |
# flip spatial dims and swap input / output channel axes | |
rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:]) | |
rhs = np.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1]) | |
return jax.lax.conv_general_dilated(lhs, rhs, one, pads, strides, dilation, dn, feature_group_count, precision=precision) | |