import math from dataclasses import dataclass import torch import torch.nn as nn from einops import rearrange, repeat from jaxtyping import Float from torch import Tensor from spar3d.models.utils import BaseModule class TriplaneLearnablePositionalEmbedding(BaseModule): @dataclass class Config(BaseModule.Config): plane_size: int = 96 num_channels: int = 1024 cfg: Config def configure(self) -> None: self.embeddings = nn.Parameter( torch.randn( (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size), dtype=torch.float32, ) * 1 / math.sqrt(self.cfg.num_channels) ) def forward(self, batch_size: int) -> Float[Tensor, "B Ct Nt"]: return rearrange( repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size), "B Np Ct Hp Wp -> B Ct (Np Hp Wp)", ) def detokenize( self, tokens: Float[Tensor, "B Ct Nt"] ) -> Float[Tensor, "B 3 Ct Hp Wp"]: batch_size, Ct, Nt = tokens.shape assert Nt == self.cfg.plane_size**2 * 3 assert Ct == self.cfg.num_channels return rearrange( tokens, "B Ct (Np Hp Wp) -> B Np Ct Hp Wp", Np=3, Hp=self.cfg.plane_size, Wp=self.cfg.plane_size, )