File size: 2,629 Bytes
38dbec8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
from dataclasses import dataclass, field
from typing import Dict, List, Optional
import torch
from jaxtyping import Float
from torch import Tensor
from spar3d.models.utils import BaseModule
from .field import RENIField
def _direction_from_coordinate(
coordinate: Float[Tensor, "*B 2"],
) -> Float[Tensor, "*B 3"]:
# OpenGL Convention
# +X Right
# +Y Up
# +Z Backward
u, v = coordinate.unbind(-1)
theta = (2 * torch.pi * u) - torch.pi
phi = torch.pi * v
dir = torch.stack(
[
theta.sin() * phi.sin(),
phi.cos(),
-1 * theta.cos() * phi.sin(),
],
-1,
)
return dir
def _get_sample_coordinates(
resolution: List[int], device: Optional[torch.device] = None
) -> Float[Tensor, "H W 2"]:
return torch.stack(
torch.meshgrid(
(torch.arange(resolution[1], device=device) + 0.5) / resolution[1],
(torch.arange(resolution[0], device=device) + 0.5) / resolution[0],
indexing="xy",
),
-1,
)
class RENIEnvMap(BaseModule):
@dataclass
class Config(BaseModule.Config):
reni_config: dict = field(default_factory=dict)
resolution: int = 128
cfg: Config
def configure(self):
self.field = RENIField(self.cfg.reni_config)
resolution = (self.cfg.resolution, self.cfg.resolution * 2)
sample_directions = _direction_from_coordinate(
_get_sample_coordinates(resolution)
)
self.img_shape = sample_directions.shape[:-1]
sample_directions_flat = sample_directions.view(-1, 3)
# Lastly these have y up but reni expects z up. Rotate 90 degrees on x axis
sample_directions_flat = torch.stack(
[
sample_directions_flat[:, 0],
-sample_directions_flat[:, 2],
sample_directions_flat[:, 1],
],
-1,
)
self.sample_directions = torch.nn.Parameter(
sample_directions_flat, requires_grad=False
)
def forward(
self,
latent_codes: Float[Tensor, "B latent_dim 3"],
rotation: Optional[Float[Tensor, "B 3 3"]] = None,
scale: Optional[Float[Tensor, "B"]] = None,
) -> Dict[str, Tensor]:
return {
k: v.view(latent_codes.shape[0], *self.img_shape, -1)
for k, v in self.field(
self.sample_directions.unsqueeze(0).repeat(latent_codes.shape[0], 1, 1),
latent_codes,
rotation=rotation,
scale=scale,
).items()
}
|