Spaces:
Sleeping
Sleeping
# This file contains the implementation of the IndependentMLPs class | |
import torch | |
class IndependentMLPs(torch.nn.Module): | |
""" | |
This class implements the MLP used for classification with the option to use an additional independent MLP layer | |
""" | |
def __init__(self, part_dim, latent_dim, bias=False, num_lin_layers=1, act_layer=True, out_dim=None, stack_dim=-1): | |
""" | |
:param part_dim: Number of parts | |
:param latent_dim: Latent dimension | |
:param bias: Whether to use bias | |
:param num_lin_layers: Number of linear layers | |
:param act_layer: Whether to use activation layer | |
:param out_dim: Output dimension (default: None) | |
:param stack_dim: Dimension to stack the outputs (default: -1) | |
""" | |
super().__init__() | |
self.bias = bias | |
self.latent_dim = latent_dim | |
if out_dim is None: | |
out_dim = latent_dim | |
self.out_dim = out_dim | |
self.part_dim = part_dim | |
self.stack_dim = stack_dim | |
layer_stack = torch.nn.ModuleList() | |
for i in range(part_dim): | |
layer_stack.append(torch.nn.Sequential()) | |
for j in range(num_lin_layers): | |
layer_stack[i].add_module(f"fc_{j}", torch.nn.Linear(latent_dim, self.out_dim, bias=bias)) | |
if act_layer: | |
layer_stack[i].add_module(f"act_{j}", torch.nn.GELU()) | |
self.feature_layers = layer_stack | |
self.reset_weights() | |
def __repr__(self): | |
return f"IndependentMLPs(part_dim={self.part_dim}, latent_dim={self.latent_dim}), bias={self.bias}" | |
def reset_weights(self): | |
""" Initialize weights with a identity matrix""" | |
for layer in self.feature_layers: | |
for m in layer.modules(): | |
if isinstance(m, torch.nn.Linear): | |
# Initialize weights with a truncated normal distribution | |
torch.nn.init.trunc_normal_(m.weight, std=0.02) | |
if m.bias is not None: | |
torch.nn.init.zeros_(m.bias) | |
def forward(self, x): | |
""" Input X has the dimensions batch x latent_dim x part_dim """ | |
outputs = [] | |
for i, layer in enumerate(self.feature_layers): | |
if self.stack_dim == -1: | |
in_ = x[..., i] | |
else: | |
in_ = x[:, i, ...] # Select feature i | |
out = layer(in_) # Apply MLP to feature i | |
outputs.append(out) | |
x = torch.stack(outputs, dim=self.stack_dim) # Stack the outputs | |
return x | |