|
"""FiLM Siren MLP as per https://marcoamonteiro.github.io/pi-GAN-website/.""" |
|
|
|
from typing import Optional |
|
|
|
import numpy as np |
|
import torch |
|
from torch import nn |
|
|
|
|
|
def kaiming_leaky_init(m): |
|
classname = m.__class__.__name__ |
|
if classname.find("Linear") != -1: |
|
torch.nn.init.kaiming_normal_( |
|
m.weight, a=0.2, mode="fan_in", nonlinearity="leaky_relu" |
|
) |
|
|
|
|
|
def frequency_init(freq): |
|
def init(m): |
|
with torch.no_grad(): |
|
if isinstance(m, nn.Linear): |
|
num_input = m.weight.size(-1) |
|
m.weight.uniform_( |
|
-np.sqrt(6 / num_input) / freq, np.sqrt(6 / num_input) / freq |
|
) |
|
|
|
return init |
|
|
|
|
|
def first_layer_film_sine_init(m): |
|
with torch.no_grad(): |
|
if isinstance(m, nn.Linear): |
|
num_input = m.weight.size(-1) |
|
m.weight.uniform_(-1 / num_input, 1 / num_input) |
|
|
|
|
|
class CustomMappingNetwork(nn.Module): |
|
def __init__(self, in_features, map_hidden_layers, map_hidden_dim, map_output_dim): |
|
super().__init__() |
|
|
|
self.network = [] |
|
|
|
for _ in range(map_hidden_layers): |
|
self.network.append(nn.Linear(in_features, map_hidden_dim)) |
|
self.network.append(nn.LeakyReLU(0.2, inplace=True)) |
|
in_features = map_hidden_dim |
|
|
|
self.network.append(nn.Linear(map_hidden_dim, map_output_dim)) |
|
|
|
self.network = nn.Sequential(*self.network) |
|
|
|
self.network.apply(kaiming_leaky_init) |
|
with torch.no_grad(): |
|
self.network[-1].weight *= 0.25 |
|
|
|
def forward(self, z): |
|
frequencies_offsets = self.network(z) |
|
frequencies = frequencies_offsets[ |
|
..., : torch.div(frequencies_offsets.shape[-1], 2, rounding_mode="floor") |
|
] |
|
phase_shifts = frequencies_offsets[ |
|
..., torch.div(frequencies_offsets.shape[-1], 2, rounding_mode="floor") : |
|
] |
|
|
|
return frequencies, phase_shifts |
|
|
|
|
|
class FiLMLayer(nn.Module): |
|
def __init__(self, input_dim, hidden_dim): |
|
super().__init__() |
|
self.layer = nn.Linear(input_dim, hidden_dim) |
|
|
|
def forward(self, x, freq, phase_shift): |
|
x = self.layer(x) |
|
freq = freq.expand_as(x) |
|
phase_shift = phase_shift.expand_as(x) |
|
return torch.sin(freq * x + phase_shift) |
|
|
|
|
|
class FiLMSiren(nn.Module): |
|
"""FiLM Conditioned Siren network.""" |
|
|
|
def __init__( |
|
self, |
|
in_dim: int, |
|
hidden_layers: int, |
|
hidden_features: int, |
|
mapping_network_in_dim: int, |
|
mapping_network_layers: int, |
|
mapping_network_features: int, |
|
out_dim: int, |
|
outermost_linear: bool = False, |
|
out_activation: Optional[nn.Module] = None, |
|
) -> None: |
|
super().__init__() |
|
self.in_dim = in_dim |
|
assert self.in_dim > 0 |
|
self.out_dim = out_dim if out_dim is not None else hidden_features |
|
self.hidden_layers = hidden_layers |
|
self.hidden_features = hidden_features |
|
self.mapping_network_in_dim = mapping_network_in_dim |
|
self.mapping_network_layers = mapping_network_layers |
|
self.mapping_network_features = mapping_network_features |
|
self.outermost_linear = outermost_linear |
|
self.out_activation = out_activation |
|
|
|
self.net = nn.ModuleList() |
|
|
|
self.net.append(FiLMLayer(self.in_dim, self.hidden_features)) |
|
|
|
for _ in range(self.hidden_layers - 1): |
|
self.net.append(FiLMLayer(self.hidden_features, self.hidden_features)) |
|
|
|
self.final_layer = None |
|
if self.outermost_linear: |
|
self.final_layer = nn.Linear(self.hidden_features, self.out_dim) |
|
self.final_layer.apply(frequency_init(25)) |
|
else: |
|
final_layer = FiLMLayer(self.hidden_features, self.out_dim) |
|
self.net.append(final_layer) |
|
|
|
self.mapping_network = CustomMappingNetwork( |
|
in_features=self.mapping_network_in_dim, |
|
map_hidden_layers=self.mapping_network_layers, |
|
map_hidden_dim=self.mapping_network_features, |
|
map_output_dim=(len(self.net)) * self.hidden_features * 2, |
|
) |
|
|
|
self.net.apply(frequency_init(25)) |
|
self.net[0].apply(first_layer_film_sine_init) |
|
|
|
def forward_with_frequencies_phase_shifts(self, x, frequencies, phase_shifts): |
|
"""Get conditiional frequencies and phase shifts from mapping network.""" |
|
frequencies = frequencies * 15 + 30 |
|
|
|
for index, layer in enumerate(self.net): |
|
start = index * self.hidden_features |
|
end = (index + 1) * self.hidden_features |
|
x = layer(x, frequencies[..., start:end], phase_shifts[..., start:end]) |
|
|
|
x = self.final_layer(x) if self.final_layer is not None else x |
|
output = self.out_activation(x) if self.out_activation is not None else x |
|
return output |
|
|
|
def forward(self, x, conditioning_input): |
|
"""Forward pass.""" |
|
frequencies, phase_shifts = self.mapping_network(conditioning_input) |
|
return self.forward_with_frequencies_phase_shifts(x, frequencies, phase_shifts) |
|
|