import torch from torch import Tensor from torch.nn import Parameter from typing import Any from layers.independent_mlp import IndependentMLPs # Baseline model, a modified convnext with reduced downsampling for a spatially larger feature tensor in the last layer class IndividualLandmarkConvNext(torch.nn.Module): def __init__(self, init_model: torch.nn.Module, num_landmarks: int = 8, num_classes: int = 200, sl_channels: int = 1024, fl_channels: int = 2048, part_dropout: float = 0.3, modulation_type: str = "original", modulation_orth: bool = False, gumbel_softmax: bool = False, gumbel_softmax_temperature: float = 1.0, gumbel_softmax_hard: bool = False, classifier_type: str = "linear", noise_variance: float = 0.0) -> None: super().__init__() self.num_landmarks = num_landmarks self.num_classes = num_classes self.noise_variance = noise_variance self.stem = init_model.stem self.stages = init_model.stages self.feature_dim = sl_channels + fl_channels self.fc_landmarks = torch.nn.Conv2d(self.feature_dim, num_landmarks + 1, 1, bias=False) self.gumbel_softmax = gumbel_softmax self.gumbel_softmax_temperature = gumbel_softmax_temperature self.gumbel_softmax_hard = gumbel_softmax_hard self.modulation_type = modulation_type if modulation_type == "layer_norm": self.modulation = torch.nn.LayerNorm([self.feature_dim, self.num_landmarks + 1]) elif modulation_type == "original": self.modulation = torch.nn.Parameter(torch.ones(1, self.feature_dim, self.num_landmarks + 1)) elif modulation_type == "parallel_mlp": self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim, num_lin_layers=1, act_layer=True, bias=True) elif modulation_type == "parallel_mlp_no_bias": self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim, num_lin_layers=1, act_layer=True, bias=False) elif modulation_type == "parallel_mlp_no_act": self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim, num_lin_layers=1, act_layer=False, bias=True) elif modulation_type == "parallel_mlp_no_act_no_bias": self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim, num_lin_layers=1, act_layer=False, bias=False) elif modulation_type == "none": self.modulation = torch.nn.Identity() else: raise ValueError("modulation_type not implemented") self.modulation_orth = modulation_orth self.dropout_full_landmarks = torch.nn.Dropout1d(part_dropout) self.classifier_type = classifier_type if classifier_type == "independent_mlp": self.fc_class_landmarks = IndependentMLPs(part_dim=self.num_landmarks, latent_dim=self.feature_dim, num_lin_layers=1, act_layer=False, out_dim=num_classes, bias=False, stack_dim=1) elif classifier_type == "linear": self.fc_class_landmarks = torch.nn.Linear(in_features=self.feature_dim, out_features=num_classes, bias=False) else: raise ValueError("classifier_type not implemented") def forward(self, x: Tensor) -> tuple[Any, Any, Any, Any, Parameter, int | Any]: # Pretrained ConvNeXt part of the model x = self.stem(x) x = self.stages[0](x) x = self.stages[1](x) l3 = self.stages[2](x) x = self.stages[3](l3) x = torch.nn.functional.interpolate(x, size=(l3.shape[-2], l3.shape[-1]), mode='bilinear', align_corners=False) x = torch.cat((x, l3), dim=1) # Compute per landmark attention maps # (b - a)^2 = b^2 - 2ab + a^2, b = feature maps resnet, a = convolution kernel batch_size = x.shape[0] ab = self.fc_landmarks(x) b_sq = x.pow(2).sum(1, keepdim=True) b_sq = b_sq.expand(-1, self.num_landmarks + 1, -1, -1).contiguous() a_sq = self.fc_landmarks.weight.pow(2).sum(1).unsqueeze(1).expand(-1, batch_size, x.shape[-2], x.shape[-1]).contiguous() a_sq = a_sq.permute(1, 0, 2, 3).contiguous() dist = b_sq - 2 * ab + a_sq maps = -dist # Softmax so that the attention maps for each pixel add up to 1 if self.gumbel_softmax: maps = torch.nn.functional.gumbel_softmax(maps, dim=1, tau=self.gumbel_softmax_temperature, hard=self.gumbel_softmax_hard) # [B, num_landmarks + 1, H, W] else: maps = torch.nn.functional.softmax(maps, dim=1) # [B, num_landmarks + 1, H, W] # Use maps to get weighted average features per landmark all_features = (maps.unsqueeze(1) * x.unsqueeze(2)).mean(-1).mean(-1).contiguous() if self.noise_variance > 0.0: all_features += torch.randn_like(all_features, device=all_features.device) * x.std().detach() * self.noise_variance # Modulate the features if self.modulation_type == "original": all_features_mod = all_features * self.modulation else: all_features_mod = self.modulation(all_features) # Classification based on the landmark features scores = self.fc_class_landmarks( self.dropout_full_landmarks(all_features_mod[..., :-1].permute(0, 2, 1).contiguous())).permute(0, 2, 1).contiguous() if self.modulation_orth: return all_features_mod, maps, scores, dist else: return all_features, maps, scores, dist