File size: 1,674 Bytes
38dbec8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from dataclasses import dataclass
from typing import Optional

import torch
from jaxtyping import Float
from torch import Tensor

from spar3d.models.transformers.transformer_1d import Transformer1D
from spar3d.models.utils import BaseModule


class TransformerPointTokenizer(BaseModule):
    @dataclass
    class Config(BaseModule.Config):
        num_attention_heads: int = 16
        attention_head_dim: int = 64
        in_channels: Optional[int] = 6
        out_channels: Optional[int] = 1024
        num_layers: int = 16
        norm_num_groups: int = 32
        attention_bias: bool = False
        activation_fn: str = "geglu"
        norm_elementwise_affine: bool = True

    cfg: Config

    def configure(self) -> None:
        transformer_cfg = dict(self.cfg.copy())
        # remove the non-transformer configs
        transformer_cfg["in_channels"] = (
            self.cfg.num_attention_heads * self.cfg.attention_head_dim
        )
        self.model = Transformer1D(transformer_cfg)
        self.linear_in = torch.nn.Linear(
            self.cfg.in_channels, transformer_cfg["in_channels"]
        )
        self.linear_out = torch.nn.Linear(
            transformer_cfg["in_channels"], self.cfg.out_channels
        )

    def forward(
        self, points: Float[Tensor, "B N Ci"], **kwargs
    ) -> Float[Tensor, "B N Cp"]:
        assert points.ndim == 3
        inputs = self.linear_in(points).permute(0, 2, 1)  # B N Ci -> B Ci N
        out = self.model(inputs).permute(0, 2, 1)  # B Ci N -> B N Ci
        out = self.linear_out(out)  # B N Ci -> B N Co
        return out

    def detokenize(self, *args, **kwargs):
        raise NotImplementedError