jammmmm's picture
Add spar3d demo files
38dbec8
raw
history blame
1.67 kB
from dataclasses import dataclass
from typing import Optional
import torch
from jaxtyping import Float
from torch import Tensor
from spar3d.models.transformers.transformer_1d import Transformer1D
from spar3d.models.utils import BaseModule
class TransformerPointTokenizer(BaseModule):
@dataclass
class Config(BaseModule.Config):
num_attention_heads: int = 16
attention_head_dim: int = 64
in_channels: Optional[int] = 6
out_channels: Optional[int] = 1024
num_layers: int = 16
norm_num_groups: int = 32
attention_bias: bool = False
activation_fn: str = "geglu"
norm_elementwise_affine: bool = True
cfg: Config
def configure(self) -> None:
transformer_cfg = dict(self.cfg.copy())
# remove the non-transformer configs
transformer_cfg["in_channels"] = (
self.cfg.num_attention_heads * self.cfg.attention_head_dim
)
self.model = Transformer1D(transformer_cfg)
self.linear_in = torch.nn.Linear(
self.cfg.in_channels, transformer_cfg["in_channels"]
)
self.linear_out = torch.nn.Linear(
transformer_cfg["in_channels"], self.cfg.out_channels
)
def forward(
self, points: Float[Tensor, "B N Ci"], **kwargs
) -> Float[Tensor, "B N Cp"]:
assert points.ndim == 3
inputs = self.linear_in(points).permute(0, 2, 1) # B N Ci -> B Ci N
out = self.model(inputs).permute(0, 2, 1) # B Ci N -> B N Ci
out = self.linear_out(out) # B N Ci -> B N Co
return out
def detokenize(self, *args, **kwargs):
raise NotImplementedError