File size: 3,584 Bytes
38dbec8
64fccd8
38dbec8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64fccd8
38dbec8
 
 
 
 
 
 
 
 
64fccd8
38dbec8
 
64fccd8
 
 
 
38dbec8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from dataclasses import dataclass, field
from typing import Any, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from jaxtyping import Float
from torch import Tensor

from spar3d.models.illumination.reni.env_map import RENIEnvMap
from spar3d.models.utils import BaseModule


def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
    assert d6.shape[-1] == 6, "Input tensor must have shape (..., 6)"

    def proj_u2a(u, a):
        r"""
        u: batch x 3
        a: batch x 3
        """
        inner_prod = torch.sum(u * a, dim=-1, keepdim=True)
        norm2 = torch.sum(u**2, dim=-1, keepdim=True)
        norm2 = torch.clamp(norm2, min=1e-8)
        factor = inner_prod / (norm2 + 1e-10)
        return factor * u

    x_raw, y_raw = d6[..., :3], d6[..., 3:]

    x = F.normalize(x_raw, dim=-1)
    y = F.normalize(y_raw - proj_u2a(x, y_raw), dim=-1)
    z = torch.cross(x, y, dim=-1)

    return torch.stack((x, y, z), dim=-1)


class ReniLatentCodeEstimator(BaseModule):
    @dataclass
    class Config(BaseModule.Config):
        triplane_features: int = 40

        n_layers: int = 5
        hidden_features: int = 512
        activation: str = "relu"

        pool: str = "mean"

        reni_env_config: dict = field(default_factory=dict)

    cfg: Config

    def configure(self):
        layers = []
        cur_features = self.cfg.triplane_features * 3
        for _ in range(self.cfg.n_layers):
            layers.append(
                nn.Conv2d(
                    cur_features,
                    self.cfg.hidden_features,
                    kernel_size=3,
                    padding=0,
                    stride=2,
                )
            )
            layers.append(self.make_activation(self.cfg.activation))

            cur_features = self.cfg.hidden_features

        self.layers = nn.Sequential(*layers)

        self.reni_env_map = RENIEnvMap(self.cfg.reni_env_config)
        self.latent_dim = self.reni_env_map.field.latent_dim

        self.fc_latents = nn.Linear(self.cfg.hidden_features, self.latent_dim * 3)
        nn.init.normal_(self.fc_latents.weight, mean=0.0, std=0.3)

        self.fc_rotations = nn.Linear(self.cfg.hidden_features, 6)
        nn.init.constant_(self.fc_rotations.bias, 0.0)
        nn.init.normal_(
            self.fc_rotations.weight, mean=0.0, std=0.01
        )  # Small variance here

        self.fc_scale = nn.Linear(self.cfg.hidden_features, 1)
        nn.init.constant_(self.fc_scale.bias, 0.0)
        nn.init.normal_(self.fc_scale.weight, mean=0.0, std=0.01)  # Small variance here

    def make_activation(self, activation):
        if activation == "relu":
            return nn.ReLU(inplace=True)
        elif activation == "silu":
            return nn.SiLU(inplace=True)
        else:
            raise NotImplementedError

    def forward(
        self,
        triplane: Float[Tensor, "B 3 F Ht Wt"],
        rotation: Optional[Float[Tensor, "B 3 3"]] = None,
    ) -> dict[str, Any]:
        x = self.layers(
            triplane.reshape(
                triplane.shape[0], -1, triplane.shape[-2], triplane.shape[-1]
            )
        )
        x = x.mean(dim=[-2, -1])

        latents = self.fc_latents(x).reshape(-1, self.latent_dim, 3)
        rotations = rotation_6d_to_matrix(self.fc_rotations(x))
        scale = self.fc_scale(x)

        if rotation is not None:
            rotations = rotations @ rotation.to(dtype=rotations.dtype)

        env_map = self.reni_env_map(latents, rotations, scale)

        return {"illumination": env_map["rgb"]}