|
from dataclasses import dataclass, field |
|
from typing import Dict, List, Optional |
|
|
|
import torch |
|
from jaxtyping import Float |
|
from torch import Tensor |
|
|
|
from spar3d.models.utils import BaseModule |
|
|
|
from .field import RENIField |
|
|
|
|
|
def _direction_from_coordinate( |
|
coordinate: Float[Tensor, "*B 2"], |
|
) -> Float[Tensor, "*B 3"]: |
|
|
|
|
|
|
|
|
|
|
|
u, v = coordinate.unbind(-1) |
|
theta = (2 * torch.pi * u) - torch.pi |
|
phi = torch.pi * v |
|
|
|
dir = torch.stack( |
|
[ |
|
theta.sin() * phi.sin(), |
|
phi.cos(), |
|
-1 * theta.cos() * phi.sin(), |
|
], |
|
-1, |
|
) |
|
return dir |
|
|
|
|
|
def _get_sample_coordinates( |
|
resolution: List[int], device: Optional[torch.device] = None |
|
) -> Float[Tensor, "H W 2"]: |
|
return torch.stack( |
|
torch.meshgrid( |
|
(torch.arange(resolution[1], device=device) + 0.5) / resolution[1], |
|
(torch.arange(resolution[0], device=device) + 0.5) / resolution[0], |
|
indexing="xy", |
|
), |
|
-1, |
|
) |
|
|
|
|
|
class RENIEnvMap(BaseModule): |
|
@dataclass |
|
class Config(BaseModule.Config): |
|
reni_config: dict = field(default_factory=dict) |
|
resolution: int = 128 |
|
|
|
cfg: Config |
|
|
|
def configure(self): |
|
self.field = RENIField(self.cfg.reni_config) |
|
resolution = (self.cfg.resolution, self.cfg.resolution * 2) |
|
sample_directions = _direction_from_coordinate( |
|
_get_sample_coordinates(resolution) |
|
) |
|
self.img_shape = sample_directions.shape[:-1] |
|
|
|
sample_directions_flat = sample_directions.view(-1, 3) |
|
|
|
sample_directions_flat = torch.stack( |
|
[ |
|
sample_directions_flat[:, 0], |
|
-sample_directions_flat[:, 2], |
|
sample_directions_flat[:, 1], |
|
], |
|
-1, |
|
) |
|
self.sample_directions = torch.nn.Parameter( |
|
sample_directions_flat, requires_grad=False |
|
) |
|
|
|
def forward( |
|
self, |
|
latent_codes: Float[Tensor, "B latent_dim 3"], |
|
rotation: Optional[Float[Tensor, "B 3 3"]] = None, |
|
scale: Optional[Float[Tensor, "B"]] = None, |
|
) -> Dict[str, Tensor]: |
|
return { |
|
k: v.view(latent_codes.shape[0], *self.img_shape, -1) |
|
for k, v in self.field( |
|
self.sample_directions.unsqueeze(0).repeat(latent_codes.shape[0], 1, 1), |
|
latent_codes, |
|
rotation=rotation, |
|
scale=scale, |
|
).items() |
|
} |
|
|