|
import math |
|
from dataclasses import dataclass |
|
|
|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange, repeat |
|
from jaxtyping import Float |
|
from torch import Tensor |
|
|
|
from spar3d.models.utils import BaseModule |
|
|
|
|
|
class TriplaneLearnablePositionalEmbedding(BaseModule): |
|
@dataclass |
|
class Config(BaseModule.Config): |
|
plane_size: int = 96 |
|
num_channels: int = 1024 |
|
|
|
cfg: Config |
|
|
|
def configure(self) -> None: |
|
self.embeddings = nn.Parameter( |
|
torch.randn( |
|
(3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size), |
|
dtype=torch.float32, |
|
) |
|
* 1 |
|
/ math.sqrt(self.cfg.num_channels) |
|
) |
|
|
|
def forward(self, batch_size: int) -> Float[Tensor, "B Ct Nt"]: |
|
return rearrange( |
|
repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size), |
|
"B Np Ct Hp Wp -> B Ct (Np Hp Wp)", |
|
) |
|
|
|
def detokenize( |
|
self, tokens: Float[Tensor, "B Ct Nt"] |
|
) -> Float[Tensor, "B 3 Ct Hp Wp"]: |
|
batch_size, Ct, Nt = tokens.shape |
|
assert Nt == self.cfg.plane_size**2 * 3 |
|
assert Ct == self.cfg.num_channels |
|
return rearrange( |
|
tokens, |
|
"B Ct (Np Hp Wp) -> B Np Ct Hp Wp", |
|
Np=3, |
|
Hp=self.cfg.plane_size, |
|
Wp=self.cfg.plane_size, |
|
) |
|
|