jammmmm's picture
Add spar3d demo files
38dbec8
"""FiLM Siren MLP as per https://marcoamonteiro.github.io/pi-GAN-website/."""
from typing import Optional
import numpy as np
import torch
from torch import nn
def kaiming_leaky_init(m):
classname = m.__class__.__name__
if classname.find("Linear") != -1:
torch.nn.init.kaiming_normal_(
m.weight, a=0.2, mode="fan_in", nonlinearity="leaky_relu"
)
def frequency_init(freq):
def init(m):
with torch.no_grad():
if isinstance(m, nn.Linear):
num_input = m.weight.size(-1)
m.weight.uniform_(
-np.sqrt(6 / num_input) / freq, np.sqrt(6 / num_input) / freq
)
return init
def first_layer_film_sine_init(m):
with torch.no_grad():
if isinstance(m, nn.Linear):
num_input = m.weight.size(-1)
m.weight.uniform_(-1 / num_input, 1 / num_input)
class CustomMappingNetwork(nn.Module):
def __init__(self, in_features, map_hidden_layers, map_hidden_dim, map_output_dim):
super().__init__()
self.network = []
for _ in range(map_hidden_layers):
self.network.append(nn.Linear(in_features, map_hidden_dim))
self.network.append(nn.LeakyReLU(0.2, inplace=True))
in_features = map_hidden_dim
self.network.append(nn.Linear(map_hidden_dim, map_output_dim))
self.network = nn.Sequential(*self.network)
self.network.apply(kaiming_leaky_init)
with torch.no_grad():
self.network[-1].weight *= 0.25
def forward(self, z):
frequencies_offsets = self.network(z)
frequencies = frequencies_offsets[
..., : torch.div(frequencies_offsets.shape[-1], 2, rounding_mode="floor")
]
phase_shifts = frequencies_offsets[
..., torch.div(frequencies_offsets.shape[-1], 2, rounding_mode="floor") :
]
return frequencies, phase_shifts
class FiLMLayer(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.layer = nn.Linear(input_dim, hidden_dim)
def forward(self, x, freq, phase_shift):
x = self.layer(x)
freq = freq.expand_as(x)
phase_shift = phase_shift.expand_as(x)
return torch.sin(freq * x + phase_shift)
class FiLMSiren(nn.Module):
"""FiLM Conditioned Siren network."""
def __init__(
self,
in_dim: int,
hidden_layers: int,
hidden_features: int,
mapping_network_in_dim: int,
mapping_network_layers: int,
mapping_network_features: int,
out_dim: int,
outermost_linear: bool = False,
out_activation: Optional[nn.Module] = None,
) -> None:
super().__init__()
self.in_dim = in_dim
assert self.in_dim > 0
self.out_dim = out_dim if out_dim is not None else hidden_features
self.hidden_layers = hidden_layers
self.hidden_features = hidden_features
self.mapping_network_in_dim = mapping_network_in_dim
self.mapping_network_layers = mapping_network_layers
self.mapping_network_features = mapping_network_features
self.outermost_linear = outermost_linear
self.out_activation = out_activation
self.net = nn.ModuleList()
self.net.append(FiLMLayer(self.in_dim, self.hidden_features))
for _ in range(self.hidden_layers - 1):
self.net.append(FiLMLayer(self.hidden_features, self.hidden_features))
self.final_layer = None
if self.outermost_linear:
self.final_layer = nn.Linear(self.hidden_features, self.out_dim)
self.final_layer.apply(frequency_init(25))
else:
final_layer = FiLMLayer(self.hidden_features, self.out_dim)
self.net.append(final_layer)
self.mapping_network = CustomMappingNetwork(
in_features=self.mapping_network_in_dim,
map_hidden_layers=self.mapping_network_layers,
map_hidden_dim=self.mapping_network_features,
map_output_dim=(len(self.net)) * self.hidden_features * 2,
)
self.net.apply(frequency_init(25))
self.net[0].apply(first_layer_film_sine_init)
def forward_with_frequencies_phase_shifts(self, x, frequencies, phase_shifts):
"""Get conditiional frequencies and phase shifts from mapping network."""
frequencies = frequencies * 15 + 30
for index, layer in enumerate(self.net):
start = index * self.hidden_features
end = (index + 1) * self.hidden_features
x = layer(x, frequencies[..., start:end], phase_shifts[..., start:end])
x = self.final_layer(x) if self.final_layer is not None else x
output = self.out_activation(x) if self.out_activation is not None else x
return output
def forward(self, x, conditioning_input):
"""Forward pass."""
frequencies, phase_shifts = self.mapping_network(conditioning_input)
return self.forward_with_frequencies_phase_shifts(x, frequencies, phase_shifts)