Spaces:
Sleeping
Sleeping
import torch | |
from torch import Tensor | |
from torch.nn import Parameter | |
from typing import Any | |
from layers.independent_mlp import IndependentMLPs | |
# Baseline model, a modified convnext with reduced downsampling for a spatially larger feature tensor in the last layer | |
class IndividualLandmarkConvNext(torch.nn.Module): | |
def __init__(self, init_model: torch.nn.Module, num_landmarks: int = 8, | |
num_classes: int = 200, sl_channels: int = 1024, fl_channels: int = 2048, part_dropout: float = 0.3, | |
modulation_type: str = "original", modulation_orth: bool = False, gumbel_softmax: bool = False, | |
gumbel_softmax_temperature: float = 1.0, gumbel_softmax_hard: bool = False, | |
classifier_type: str = "linear", noise_variance: float = 0.0) -> None: | |
super().__init__() | |
self.num_landmarks = num_landmarks | |
self.num_classes = num_classes | |
self.noise_variance = noise_variance | |
self.stem = init_model.stem | |
self.stages = init_model.stages | |
self.feature_dim = sl_channels + fl_channels | |
self.fc_landmarks = torch.nn.Conv2d(self.feature_dim, num_landmarks + 1, 1, bias=False) | |
self.gumbel_softmax = gumbel_softmax | |
self.gumbel_softmax_temperature = gumbel_softmax_temperature | |
self.gumbel_softmax_hard = gumbel_softmax_hard | |
self.modulation_type = modulation_type | |
if modulation_type == "layer_norm": | |
self.modulation = torch.nn.LayerNorm([self.feature_dim, self.num_landmarks + 1]) | |
elif modulation_type == "original": | |
self.modulation = torch.nn.Parameter(torch.ones(1, self.feature_dim, self.num_landmarks + 1)) | |
elif modulation_type == "parallel_mlp": | |
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim, | |
num_lin_layers=1, act_layer=True, bias=True) | |
elif modulation_type == "parallel_mlp_no_bias": | |
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim, | |
num_lin_layers=1, act_layer=True, bias=False) | |
elif modulation_type == "parallel_mlp_no_act": | |
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim, | |
num_lin_layers=1, act_layer=False, bias=True) | |
elif modulation_type == "parallel_mlp_no_act_no_bias": | |
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim, | |
num_lin_layers=1, act_layer=False, bias=False) | |
elif modulation_type == "none": | |
self.modulation = torch.nn.Identity() | |
else: | |
raise ValueError("modulation_type not implemented") | |
self.modulation_orth = modulation_orth | |
self.dropout_full_landmarks = torch.nn.Dropout1d(part_dropout) | |
self.classifier_type = classifier_type | |
if classifier_type == "independent_mlp": | |
self.fc_class_landmarks = IndependentMLPs(part_dim=self.num_landmarks, latent_dim=self.feature_dim, | |
num_lin_layers=1, act_layer=False, out_dim=num_classes, | |
bias=False, stack_dim=1) | |
elif classifier_type == "linear": | |
self.fc_class_landmarks = torch.nn.Linear(in_features=self.feature_dim, out_features=num_classes, | |
bias=False) | |
else: | |
raise ValueError("classifier_type not implemented") | |
def forward(self, x: Tensor) -> tuple[Any, Any, Any, Any, Parameter, int | Any]: | |
# Pretrained ConvNeXt part of the model | |
x = self.stem(x) | |
x = self.stages[0](x) | |
x = self.stages[1](x) | |
l3 = self.stages[2](x) | |
x = self.stages[3](l3) | |
x = torch.nn.functional.interpolate(x, size=(l3.shape[-2], l3.shape[-1]), mode='bilinear', align_corners=False) | |
x = torch.cat((x, l3), dim=1) | |
# Compute per landmark attention maps | |
# (b - a)^2 = b^2 - 2ab + a^2, b = feature maps resnet, a = convolution kernel | |
batch_size = x.shape[0] | |
ab = self.fc_landmarks(x) | |
b_sq = x.pow(2).sum(1, keepdim=True) | |
b_sq = b_sq.expand(-1, self.num_landmarks + 1, -1, -1).contiguous() | |
a_sq = self.fc_landmarks.weight.pow(2).sum(1).unsqueeze(1).expand(-1, batch_size, x.shape[-2], | |
x.shape[-1]).contiguous() | |
a_sq = a_sq.permute(1, 0, 2, 3).contiguous() | |
dist = b_sq - 2 * ab + a_sq | |
maps = -dist | |
# Softmax so that the attention maps for each pixel add up to 1 | |
if self.gumbel_softmax: | |
maps = torch.nn.functional.gumbel_softmax(maps, dim=1, tau=self.gumbel_softmax_temperature, | |
hard=self.gumbel_softmax_hard) # [B, num_landmarks + 1, H, W] | |
else: | |
maps = torch.nn.functional.softmax(maps, dim=1) # [B, num_landmarks + 1, H, W] | |
# Use maps to get weighted average features per landmark | |
all_features = (maps.unsqueeze(1) * x.unsqueeze(2)).mean(-1).mean(-1).contiguous() | |
if self.noise_variance > 0.0: | |
all_features += torch.randn_like(all_features, | |
device=all_features.device) * x.std().detach() * self.noise_variance | |
# Modulate the features | |
if self.modulation_type == "original": | |
all_features_mod = all_features * self.modulation | |
else: | |
all_features_mod = self.modulation(all_features) | |
# Classification based on the landmark features | |
scores = self.fc_class_landmarks( | |
self.dropout_full_landmarks(all_features_mod[..., :-1].permute(0, 2, 1).contiguous())).permute(0, 2, | |
1).contiguous() | |
if self.modulation_orth: | |
return all_features_mod, maps, scores, dist | |
else: | |
return all_features, maps, scores, dist | |