"""Siren MLP https://www.vincentsitzmann.com/siren/""" from typing import Optional import numpy as np import torch from torch import nn class SineLayer(nn.Module): """ Sine layer for the SIREN network. """ def __init__( self, in_features, out_features, bias=True, is_first=False, omega_0=30.0 ): super().__init__() self.omega_0 = omega_0 self.is_first = is_first self.in_features = in_features self.linear = nn.Linear(in_features, out_features, bias=bias) self.init_weights() def init_weights(self): with torch.no_grad(): if self.is_first: self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features) else: self.linear.weight.uniform_( -np.sqrt(6 / self.in_features) / self.omega_0, np.sqrt(6 / self.in_features) / self.omega_0, ) def forward(self, x): return torch.sin(self.omega_0 * self.linear(x)) class Siren(nn.Module): """Siren network. Args: in_dim: Input layer dimension num_layers: Number of network layers layer_width: Width of each MLP layer out_dim: Output layer dimension. Uses layer_width if None. activation: intermediate layer activation function. out_activation: output activation function. """ def __init__( self, in_dim: int, hidden_layers: int, hidden_features: int, out_dim: Optional[int] = None, outermost_linear: bool = False, first_omega_0: float = 30, hidden_omega_0: float = 30, out_activation: Optional[nn.Module] = None, ) -> None: super().__init__() self.in_dim = in_dim assert self.in_dim > 0 self.out_dim = out_dim if out_dim is not None else hidden_features self.outermost_linear = outermost_linear self.first_omega_0 = first_omega_0 self.hidden_omega_0 = hidden_omega_0 self.hidden_layers = hidden_layers self.layer_width = hidden_features self.out_activation = out_activation self.net = [] self.net.append( SineLayer(in_dim, hidden_features, is_first=True, omega_0=first_omega_0) ) for _ in range(hidden_layers): self.net.append( SineLayer( hidden_features, hidden_features, is_first=False, omega_0=hidden_omega_0, ) ) if outermost_linear: final_layer = nn.Linear(hidden_features, self.out_dim) with torch.no_grad(): final_layer.weight.uniform_( -np.sqrt(6 / hidden_features) / hidden_omega_0, np.sqrt(6 / hidden_features) / hidden_omega_0, ) self.net.append(final_layer) else: self.net.append( SineLayer( hidden_features, self.out_dim, is_first=False, omega_0=hidden_omega_0, ) ) if self.out_activation is not None: self.net.append(self.out_activation) self.net = nn.Sequential(*self.net) def forward(self, model_input): """Forward pass through the network""" output = self.net(model_input) return output