File size: 3,410 Bytes
38dbec8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from dataclasses import dataclass, field
from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F
from jaxtyping import Float
from torch import Tensor

from spar3d.models.illumination.reni.env_map import RENIEnvMap
from spar3d.models.utils import BaseModule


def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
    assert d6.shape[-1] == 6, "Input tensor must have shape (..., 6)"

    def proj_u2a(u, a):
        r"""
        u: batch x 3
        a: batch x 3
        """
        inner_prod = torch.sum(u * a, dim=-1, keepdim=True)
        norm2 = torch.sum(u**2, dim=-1, keepdim=True)
        norm2 = torch.clamp(norm2, min=1e-8)
        factor = inner_prod / (norm2 + 1e-10)
        return factor * u

    x_raw, y_raw = d6[..., :3], d6[..., 3:]

    x = F.normalize(x_raw, dim=-1)
    y = F.normalize(y_raw - proj_u2a(x, y_raw), dim=-1)
    z = torch.cross(x, y, dim=-1)

    return torch.stack((x, y, z), dim=-1)


class ReniLatentCodeEstimator(BaseModule):
    @dataclass
    class Config(BaseModule.Config):
        triplane_features: int = 40

        n_layers: int = 5
        hidden_features: int = 512
        activation: str = "relu"

        pool: str = "mean"

        reni_env_config: dict = field(default_factory=dict)

    cfg: Config

    def configure(self):
        layers = []
        cur_features = self.cfg.triplane_features * 3
        for _ in range(self.cfg.n_layers):
            layers.append(
                nn.Conv2d(
                    cur_features,
                    self.cfg.hidden_features,
                    kernel_size=3,
                    padding=0,
                    stride=2,
                )
            )
            layers.append(self.make_activation(self.cfg.activation))

            cur_features = self.cfg.hidden_features

        self.layers = nn.Sequential(*layers)

        self.reni_env_map = RENIEnvMap(self.cfg.reni_env_config)
        self.latent_dim = self.reni_env_map.field.latent_dim

        self.fc_latents = nn.Linear(self.cfg.hidden_features, self.latent_dim * 3)
        nn.init.normal_(self.fc_latents.weight, mean=0.0, std=0.3)

        self.fc_rotations = nn.Linear(self.cfg.hidden_features, 6)
        nn.init.constant_(self.fc_rotations.bias, 0.0)
        nn.init.normal_(
            self.fc_rotations.weight, mean=0.0, std=0.01
        )  # Small variance here

        self.fc_scale = nn.Linear(self.cfg.hidden_features, 1)
        nn.init.constant_(self.fc_scale.bias, 0.0)
        nn.init.normal_(self.fc_scale.weight, mean=0.0, std=0.01)  # Small variance here

    def make_activation(self, activation):
        if activation == "relu":
            return nn.ReLU(inplace=True)
        elif activation == "silu":
            return nn.SiLU(inplace=True)
        else:
            raise NotImplementedError

    def forward(
        self,
        triplane: Float[Tensor, "B 3 F Ht Wt"],
    ) -> dict[str, Any]:
        x = self.layers(
            triplane.reshape(
                triplane.shape[0], -1, triplane.shape[-2], triplane.shape[-1]
            )
        )
        x = x.mean(dim=[-2, -1])

        latents = self.fc_latents(x).reshape(-1, self.latent_dim, 3)
        rotations = self.fc_rotations(x)
        scale = self.fc_scale(x)

        env_map = self.reni_env_map(latents, rotation_6d_to_matrix(rotations), scale)

        return {"illumination": env_map["rgb"]}