|
from dataclasses import dataclass |
|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange |
|
from jaxtyping import Float |
|
from torch import Tensor |
|
|
|
from spar3d.models.tokenizers.dinov2 import Dinov2Model |
|
from spar3d.models.transformers.attention import Modulation |
|
from spar3d.models.utils import BaseModule |
|
|
|
|
|
class DINOV2SingleImageTokenizer(BaseModule): |
|
@dataclass |
|
class Config(BaseModule.Config): |
|
pretrained_model_name_or_path: str = "facebook/dinov2-large" |
|
width: int = 512 |
|
height: int = 512 |
|
modulation_cond_dim: int = 768 |
|
|
|
cfg: Config |
|
|
|
def configure(self) -> None: |
|
self.model = Dinov2Model.from_pretrained(self.cfg.pretrained_model_name_or_path) |
|
|
|
for p in self.model.parameters(): |
|
p.requires_grad_(False) |
|
self.model.eval() |
|
|
|
self.model.set_gradient_checkpointing(False) |
|
|
|
|
|
modulations = [] |
|
for layer in self.model.encoder.layer: |
|
norm1_modulation = Modulation( |
|
self.model.config.hidden_size, |
|
self.cfg.modulation_cond_dim, |
|
zero_init=True, |
|
single_layer=True, |
|
) |
|
norm2_modulation = Modulation( |
|
self.model.config.hidden_size, |
|
self.cfg.modulation_cond_dim, |
|
zero_init=True, |
|
single_layer=True, |
|
) |
|
layer.register_ada_norm_modulation(norm1_modulation, norm2_modulation) |
|
modulations += [norm1_modulation, norm2_modulation] |
|
self.modulations = nn.ModuleList(modulations) |
|
|
|
self.register_buffer( |
|
"image_mean", |
|
torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1), |
|
persistent=False, |
|
) |
|
self.register_buffer( |
|
"image_std", |
|
torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1), |
|
persistent=False, |
|
) |
|
|
|
def forward( |
|
self, |
|
images: Float[Tensor, "B *N C H W"], |
|
modulation_cond: Optional[Float[Tensor, "B *N Cc"]], |
|
**kwargs, |
|
) -> Float[Tensor, "B *N Ct Nt"]: |
|
model = self.model |
|
|
|
packed = False |
|
if images.ndim == 4: |
|
packed = True |
|
images = images.unsqueeze(1) |
|
if modulation_cond is not None: |
|
assert modulation_cond.ndim == 2 |
|
modulation_cond = modulation_cond.unsqueeze(1) |
|
|
|
batch_size, n_input_views = images.shape[:2] |
|
images = (images - self.image_mean) / self.image_std |
|
out = model( |
|
rearrange(images, "B N C H W -> (B N) C H W"), |
|
modulation_cond=rearrange(modulation_cond, "B N Cc -> (B N) Cc") |
|
if modulation_cond is not None |
|
else None, |
|
) |
|
local_features = out.last_hidden_state |
|
local_features = local_features.permute(0, 2, 1) |
|
local_features = rearrange( |
|
local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size |
|
) |
|
if packed: |
|
local_features = local_features.squeeze(1) |
|
|
|
return local_features |
|
|
|
def detokenize(self, *args, **kwargs): |
|
raise NotImplementedError |
|
|