jammmmm's picture
Add spar3d demo files
38dbec8
raw
history blame
1.39 kB
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
from einops import rearrange, repeat
from jaxtyping import Float
from torch import Tensor
from spar3d.models.utils import BaseModule
class TriplaneLearnablePositionalEmbedding(BaseModule):
@dataclass
class Config(BaseModule.Config):
plane_size: int = 96
num_channels: int = 1024
cfg: Config
def configure(self) -> None:
self.embeddings = nn.Parameter(
torch.randn(
(3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
dtype=torch.float32,
)
* 1
/ math.sqrt(self.cfg.num_channels)
)
def forward(self, batch_size: int) -> Float[Tensor, "B Ct Nt"]:
return rearrange(
repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
"B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
)
def detokenize(
self, tokens: Float[Tensor, "B Ct Nt"]
) -> Float[Tensor, "B 3 Ct Hp Wp"]:
batch_size, Ct, Nt = tokens.shape
assert Nt == self.cfg.plane_size**2 * 3
assert Ct == self.cfg.num_channels
return rearrange(
tokens,
"B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
Np=3,
Hp=self.cfg.plane_size,
Wp=self.cfg.plane_size,
)