from typing import Optional import torch.nn as nn from diffusers.models.activations import get_activation from diffusers.models.modeling_utils import ModelMixin from diffusers.configuration_utils import ConfigMixin, register_to_config class CameraMatrixEmbedding(ModelMixin, ConfigMixin): @register_to_config def __init__( self, in_channels: int, camera_embed_dim: int, act_fn: str = "silu", out_dim: int = None, post_act_fn: Optional[str] = None, cond_proj_dim=None, ): super().__init__() self.linear_1 = nn.Linear(in_channels, camera_embed_dim) if cond_proj_dim is not None: self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) else: self.cond_proj = None self.act = get_activation(act_fn) if out_dim is not None: camera_embed_dim_out = out_dim else: camera_embed_dim_out = camera_embed_dim self.linear_2 = nn.Linear(camera_embed_dim, camera_embed_dim_out) if post_act_fn is None: self.post_act = None else: self.post_act = get_activation(post_act_fn) def forward(self, sample, condition=None): if condition is not None: sample = sample + self.cond_proj(condition) sample = self.linear_1(sample) if self.act is not None: sample = self.act(sample) sample = self.linear_2(sample) if self.post_act is not None: sample = self.post_act(sample) return sample