jammmmm's picture
Add spar3d demo files
38dbec8
raw
history blame
2.63 kB
from dataclasses import dataclass, field
from typing import Dict, List, Optional
import torch
from jaxtyping import Float
from torch import Tensor
from spar3d.models.utils import BaseModule
from .field import RENIField
def _direction_from_coordinate(
coordinate: Float[Tensor, "*B 2"],
) -> Float[Tensor, "*B 3"]:
# OpenGL Convention
# +X Right
# +Y Up
# +Z Backward
u, v = coordinate.unbind(-1)
theta = (2 * torch.pi * u) - torch.pi
phi = torch.pi * v
dir = torch.stack(
[
theta.sin() * phi.sin(),
phi.cos(),
-1 * theta.cos() * phi.sin(),
],
-1,
)
return dir
def _get_sample_coordinates(
resolution: List[int], device: Optional[torch.device] = None
) -> Float[Tensor, "H W 2"]:
return torch.stack(
torch.meshgrid(
(torch.arange(resolution[1], device=device) + 0.5) / resolution[1],
(torch.arange(resolution[0], device=device) + 0.5) / resolution[0],
indexing="xy",
),
-1,
)
class RENIEnvMap(BaseModule):
@dataclass
class Config(BaseModule.Config):
reni_config: dict = field(default_factory=dict)
resolution: int = 128
cfg: Config
def configure(self):
self.field = RENIField(self.cfg.reni_config)
resolution = (self.cfg.resolution, self.cfg.resolution * 2)
sample_directions = _direction_from_coordinate(
_get_sample_coordinates(resolution)
)
self.img_shape = sample_directions.shape[:-1]
sample_directions_flat = sample_directions.view(-1, 3)
# Lastly these have y up but reni expects z up. Rotate 90 degrees on x axis
sample_directions_flat = torch.stack(
[
sample_directions_flat[:, 0],
-sample_directions_flat[:, 2],
sample_directions_flat[:, 1],
],
-1,
)
self.sample_directions = torch.nn.Parameter(
sample_directions_flat, requires_grad=False
)
def forward(
self,
latent_codes: Float[Tensor, "B latent_dim 3"],
rotation: Optional[Float[Tensor, "B 3 3"]] = None,
scale: Optional[Float[Tensor, "B"]] = None,
) -> Dict[str, Tensor]:
return {
k: v.view(latent_codes.shape[0], *self.img_shape, -1)
for k, v in self.field(
self.sample_directions.unsqueeze(0).repeat(latent_codes.shape[0], 1, 1),
latent_codes,
rotation=rotation,
scale=scale,
).items()
}