"""FiLM Siren MLP as per https://marcoamonteiro.github.io/pi-GAN-website/.""" from typing import Optional import numpy as np import torch from torch import nn def kaiming_leaky_init(m): classname = m.__class__.__name__ if classname.find("Linear") != -1: torch.nn.init.kaiming_normal_( m.weight, a=0.2, mode="fan_in", nonlinearity="leaky_relu" ) def frequency_init(freq): def init(m): with torch.no_grad(): if isinstance(m, nn.Linear): num_input = m.weight.size(-1) m.weight.uniform_( -np.sqrt(6 / num_input) / freq, np.sqrt(6 / num_input) / freq ) return init def first_layer_film_sine_init(m): with torch.no_grad(): if isinstance(m, nn.Linear): num_input = m.weight.size(-1) m.weight.uniform_(-1 / num_input, 1 / num_input) class CustomMappingNetwork(nn.Module): def __init__(self, in_features, map_hidden_layers, map_hidden_dim, map_output_dim): super().__init__() self.network = [] for _ in range(map_hidden_layers): self.network.append(nn.Linear(in_features, map_hidden_dim)) self.network.append(nn.LeakyReLU(0.2, inplace=True)) in_features = map_hidden_dim self.network.append(nn.Linear(map_hidden_dim, map_output_dim)) self.network = nn.Sequential(*self.network) self.network.apply(kaiming_leaky_init) with torch.no_grad(): self.network[-1].weight *= 0.25 def forward(self, z): frequencies_offsets = self.network(z) frequencies = frequencies_offsets[ ..., : torch.div(frequencies_offsets.shape[-1], 2, rounding_mode="floor") ] phase_shifts = frequencies_offsets[ ..., torch.div(frequencies_offsets.shape[-1], 2, rounding_mode="floor") : ] return frequencies, phase_shifts class FiLMLayer(nn.Module): def __init__(self, input_dim, hidden_dim): super().__init__() self.layer = nn.Linear(input_dim, hidden_dim) def forward(self, x, freq, phase_shift): x = self.layer(x) freq = freq.expand_as(x) phase_shift = phase_shift.expand_as(x) return torch.sin(freq * x + phase_shift) class FiLMSiren(nn.Module): """FiLM Conditioned Siren network.""" def __init__( self, in_dim: int, hidden_layers: int, hidden_features: int, mapping_network_in_dim: int, mapping_network_layers: int, mapping_network_features: int, out_dim: int, outermost_linear: bool = False, out_activation: Optional[nn.Module] = None, ) -> None: super().__init__() self.in_dim = in_dim assert self.in_dim > 0 self.out_dim = out_dim if out_dim is not None else hidden_features self.hidden_layers = hidden_layers self.hidden_features = hidden_features self.mapping_network_in_dim = mapping_network_in_dim self.mapping_network_layers = mapping_network_layers self.mapping_network_features = mapping_network_features self.outermost_linear = outermost_linear self.out_activation = out_activation self.net = nn.ModuleList() self.net.append(FiLMLayer(self.in_dim, self.hidden_features)) for _ in range(self.hidden_layers - 1): self.net.append(FiLMLayer(self.hidden_features, self.hidden_features)) self.final_layer = None if self.outermost_linear: self.final_layer = nn.Linear(self.hidden_features, self.out_dim) self.final_layer.apply(frequency_init(25)) else: final_layer = FiLMLayer(self.hidden_features, self.out_dim) self.net.append(final_layer) self.mapping_network = CustomMappingNetwork( in_features=self.mapping_network_in_dim, map_hidden_layers=self.mapping_network_layers, map_hidden_dim=self.mapping_network_features, map_output_dim=(len(self.net)) * self.hidden_features * 2, ) self.net.apply(frequency_init(25)) self.net[0].apply(first_layer_film_sine_init) def forward_with_frequencies_phase_shifts(self, x, frequencies, phase_shifts): """Get conditiional frequencies and phase shifts from mapping network.""" frequencies = frequencies * 15 + 30 for index, layer in enumerate(self.net): start = index * self.hidden_features end = (index + 1) * self.hidden_features x = layer(x, frequencies[..., start:end], phase_shifts[..., start:end]) x = self.final_layer(x) if self.final_layer is not None else x output = self.out_activation(x) if self.out_activation is not None else x return output def forward(self, x, conditioning_input): """Forward pass.""" frequencies, phase_shifts = self.mapping_network(conditioning_input) return self.forward_with_frequencies_phase_shifts(x, frequencies, phase_shifts)