File size: 2,629 Bytes
38dbec8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from dataclasses import dataclass, field
from typing import Dict, List, Optional

import torch
from jaxtyping import Float
from torch import Tensor

from spar3d.models.utils import BaseModule

from .field import RENIField


def _direction_from_coordinate(
    coordinate: Float[Tensor, "*B 2"],
) -> Float[Tensor, "*B 3"]:
    # OpenGL Convention
    # +X Right
    # +Y Up
    # +Z Backward

    u, v = coordinate.unbind(-1)
    theta = (2 * torch.pi * u) - torch.pi
    phi = torch.pi * v

    dir = torch.stack(
        [
            theta.sin() * phi.sin(),
            phi.cos(),
            -1 * theta.cos() * phi.sin(),
        ],
        -1,
    )
    return dir


def _get_sample_coordinates(
    resolution: List[int], device: Optional[torch.device] = None
) -> Float[Tensor, "H W 2"]:
    return torch.stack(
        torch.meshgrid(
            (torch.arange(resolution[1], device=device) + 0.5) / resolution[1],
            (torch.arange(resolution[0], device=device) + 0.5) / resolution[0],
            indexing="xy",
        ),
        -1,
    )


class RENIEnvMap(BaseModule):
    @dataclass
    class Config(BaseModule.Config):
        reni_config: dict = field(default_factory=dict)
        resolution: int = 128

    cfg: Config

    def configure(self):
        self.field = RENIField(self.cfg.reni_config)
        resolution = (self.cfg.resolution, self.cfg.resolution * 2)
        sample_directions = _direction_from_coordinate(
            _get_sample_coordinates(resolution)
        )
        self.img_shape = sample_directions.shape[:-1]

        sample_directions_flat = sample_directions.view(-1, 3)
        # Lastly these have y up but reni expects z up. Rotate 90 degrees on x axis
        sample_directions_flat = torch.stack(
            [
                sample_directions_flat[:, 0],
                -sample_directions_flat[:, 2],
                sample_directions_flat[:, 1],
            ],
            -1,
        )
        self.sample_directions = torch.nn.Parameter(
            sample_directions_flat, requires_grad=False
        )

    def forward(
        self,
        latent_codes: Float[Tensor, "B latent_dim 3"],
        rotation: Optional[Float[Tensor, "B 3 3"]] = None,
        scale: Optional[Float[Tensor, "B"]] = None,
    ) -> Dict[str, Tensor]:
        return {
            k: v.view(latent_codes.shape[0], *self.img_shape, -1)
            for k, v in self.field(
                self.sample_directions.unsqueeze(0).repeat(latent_codes.shape[0], 1, 1),
                latent_codes,
                rotation=rotation,
                scale=scale,
            ).items()
        }