File size: 961 Bytes
38dbec8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
from dataclasses import dataclass, field
from typing import List
import torch
import torch.nn as nn
from spar3d.models.utils import BaseModule
class LinearCameraEmbedder(BaseModule):
@dataclass
class Config(BaseModule.Config):
in_channels: int = 25
out_channels: int = 768
conditions: List[str] = field(default_factory=list)
cfg: Config
def configure(self) -> None:
self.linear = nn.Linear(self.cfg.in_channels, self.cfg.out_channels)
def forward(self, **kwargs):
cond_tensors = []
for cond_name in self.cfg.conditions:
assert cond_name in kwargs
cond = kwargs[cond_name]
# cond in shape (B, Nv, ...)
cond_tensors.append(cond.view(*cond.shape[:2], -1))
cond_tensor = torch.cat(cond_tensors, dim=-1)
assert cond_tensor.shape[-1] == self.cfg.in_channels
embedding = self.linear(cond_tensor)
return embedding
|