File size: 1,674 Bytes
38dbec8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
from dataclasses import dataclass
from typing import Optional
import torch
from jaxtyping import Float
from torch import Tensor
from spar3d.models.transformers.transformer_1d import Transformer1D
from spar3d.models.utils import BaseModule
class TransformerPointTokenizer(BaseModule):
@dataclass
class Config(BaseModule.Config):
num_attention_heads: int = 16
attention_head_dim: int = 64
in_channels: Optional[int] = 6
out_channels: Optional[int] = 1024
num_layers: int = 16
norm_num_groups: int = 32
attention_bias: bool = False
activation_fn: str = "geglu"
norm_elementwise_affine: bool = True
cfg: Config
def configure(self) -> None:
transformer_cfg = dict(self.cfg.copy())
# remove the non-transformer configs
transformer_cfg["in_channels"] = (
self.cfg.num_attention_heads * self.cfg.attention_head_dim
)
self.model = Transformer1D(transformer_cfg)
self.linear_in = torch.nn.Linear(
self.cfg.in_channels, transformer_cfg["in_channels"]
)
self.linear_out = torch.nn.Linear(
transformer_cfg["in_channels"], self.cfg.out_channels
)
def forward(
self, points: Float[Tensor, "B N Ci"], **kwargs
) -> Float[Tensor, "B N Cp"]:
assert points.ndim == 3
inputs = self.linear_in(points).permute(0, 2, 1) # B N Ci -> B Ci N
out = self.model(inputs).permute(0, 2, 1) # B Ci N -> B N Ci
out = self.linear_out(out) # B N Ci -> B N Co
return out
def detokenize(self, *args, **kwargs):
raise NotImplementedError
|