File size: 3,220 Bytes
38dbec8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from einops import rearrange
from jaxtyping import Float
from torch import Tensor
from spar3d.models.tokenizers.dinov2 import Dinov2Model
from spar3d.models.transformers.attention import Modulation
from spar3d.models.utils import BaseModule
class DINOV2SingleImageTokenizer(BaseModule):
@dataclass
class Config(BaseModule.Config):
pretrained_model_name_or_path: str = "facebook/dinov2-large"
width: int = 512
height: int = 512
modulation_cond_dim: int = 768
cfg: Config
def configure(self) -> None:
self.model = Dinov2Model.from_pretrained(self.cfg.pretrained_model_name_or_path)
for p in self.model.parameters():
p.requires_grad_(False)
self.model.eval()
self.model.set_gradient_checkpointing(False)
# add modulation
modulations = []
for layer in self.model.encoder.layer:
norm1_modulation = Modulation(
self.model.config.hidden_size,
self.cfg.modulation_cond_dim,
zero_init=True,
single_layer=True,
)
norm2_modulation = Modulation(
self.model.config.hidden_size,
self.cfg.modulation_cond_dim,
zero_init=True,
single_layer=True,
)
layer.register_ada_norm_modulation(norm1_modulation, norm2_modulation)
modulations += [norm1_modulation, norm2_modulation]
self.modulations = nn.ModuleList(modulations)
self.register_buffer(
"image_mean",
torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
persistent=False,
)
self.register_buffer(
"image_std",
torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
persistent=False,
)
def forward(
self,
images: Float[Tensor, "B *N C H W"],
modulation_cond: Optional[Float[Tensor, "B *N Cc"]],
**kwargs,
) -> Float[Tensor, "B *N Ct Nt"]:
model = self.model
packed = False
if images.ndim == 4:
packed = True
images = images.unsqueeze(1)
if modulation_cond is not None:
assert modulation_cond.ndim == 2
modulation_cond = modulation_cond.unsqueeze(1)
batch_size, n_input_views = images.shape[:2]
images = (images - self.image_mean) / self.image_std
out = model(
rearrange(images, "B N C H W -> (B N) C H W"),
modulation_cond=rearrange(modulation_cond, "B N Cc -> (B N) Cc")
if modulation_cond is not None
else None,
)
local_features = out.last_hidden_state
local_features = local_features.permute(0, 2, 1)
local_features = rearrange(
local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
)
if packed:
local_features = local_features.squeeze(1)
return local_features
def detokenize(self, *args, **kwargs):
raise NotImplementedError
|