from dataclasses import dataclass from typing import Optional import torch import torch.nn as nn from einops import rearrange from jaxtyping import Float from torch import Tensor from spar3d.models.tokenizers.dinov2 import Dinov2Model from spar3d.models.transformers.attention import Modulation from spar3d.models.utils import BaseModule class DINOV2SingleImageTokenizer(BaseModule): @dataclass class Config(BaseModule.Config): pretrained_model_name_or_path: str = "facebook/dinov2-large" width: int = 512 height: int = 512 modulation_cond_dim: int = 768 cfg: Config def configure(self) -> None: self.model = Dinov2Model.from_pretrained(self.cfg.pretrained_model_name_or_path) for p in self.model.parameters(): p.requires_grad_(False) self.model.eval() self.model.set_gradient_checkpointing(False) # add modulation modulations = [] for layer in self.model.encoder.layer: norm1_modulation = Modulation( self.model.config.hidden_size, self.cfg.modulation_cond_dim, zero_init=True, single_layer=True, ) norm2_modulation = Modulation( self.model.config.hidden_size, self.cfg.modulation_cond_dim, zero_init=True, single_layer=True, ) layer.register_ada_norm_modulation(norm1_modulation, norm2_modulation) modulations += [norm1_modulation, norm2_modulation] self.modulations = nn.ModuleList(modulations) self.register_buffer( "image_mean", torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1), persistent=False, ) self.register_buffer( "image_std", torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1), persistent=False, ) def forward( self, images: Float[Tensor, "B *N C H W"], modulation_cond: Optional[Float[Tensor, "B *N Cc"]], **kwargs, ) -> Float[Tensor, "B *N Ct Nt"]: model = self.model packed = False if images.ndim == 4: packed = True images = images.unsqueeze(1) if modulation_cond is not None: assert modulation_cond.ndim == 2 modulation_cond = modulation_cond.unsqueeze(1) batch_size, n_input_views = images.shape[:2] images = (images - self.image_mean) / self.image_std out = model( rearrange(images, "B N C H W -> (B N) C H W"), modulation_cond=rearrange(modulation_cond, "B N Cc -> (B N) Cc") if modulation_cond is not None else None, ) local_features = out.last_hidden_state local_features = local_features.permute(0, 2, 1) local_features = rearrange( local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size ) if packed: local_features = local_features.squeeze(1) return local_features def detokenize(self, *args, **kwargs): raise NotImplementedError