|
from dataclasses import dataclass |
|
from typing import Optional |
|
|
|
import torch |
|
from jaxtyping import Float |
|
from torch import Tensor |
|
|
|
from spar3d.models.transformers.transformer_1d import Transformer1D |
|
from spar3d.models.utils import BaseModule |
|
|
|
|
|
class TransformerPointTokenizer(BaseModule): |
|
@dataclass |
|
class Config(BaseModule.Config): |
|
num_attention_heads: int = 16 |
|
attention_head_dim: int = 64 |
|
in_channels: Optional[int] = 6 |
|
out_channels: Optional[int] = 1024 |
|
num_layers: int = 16 |
|
norm_num_groups: int = 32 |
|
attention_bias: bool = False |
|
activation_fn: str = "geglu" |
|
norm_elementwise_affine: bool = True |
|
|
|
cfg: Config |
|
|
|
def configure(self) -> None: |
|
transformer_cfg = dict(self.cfg.copy()) |
|
|
|
transformer_cfg["in_channels"] = ( |
|
self.cfg.num_attention_heads * self.cfg.attention_head_dim |
|
) |
|
self.model = Transformer1D(transformer_cfg) |
|
self.linear_in = torch.nn.Linear( |
|
self.cfg.in_channels, transformer_cfg["in_channels"] |
|
) |
|
self.linear_out = torch.nn.Linear( |
|
transformer_cfg["in_channels"], self.cfg.out_channels |
|
) |
|
|
|
def forward( |
|
self, points: Float[Tensor, "B N Ci"], **kwargs |
|
) -> Float[Tensor, "B N Cp"]: |
|
assert points.ndim == 3 |
|
inputs = self.linear_in(points).permute(0, 2, 1) |
|
out = self.model(inputs).permute(0, 2, 1) |
|
out = self.linear_out(out) |
|
return out |
|
|
|
def detokenize(self, *args, **kwargs): |
|
raise NotImplementedError |
|
|