|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""RENI field""" |
|
|
|
import contextlib |
|
from dataclasses import dataclass |
|
from typing import Dict, Literal, Optional |
|
|
|
import torch |
|
from einops.layers.torch import Rearrange |
|
from jaxtyping import Float |
|
from torch import Tensor, nn |
|
|
|
from spar3d.models.network import get_activation_module, trunc_exp |
|
from spar3d.models.utils import BaseModule |
|
|
|
from .components.film_siren import FiLMSiren |
|
from .components.siren import Siren |
|
from .components.transformer_decoder import Decoder |
|
from .components.vn_layers import VNInvariant, VNLinear |
|
|
|
|
|
|
|
|
|
def expected_sin(x_means: torch.Tensor, x_vars: torch.Tensor) -> torch.Tensor: |
|
"""Computes the expected value of sin(y) where y ~ N(x_means, x_vars) |
|
|
|
Args: |
|
x_means: Mean values. |
|
x_vars: Variance of values. |
|
|
|
Returns: |
|
torch.Tensor: The expected value of sin. |
|
""" |
|
|
|
return torch.exp(-0.5 * x_vars) * torch.sin(x_means) |
|
|
|
|
|
class NeRFEncoding(torch.nn.Module): |
|
"""Multi-scale sinousoidal encodings. Support ``integrated positional encodings`` if covariances are provided. |
|
Each axis is encoded with frequencies ranging from 2^min_freq_exp to 2^max_freq_exp. |
|
|
|
Args: |
|
in_dim: Input dimension of tensor |
|
num_frequencies: Number of encoded frequencies per axis |
|
min_freq_exp: Minimum frequency exponent |
|
max_freq_exp: Maximum frequency exponent |
|
include_input: Append the input coordinate to the encoding |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_dim: int, |
|
num_frequencies: int, |
|
min_freq_exp: float, |
|
max_freq_exp: float, |
|
include_input: bool = False, |
|
off_axis: bool = False, |
|
) -> None: |
|
super().__init__() |
|
|
|
self.in_dim = in_dim |
|
self.num_frequencies = num_frequencies |
|
self.min_freq = min_freq_exp |
|
self.max_freq = max_freq_exp |
|
self.include_input = include_input |
|
|
|
self.off_axis = off_axis |
|
|
|
self.P = torch.tensor( |
|
[ |
|
[0.8506508, 0, 0.5257311], |
|
[0.809017, 0.5, 0.309017], |
|
[0.5257311, 0.8506508, 0], |
|
[1, 0, 0], |
|
[0.809017, 0.5, -0.309017], |
|
[0.8506508, 0, -0.5257311], |
|
[0.309017, 0.809017, -0.5], |
|
[0, 0.5257311, -0.8506508], |
|
[0.5, 0.309017, -0.809017], |
|
[0, 1, 0], |
|
[-0.5257311, 0.8506508, 0], |
|
[-0.309017, 0.809017, -0.5], |
|
[0, 0.5257311, 0.8506508], |
|
[-0.309017, 0.809017, 0.5], |
|
[0.309017, 0.809017, 0.5], |
|
[0.5, 0.309017, 0.809017], |
|
[0.5, -0.309017, 0.809017], |
|
[0, 0, 1], |
|
[-0.5, 0.309017, 0.809017], |
|
[-0.809017, 0.5, 0.309017], |
|
[-0.809017, 0.5, -0.309017], |
|
] |
|
).T |
|
|
|
def get_out_dim(self) -> int: |
|
if self.in_dim is None: |
|
raise ValueError("Input dimension has not been set") |
|
out_dim = self.in_dim * self.num_frequencies * 2 |
|
|
|
if self.off_axis: |
|
out_dim = self.P.shape[1] * self.num_frequencies * 2 |
|
|
|
if self.include_input: |
|
out_dim += self.in_dim |
|
return out_dim |
|
|
|
def forward( |
|
self, |
|
in_tensor: Float[Tensor, "*b input_dim"], |
|
covs: Optional[Float[Tensor, "*b input_dim input_dim"]] = None, |
|
) -> Float[Tensor, "*b output_dim"]: |
|
"""Calculates NeRF encoding. If covariances are provided the encodings will be integrated as proposed |
|
in mip-NeRF. |
|
|
|
Args: |
|
in_tensor: For best performance, the input tensor should be between 0 and 1. |
|
covs: Covariances of input points. |
|
Returns: |
|
Output values will be between -1 and 1 |
|
""" |
|
|
|
|
|
freqs = 2 ** torch.linspace( |
|
self.min_freq, self.max_freq, self.num_frequencies |
|
).to(in_tensor.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.off_axis: |
|
scaled_inputs = ( |
|
torch.matmul(in_tensor, self.P.to(in_tensor.device))[..., None] * freqs |
|
) |
|
else: |
|
scaled_inputs = ( |
|
in_tensor[..., None] * freqs |
|
) |
|
scaled_inputs = scaled_inputs.view( |
|
*scaled_inputs.shape[:-2], -1 |
|
) |
|
|
|
if covs is None: |
|
encoded_inputs = torch.sin( |
|
torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1) |
|
) |
|
else: |
|
input_var = ( |
|
torch.diagonal(covs, dim1=-2, dim2=-1)[..., :, None] |
|
* freqs[None, :] ** 2 |
|
) |
|
input_var = input_var.reshape((*input_var.shape[:-2], -1)) |
|
encoded_inputs = expected_sin( |
|
torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1), |
|
torch.cat(2 * [input_var], dim=-1), |
|
) |
|
|
|
if self.include_input: |
|
encoded_inputs = torch.cat([encoded_inputs, in_tensor], dim=-1) |
|
return encoded_inputs |
|
|
|
|
|
class RENIField(BaseModule): |
|
@dataclass |
|
class Config(BaseModule.Config): |
|
"""Configuration for model instantiation""" |
|
|
|
fixed_decoder: bool = False |
|
"""Whether to fix the decoder weights""" |
|
equivariance: str = "SO2" |
|
"""Type of equivariance to use: None, SO2, SO3""" |
|
axis_of_invariance: str = "y" |
|
"""Which axis should SO2 equivariance be invariant to: x, y, z""" |
|
invariant_function: str = "GramMatrix" |
|
"""Type of invariant function to use: GramMatrix, VN""" |
|
conditioning: str = "Concat" |
|
"""Type of conditioning to use: FiLM, Concat, Attention""" |
|
positional_encoding: str = "NeRF" |
|
"""Type of positional encoding to use. Currently only NeRF is supported""" |
|
encoded_input: str = "Directions" |
|
"""Type of input to encode: None, Directions, Conditioning, Both""" |
|
latent_dim: int = 36 |
|
"""Dimensionality of latent code, N for a latent code size of (N x 3)""" |
|
hidden_layers: int = 3 |
|
"""Number of hidden layers""" |
|
hidden_features: int = 128 |
|
"""Number of hidden features""" |
|
mapping_layers: int = 3 |
|
"""Number of mapping layers""" |
|
mapping_features: int = 128 |
|
"""Number of mapping features""" |
|
num_attention_heads: int = 8 |
|
"""Number of attention heads""" |
|
num_attention_layers: int = 3 |
|
"""Number of attention layers""" |
|
out_features: int = 3 |
|
"""Number of output features""" |
|
last_layer_linear: bool = False |
|
"""Whether to use a linear layer as the last layer""" |
|
output_activation: str = "exp" |
|
"""Activation function for output layer: sigmoid, tanh, relu, exp, None""" |
|
first_omega_0: float = 30.0 |
|
"""Omega_0 for first layer""" |
|
hidden_omega_0: float = 30.0 |
|
"""Omega_0 for hidden layers""" |
|
fixed_decoder: bool = False |
|
"""Whether to fix the decoder weights""" |
|
old_implementation: bool = False |
|
"""Whether to match implementation of old RENI, when using old checkpoints""" |
|
|
|
cfg: Config |
|
|
|
def configure(self): |
|
self.equivariance = self.cfg.equivariance |
|
self.conditioning = self.cfg.conditioning |
|
self.latent_dim = self.cfg.latent_dim |
|
self.hidden_layers = self.cfg.hidden_layers |
|
self.hidden_features = self.cfg.hidden_features |
|
self.mapping_layers = self.cfg.mapping_layers |
|
self.mapping_features = self.cfg.mapping_features |
|
self.out_features = self.cfg.out_features |
|
self.last_layer_linear = self.cfg.last_layer_linear |
|
self.output_activation = self.cfg.output_activation |
|
self.first_omega_0 = self.cfg.first_omega_0 |
|
self.hidden_omega_0 = self.cfg.hidden_omega_0 |
|
self.old_implementation = self.cfg.old_implementation |
|
self.axis_of_invariance = ["x", "y", "z"].index(self.cfg.axis_of_invariance) |
|
|
|
self.fixed_decoder = self.cfg.fixed_decoder |
|
if self.cfg.invariant_function == "GramMatrix": |
|
self.invariant_function = self.gram_matrix_invariance |
|
else: |
|
self.vn_proj_in = nn.Sequential( |
|
Rearrange("... c -> ... 1 c"), |
|
VNLinear(dim_in=1, dim_out=1, bias_epsilon=0), |
|
) |
|
dim_coor = 2 if self.cfg.equivariance == "SO2" else 3 |
|
self.vn_invar = VNInvariant(dim=1, dim_coor=dim_coor) |
|
self.invariant_function = self.vn_invariance |
|
|
|
self.network = self.setup_network() |
|
|
|
if self.fixed_decoder: |
|
for param in self.network.parameters(): |
|
param.requires_grad = False |
|
|
|
if self.cfg.invariant_function == "VN": |
|
for param in self.vn_proj_in.parameters(): |
|
param.requires_grad = False |
|
for param in self.vn_invar.parameters(): |
|
param.requires_grad = False |
|
|
|
@contextlib.contextmanager |
|
def hold_decoder_fixed(self): |
|
"""Context manager to fix the decoder weights |
|
|
|
Example usage: |
|
``` |
|
with instance_of_RENIField.hold_decoder_fixed(): |
|
# do stuff |
|
``` |
|
""" |
|
prev_state_network = { |
|
name: p.requires_grad for name, p in self.network.named_parameters() |
|
} |
|
for param in self.network.parameters(): |
|
param.requires_grad = False |
|
if self.cfg.invariant_function == "VN": |
|
prev_state_proj_in = { |
|
k: p.requires_grad for k, p in self.vn_proj_in.named_parameters() |
|
} |
|
prev_state_invar = { |
|
k: p.requires_grad for k, p in self.vn_invar.named_parameters() |
|
} |
|
for param in self.vn_proj_in.parameters(): |
|
param.requires_grad = False |
|
for param in self.vn_invar.parameters(): |
|
param.requires_grad = False |
|
|
|
prev_decoder_state = self.fixed_decoder |
|
self.fixed_decoder = True |
|
try: |
|
yield |
|
finally: |
|
|
|
for name, param in self.network.named_parameters(): |
|
param.requires_grad = prev_state_network[name] |
|
if self.cfg.invariant_function == "VN": |
|
for name, param in self.vn_proj_in.named_parameters(): |
|
param.requires_grad_(prev_state_proj_in[name]) |
|
for name, param in self.vn_invar.named_parameters(): |
|
param.requires_grad_(prev_state_invar[name]) |
|
self.fixed_decoder = prev_decoder_state |
|
|
|
def vn_invariance( |
|
self, |
|
Z: Float[Tensor, "B latent_dim 3"], |
|
D: Float[Tensor, "B num_rays 3"], |
|
equivariance: Literal["None", "SO2", "SO3"] = "SO2", |
|
axis_of_invariance: int = 1, |
|
): |
|
"""Generates a batched invariant representation from latent code Z and direction coordinates D. |
|
|
|
Args: |
|
Z: [B, latent_dim, 3] - Latent code. |
|
D: [B num_rays, 3] - Direction coordinates. |
|
equivariance: The type of equivariance to use. Options are 'None', 'SO2', 'SO3'. |
|
axis_of_invariance: The axis of rotation invariance. Should be 0 (x-axis), 1 (y-axis), or 2 (z-axis). |
|
|
|
Returns: |
|
Tuple[Tensor, Tensor]: directional_input, conditioning_input |
|
""" |
|
assert 0 <= axis_of_invariance < 3, "axis_of_invariance should be 0, 1, or 2." |
|
other_axes = [i for i in range(3) if i != axis_of_invariance] |
|
|
|
B, latent_dim, _ = Z.shape |
|
_, num_rays, _ = D.shape |
|
|
|
if equivariance == "None": |
|
|
|
innerprod = torch.sum( |
|
Z.unsqueeze(1) * D.unsqueeze(2), dim=-1 |
|
) |
|
z_input = ( |
|
Z.flatten(start_dim=1).unsqueeze(1).expand(B, num_rays, latent_dim * 3) |
|
) |
|
return innerprod, z_input |
|
|
|
if equivariance == "SO2": |
|
z_other = torch.stack( |
|
(Z[..., other_axes[0]], Z[..., other_axes[1]]), -1 |
|
) |
|
d_other = torch.stack( |
|
(D[..., other_axes[0]], D[..., other_axes[1]]), -1 |
|
).unsqueeze(2) |
|
d_other = d_other.expand( |
|
B, num_rays, latent_dim, 2 |
|
) |
|
|
|
z_other_emb = self.vn_proj_in(z_other) |
|
z_other_invar = self.vn_invar(z_other_emb) |
|
|
|
|
|
z_invar = Z[..., axis_of_invariance].unsqueeze(-1) |
|
|
|
|
|
|
|
|
|
innerprod = (z_other.unsqueeze(1) * d_other).sum( |
|
dim=-1 |
|
) |
|
|
|
|
|
d_other_norm = torch.sqrt( |
|
D[..., other_axes[0]] ** 2 + D[..., other_axes[1]] ** 2 |
|
).unsqueeze(-1) |
|
|
|
|
|
d_invar = D[..., axis_of_invariance].unsqueeze(-1) |
|
|
|
directional_input = torch.cat( |
|
(innerprod, d_invar, d_other_norm), -1 |
|
) |
|
conditioning_input = ( |
|
torch.cat((z_other_invar, z_invar), dim=-1) |
|
.flatten(1) |
|
.unsqueeze(1) |
|
.expand(B, num_rays, latent_dim * 3) |
|
) |
|
|
|
return directional_input, conditioning_input |
|
|
|
if equivariance == "SO3": |
|
z = self.vn_proj_in(Z) |
|
z_invar = self.vn_invar(z) |
|
conditioning_input = ( |
|
z_invar.flatten(1).unsqueeze(1).expand(B, num_rays, latent_dim) |
|
) |
|
|
|
|
|
innerprod = torch.sum( |
|
Z.unsqueeze(1) * D.unsqueeze(2), dim=-1 |
|
) |
|
return innerprod, conditioning_input |
|
|
|
def gram_matrix_invariance( |
|
self, |
|
Z: Float[Tensor, "B latent_dim 3"], |
|
D: Float[Tensor, "B num_rays 3"], |
|
equivariance: Literal["None", "SO2", "SO3"] = "SO2", |
|
axis_of_invariance: int = 1, |
|
): |
|
"""Generates an invariant representation from latent code Z and direction coordinates D. |
|
|
|
Args: |
|
Z (torch.Tensor): Latent code (B x latent_dim x 3) |
|
D (torch.Tensor): Direction coordinates (B x num_rays x 3) |
|
equivariance (str): Type of equivariance to use. Options are 'none', 'SO2', and 'SO3' |
|
axis_of_invariance (int): The axis of rotation invariance. Should be 0 (x-axis), 1 (y-axis), or 2 (z-axis). |
|
Default is 1 (y-axis). |
|
Returns: |
|
torch.Tensor: Invariant representation |
|
""" |
|
assert 0 <= axis_of_invariance < 3, "axis_of_invariance should be 0, 1, or 2." |
|
other_axes = [i for i in range(3) if i != axis_of_invariance] |
|
|
|
B, latent_dim, _ = Z.shape |
|
_, num_rays, _ = D.shape |
|
|
|
if equivariance == "None": |
|
|
|
innerprod = torch.sum( |
|
Z.unsqueeze(1) * D.unsqueeze(2), dim=-1 |
|
) |
|
z_input = ( |
|
Z.flatten(start_dim=1).unsqueeze(1).expand(B, num_rays, latent_dim * 3) |
|
) |
|
return innerprod, z_input |
|
|
|
if equivariance == "SO2": |
|
|
|
z_other = torch.stack( |
|
(Z[..., other_axes[0]], Z[..., other_axes[1]]), -1 |
|
) |
|
d_other = torch.stack( |
|
(D[..., other_axes[0]], D[..., other_axes[1]]), -1 |
|
).unsqueeze(2) |
|
d_other = d_other.expand( |
|
B, num_rays, latent_dim, 2 |
|
) |
|
|
|
|
|
G = torch.bmm(z_other, torch.transpose(z_other, 1, 2)) |
|
|
|
|
|
z_other_invar = G.flatten(start_dim=1) |
|
|
|
|
|
z_invar = Z[..., axis_of_invariance] |
|
|
|
|
|
innerprod = (z_other.unsqueeze(1) * d_other).sum( |
|
dim=-1 |
|
) |
|
|
|
|
|
d_other_norm = torch.sqrt( |
|
D[..., other_axes[0]] ** 2 + D[..., other_axes[1]] ** 2 |
|
).unsqueeze(-1) |
|
|
|
|
|
d_invar = D[..., axis_of_invariance].unsqueeze(-1) |
|
|
|
if not self.old_implementation: |
|
directional_input = torch.cat( |
|
(innerprod, d_invar, d_other_norm), -1 |
|
) |
|
conditioning_input = ( |
|
torch.cat((z_other_invar, z_invar), -1) |
|
.unsqueeze(1) |
|
.expand(B, num_rays, latent_dim * 3) |
|
) |
|
else: |
|
|
|
z_other_invar = z_other_invar.unsqueeze(1).expand(B, num_rays, -1) |
|
z_invar = z_invar.unsqueeze(1).expand(B, num_rays, -1) |
|
return torch.cat( |
|
(innerprod, z_other_invar, d_other_norm, z_invar, d_invar), 1 |
|
) |
|
|
|
return directional_input, conditioning_input |
|
|
|
if equivariance == "SO3": |
|
G = Z @ torch.transpose(Z, 1, 2) |
|
innerprod = torch.sum( |
|
Z.unsqueeze(1) * D.unsqueeze(2), dim=-1 |
|
) |
|
z_invar = ( |
|
G.flatten(start_dim=1).unsqueeze(1).expand(B, num_rays, -1) |
|
) |
|
return innerprod, z_invar |
|
|
|
def setup_network(self): |
|
"""Sets up the network architecture""" |
|
base_input_dims = { |
|
"VN": { |
|
"None": { |
|
"direction": self.latent_dim, |
|
"conditioning": self.latent_dim * 3, |
|
}, |
|
"SO2": { |
|
"direction": self.latent_dim + 2, |
|
"conditioning": self.latent_dim * 3, |
|
}, |
|
"SO3": { |
|
"direction": self.latent_dim, |
|
"conditioning": self.latent_dim * 3, |
|
}, |
|
}, |
|
"GramMatrix": { |
|
"None": { |
|
"direction": self.latent_dim, |
|
"conditioning": self.latent_dim * 3, |
|
}, |
|
"SO2": { |
|
"direction": self.latent_dim + 2, |
|
"conditioning": self.latent_dim**2 + self.latent_dim, |
|
}, |
|
"SO3": { |
|
"direction": self.latent_dim, |
|
"conditioning": self.latent_dim**2, |
|
}, |
|
}, |
|
} |
|
|
|
|
|
input_types = ["direction", "conditioning"] |
|
input_dims = { |
|
key: base_input_dims[self.cfg.invariant_function][self.cfg.equivariance][ |
|
key |
|
] |
|
for key in input_types |
|
} |
|
|
|
|
|
def create_nerf_encoding(in_dim): |
|
return NeRFEncoding( |
|
in_dim=in_dim, |
|
num_frequencies=2, |
|
min_freq_exp=0.0, |
|
max_freq_exp=2.0, |
|
include_input=True, |
|
) |
|
|
|
|
|
encoding_setup = { |
|
"None": [], |
|
"Conditioning": ["conditioning"], |
|
"Directions": ["direction"], |
|
"Both": ["direction", "conditioning"], |
|
} |
|
|
|
|
|
for input_type in encoding_setup.get(self.cfg.encoded_input, []): |
|
|
|
setattr( |
|
self, |
|
f"{input_type}_encoding", |
|
create_nerf_encoding(input_dims[input_type]), |
|
) |
|
input_dims[input_type] = getattr( |
|
self, f"{input_type}_encoding" |
|
).get_out_dim() |
|
|
|
output_activation = get_activation_module(self.cfg.output_activation) |
|
|
|
network = None |
|
if self.conditioning == "Concat": |
|
network = Siren( |
|
in_dim=input_dims["direction"] + input_dims["conditioning"], |
|
hidden_layers=self.hidden_layers, |
|
hidden_features=self.hidden_features, |
|
out_dim=self.out_features, |
|
outermost_linear=self.last_layer_linear, |
|
first_omega_0=self.first_omega_0, |
|
hidden_omega_0=self.hidden_omega_0, |
|
out_activation=output_activation, |
|
) |
|
elif self.conditioning == "FiLM": |
|
network = FiLMSiren( |
|
in_dim=input_dims["direction"], |
|
hidden_layers=self.hidden_layers, |
|
hidden_features=self.hidden_features, |
|
mapping_network_in_dim=input_dims["conditioning"], |
|
mapping_network_layers=self.mapping_layers, |
|
mapping_network_features=self.mapping_features, |
|
out_dim=self.out_features, |
|
outermost_linear=True, |
|
out_activation=output_activation, |
|
) |
|
elif self.conditioning == "Attention": |
|
|
|
network = Decoder( |
|
in_dim=input_dims["direction"], |
|
conditioning_input_dim=input_dims["conditioning"], |
|
hidden_features=self.cfg.hidden_features, |
|
num_heads=self.cfg.num_attention_heads, |
|
num_layers=self.cfg.num_attention_layers, |
|
out_activation=output_activation, |
|
) |
|
assert network is not None, "unknown conditioning type" |
|
return network |
|
|
|
def apply_positional_encoding(self, directional_input, conditioning_input): |
|
|
|
if self.cfg.encoded_input == "Conditioning": |
|
conditioning_input = self.conditioning_encoding( |
|
conditioning_input |
|
) |
|
elif self.cfg.encoded_input == "Directions": |
|
directional_input = self.direction_encoding( |
|
directional_input |
|
) |
|
elif self.cfg.encoded_input == "Both": |
|
directional_input = self.direction_encoding(directional_input) |
|
conditioning_input = self.conditioning_encoding(conditioning_input) |
|
|
|
return directional_input, conditioning_input |
|
|
|
def get_outputs( |
|
self, |
|
rays_d: Float[Tensor, "batch num_rays 3"], |
|
latent_codes: Float[Tensor, "batch_size latent_dim 3"], |
|
rotation: Optional[Float[Tensor, "batch_size 3 3"]] = None, |
|
scale: Optional[Float[Tensor, "batch_size"]] = None, |
|
) -> Dict[str, Tensor]: |
|
"""Returns the outputs of the field. |
|
|
|
Args: |
|
ray_samples: [batch_size num_rays 3] |
|
latent_codes: [batch_size, latent_dim, 3] |
|
rotation: [batch_size, 3, 3] |
|
scale: [batch_size] |
|
""" |
|
if rotation is not None: |
|
if len(rotation.shape) == 3: |
|
|
|
latent_codes = torch.einsum( |
|
"bik,blk->bli", |
|
rotation, |
|
latent_codes, |
|
) |
|
else: |
|
raise NotImplementedError( |
|
"Unsupported rotation shape. Expected [batch_size, 3, 3]." |
|
) |
|
|
|
B, num_rays, _ = rays_d.shape |
|
_, latent_dim, _ = latent_codes.shape |
|
|
|
if not self.old_implementation: |
|
directional_input, conditioning_input = self.invariant_function( |
|
latent_codes, |
|
rays_d, |
|
equivariance=self.equivariance, |
|
axis_of_invariance=self.axis_of_invariance, |
|
) |
|
|
|
if self.cfg.positional_encoding == "NeRF": |
|
directional_input, conditioning_input = self.apply_positional_encoding( |
|
directional_input, conditioning_input |
|
) |
|
|
|
if self.conditioning == "Concat": |
|
model_outputs = self.network( |
|
torch.cat((directional_input, conditioning_input), dim=-1).reshape( |
|
B * num_rays, -1 |
|
) |
|
).view(B, num_rays, 3) |
|
elif self.conditioning == "FiLM": |
|
model_outputs = self.network( |
|
directional_input.reshape(B * num_rays, -1), |
|
conditioning_input.reshape(B * num_rays, -1), |
|
).view(B, num_rays, 3) |
|
elif self.conditioning == "Attention": |
|
model_outputs = self.network( |
|
directional_input.reshape(B * num_rays, -1), |
|
conditioning_input.reshape(B * num_rays, -1), |
|
).view(B, num_rays, 3) |
|
else: |
|
|
|
directions = torch.stack( |
|
(rays_d[..., 0], rays_d[..., 2], rays_d[..., 1]), -1 |
|
) |
|
model_input = self.invariant_function( |
|
latent_codes, |
|
directions, |
|
equivariance=self.equivariance, |
|
axis_of_invariance=self.axis_of_invariance, |
|
) |
|
|
|
model_outputs = self.network(model_input.view(B * num_rays, -1)).view( |
|
B, num_rays, 3 |
|
) |
|
|
|
outputs = {} |
|
|
|
if scale is not None: |
|
scale = trunc_exp(scale) |
|
model_outputs = model_outputs * scale.view(-1, 1, 1) |
|
|
|
outputs["rgb"] = model_outputs |
|
|
|
return outputs |
|
|
|
def forward( |
|
self, |
|
rays_d: Float[Tensor, "batch num_rays 3"], |
|
latent_codes: Float[Tensor, "batch_size latent_dim 3"], |
|
rotation: Optional[Float[Tensor, "batch_size 3 3"]] = None, |
|
scale: Optional[Float[Tensor, "batch_size"]] = None, |
|
) -> Dict[str, Tensor]: |
|
"""Evaluates spherical field for a given ray bundle and rotation. |
|
|
|
Args: |
|
ray_samples: [B num_rays 3] |
|
latent_codes: [B, num_rays, latent_dim, 3] |
|
rotation: [batch_size, 3, 3] |
|
scale: [batch_size] |
|
|
|
Returns: |
|
Dict[str, Tensor]: A dictionary containing the outputs of the field. |
|
""" |
|
return self.get_outputs( |
|
rays_d=rays_d, |
|
latent_codes=latent_codes, |
|
rotation=rotation, |
|
scale=scale, |
|
) |
|
|