from typing import Optional import torch from torch import nn class MultiHeadAttention(nn.Module): def __init__( self, direction_input_dim: int, conditioning_input_dim: int, latent_dim: int, num_heads: int, ): """ Multi-Head Attention module. Args: direction_input_dim (int): The input dimension of the directional input. conditioning_input_dim (int): The input dimension of the conditioning input. latent_dim (int): The latent dimension of the module. num_heads (int): The number of heads to use in the attention mechanism. """ super().__init__() assert latent_dim % num_heads == 0, "latent_dim must be divisible by num_heads" self.num_heads = num_heads self.head_dim = latent_dim // num_heads self.scale = self.head_dim**-0.5 self.query = nn.Linear(direction_input_dim, latent_dim) self.key = nn.Linear(conditioning_input_dim, latent_dim) self.value = nn.Linear(conditioning_input_dim, latent_dim) self.fc_out = nn.Linear(latent_dim, latent_dim) def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor ) -> torch.Tensor: """ Forward pass of the Multi-Head Attention module. Args: query (torch.Tensor): The directional input tensor. key (torch.Tensor): The conditioning input tensor for the keys. value (torch.Tensor): The conditioning input tensor for the values. Returns: torch.Tensor: The output tensor of the Multi-Head Attention module. """ batch_size = query.size(0) Q = ( self.query(query) .view(batch_size, -1, self.num_heads, self.head_dim) .transpose(1, 2) ) K = ( self.key(key) .view(batch_size, -1, self.num_heads, self.head_dim) .transpose(1, 2) ) V = ( self.value(value) .view(batch_size, -1, self.num_heads, self.head_dim) .transpose(1, 2) ) attention = ( torch.einsum("bnqk,bnkh->bnqh", [Q, K.transpose(-2, -1)]) * self.scale ) attention = torch.softmax(attention, dim=-1) out = torch.einsum("bnqh,bnhv->bnqv", [attention, V]) out = ( out.transpose(1, 2) .contiguous() .view(batch_size, -1, self.num_heads * self.head_dim) ) out = self.fc_out(out).squeeze(1) return out class AttentionLayer(nn.Module): def __init__( self, direction_input_dim: int, conditioning_input_dim: int, latent_dim: int, num_heads: int, ): """ Attention Layer module. Args: direction_input_dim (int): The input dimension of the directional input. conditioning_input_dim (int): The input dimension of the conditioning input. latent_dim (int): The latent dimension of the module. num_heads (int): The number of heads to use in the attention mechanism. """ super().__init__() self.mha = MultiHeadAttention( direction_input_dim, conditioning_input_dim, latent_dim, num_heads ) self.norm1 = nn.LayerNorm(latent_dim) self.norm2 = nn.LayerNorm(latent_dim) self.fc = nn.Sequential( nn.Linear(latent_dim, latent_dim), nn.ReLU(), nn.Linear(latent_dim, latent_dim), ) def forward( self, directional_input: torch.Tensor, conditioning_input: torch.Tensor ) -> torch.Tensor: """ Forward pass of the Attention Layer module. Args: directional_input (torch.Tensor): The directional input tensor. conditioning_input (torch.Tensor): The conditioning input tensor. Returns: torch.Tensor: The output tensor of the Attention Layer module. """ attn_output = self.mha( directional_input, conditioning_input, conditioning_input ) out1 = self.norm1(attn_output + directional_input) fc_output = self.fc(out1) out2 = self.norm2(fc_output + out1) return out2 class Decoder(nn.Module): def __init__( self, in_dim: int, conditioning_input_dim: int, hidden_features: int, num_heads: int, num_layers: int, out_activation: Optional[nn.Module], ): """ Decoder module. Args: in_dim (int): The input dimension of the module. conditioning_input_dim (int): The input dimension of the conditioning input. hidden_features (int): The number of hidden features in the module. num_heads (int): The number of heads to use in the attention mechanism. num_layers (int): The number of layers in the module. out_activation (nn.Module): The activation function to use on the output tensor. """ super().__init__() self.residual_projection = nn.Linear( in_dim, hidden_features ) # projection for residual connection self.layers = nn.ModuleList( [ AttentionLayer( hidden_features, conditioning_input_dim, hidden_features, num_heads ) for i in range(num_layers) ] ) self.fc = nn.Linear(hidden_features, 3) # 3 for RGB self.out_activation = out_activation def forward( self, x: torch.Tensor, conditioning_input: torch.Tensor ) -> torch.Tensor: """ Forward pass of the Decoder module. Args: x (torch.Tensor): The input tensor. conditioning_input (torch.Tensor): The conditioning input tensor. Returns: torch.Tensor: The output tensor of the Decoder module. """ x = self.residual_projection(x) for layer in self.layers: x = layer(x, conditioning_input) x = self.fc(x) if self.out_activation is not None: x = self.out_activation(x) return x