jammmmm's picture
Add spar3d demo files
38dbec8
raw
history blame
6.27 kB
from typing import Optional
import torch
from torch import nn
class MultiHeadAttention(nn.Module):
def __init__(
self,
direction_input_dim: int,
conditioning_input_dim: int,
latent_dim: int,
num_heads: int,
):
"""
Multi-Head Attention module.
Args:
direction_input_dim (int): The input dimension of the directional input.
conditioning_input_dim (int): The input dimension of the conditioning input.
latent_dim (int): The latent dimension of the module.
num_heads (int): The number of heads to use in the attention mechanism.
"""
super().__init__()
assert latent_dim % num_heads == 0, "latent_dim must be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = latent_dim // num_heads
self.scale = self.head_dim**-0.5
self.query = nn.Linear(direction_input_dim, latent_dim)
self.key = nn.Linear(conditioning_input_dim, latent_dim)
self.value = nn.Linear(conditioning_input_dim, latent_dim)
self.fc_out = nn.Linear(latent_dim, latent_dim)
def forward(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
"""
Forward pass of the Multi-Head Attention module.
Args:
query (torch.Tensor): The directional input tensor.
key (torch.Tensor): The conditioning input tensor for the keys.
value (torch.Tensor): The conditioning input tensor for the values.
Returns:
torch.Tensor: The output tensor of the Multi-Head Attention module.
"""
batch_size = query.size(0)
Q = (
self.query(query)
.view(batch_size, -1, self.num_heads, self.head_dim)
.transpose(1, 2)
)
K = (
self.key(key)
.view(batch_size, -1, self.num_heads, self.head_dim)
.transpose(1, 2)
)
V = (
self.value(value)
.view(batch_size, -1, self.num_heads, self.head_dim)
.transpose(1, 2)
)
attention = (
torch.einsum("bnqk,bnkh->bnqh", [Q, K.transpose(-2, -1)]) * self.scale
)
attention = torch.softmax(attention, dim=-1)
out = torch.einsum("bnqh,bnhv->bnqv", [attention, V])
out = (
out.transpose(1, 2)
.contiguous()
.view(batch_size, -1, self.num_heads * self.head_dim)
)
out = self.fc_out(out).squeeze(1)
return out
class AttentionLayer(nn.Module):
def __init__(
self,
direction_input_dim: int,
conditioning_input_dim: int,
latent_dim: int,
num_heads: int,
):
"""
Attention Layer module.
Args:
direction_input_dim (int): The input dimension of the directional input.
conditioning_input_dim (int): The input dimension of the conditioning input.
latent_dim (int): The latent dimension of the module.
num_heads (int): The number of heads to use in the attention mechanism.
"""
super().__init__()
self.mha = MultiHeadAttention(
direction_input_dim, conditioning_input_dim, latent_dim, num_heads
)
self.norm1 = nn.LayerNorm(latent_dim)
self.norm2 = nn.LayerNorm(latent_dim)
self.fc = nn.Sequential(
nn.Linear(latent_dim, latent_dim),
nn.ReLU(),
nn.Linear(latent_dim, latent_dim),
)
def forward(
self, directional_input: torch.Tensor, conditioning_input: torch.Tensor
) -> torch.Tensor:
"""
Forward pass of the Attention Layer module.
Args:
directional_input (torch.Tensor): The directional input tensor.
conditioning_input (torch.Tensor): The conditioning input tensor.
Returns:
torch.Tensor: The output tensor of the Attention Layer module.
"""
attn_output = self.mha(
directional_input, conditioning_input, conditioning_input
)
out1 = self.norm1(attn_output + directional_input)
fc_output = self.fc(out1)
out2 = self.norm2(fc_output + out1)
return out2
class Decoder(nn.Module):
def __init__(
self,
in_dim: int,
conditioning_input_dim: int,
hidden_features: int,
num_heads: int,
num_layers: int,
out_activation: Optional[nn.Module],
):
"""
Decoder module.
Args:
in_dim (int): The input dimension of the module.
conditioning_input_dim (int): The input dimension of the conditioning input.
hidden_features (int): The number of hidden features in the module.
num_heads (int): The number of heads to use in the attention mechanism.
num_layers (int): The number of layers in the module.
out_activation (nn.Module): The activation function to use on the output tensor.
"""
super().__init__()
self.residual_projection = nn.Linear(
in_dim, hidden_features
) # projection for residual connection
self.layers = nn.ModuleList(
[
AttentionLayer(
hidden_features, conditioning_input_dim, hidden_features, num_heads
)
for i in range(num_layers)
]
)
self.fc = nn.Linear(hidden_features, 3) # 3 for RGB
self.out_activation = out_activation
def forward(
self, x: torch.Tensor, conditioning_input: torch.Tensor
) -> torch.Tensor:
"""
Forward pass of the Decoder module.
Args:
x (torch.Tensor): The input tensor.
conditioning_input (torch.Tensor): The conditioning input tensor.
Returns:
torch.Tensor: The output tensor of the Decoder module.
"""
x = self.residual_projection(x)
for layer in self.layers:
x = layer(x, conditioning_input)
x = self.fc(x)
if self.out_activation is not None:
x = self.out_activation(x)
return x