ananthu-aniraj commited on
Commit
20239f9
·
1 Parent(s): b507f8e

add initial files

Browse files
.gitignore ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # editor settings
2
+ .idea
3
+ .vscode
4
+ _darcs
5
+
6
+ # compilation and distribution
7
+ __pycache__
8
+ _ext
9
+ *.pyc
10
+ *.pyd
11
+ *.so
12
+ *.dll
13
+ *.egg-info/
14
+ build/
15
+ dist/
16
+ wheels/
17
+
18
+ # pytorch/python/numpy formats
19
+ *.pth
20
+ *.pkl
21
+ *.npy
22
+ *.ts
23
+ *.pt
24
+
25
+ # ipython/jupyter notebooks
26
+ *.ipynb
27
+ **/.ipynb_checkpoints/
28
+
29
+ # Editor temporaries
30
+ *.swn
31
+ *.swo
32
+ *.swp
33
+ *~
34
+
35
+ # Results temporary
36
+ *.png
37
+ *.txt
38
+ *.tsv
39
+ wandb/
40
+ exps/
files/images/Laysan_Albatross_0050_870.jpg ADDED
layers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .transformer_layers import *
2
+ from .independent_mlp import *
layers/independent_mlp.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file contains the implementation of the IndependentMLPs class
2
+ import torch
3
+
4
+
5
+ class IndependentMLPs(torch.nn.Module):
6
+ """
7
+ This class implements the MLP used for classification with the option to use an additional independent MLP layer
8
+ """
9
+
10
+ def __init__(self, part_dim, latent_dim, bias=False, num_lin_layers=1, act_layer=True, out_dim=None, stack_dim=-1):
11
+ """
12
+
13
+ :param part_dim: Number of parts
14
+ :param latent_dim: Latent dimension
15
+ :param bias: Whether to use bias
16
+ :param num_lin_layers: Number of linear layers
17
+ :param act_layer: Whether to use activation layer
18
+ :param out_dim: Output dimension (default: None)
19
+ :param stack_dim: Dimension to stack the outputs (default: -1)
20
+ """
21
+
22
+ super().__init__()
23
+
24
+ self.bias = bias
25
+ self.latent_dim = latent_dim
26
+ if out_dim is None:
27
+ out_dim = latent_dim
28
+ self.out_dim = out_dim
29
+ self.part_dim = part_dim
30
+ self.stack_dim = stack_dim
31
+
32
+ layer_stack = torch.nn.ModuleList()
33
+ for i in range(part_dim):
34
+ layer_stack.append(torch.nn.Sequential())
35
+ for j in range(num_lin_layers):
36
+ layer_stack[i].add_module(f"fc_{j}", torch.nn.Linear(latent_dim, self.out_dim, bias=bias))
37
+ if act_layer:
38
+ layer_stack[i].add_module(f"act_{j}", torch.nn.GELU())
39
+ self.feature_layers = layer_stack
40
+ self.reset_weights()
41
+
42
+ def __repr__(self):
43
+ return f"IndependentMLPs(part_dim={self.part_dim}, latent_dim={self.latent_dim}), bias={self.bias}"
44
+
45
+ def reset_weights(self):
46
+ """ Initialize weights with a identity matrix"""
47
+ for layer in self.feature_layers:
48
+ for m in layer.modules():
49
+ if isinstance(m, torch.nn.Linear):
50
+ # Initialize weights with a truncated normal distribution
51
+ torch.nn.init.trunc_normal_(m.weight, std=0.02)
52
+ if m.bias is not None:
53
+ torch.nn.init.zeros_(m.bias)
54
+
55
+ def forward(self, x):
56
+ """ Input X has the dimensions batch x latent_dim x part_dim """
57
+
58
+ outputs = []
59
+ for i, layer in enumerate(self.feature_layers):
60
+ if self.stack_dim == -1:
61
+ in_ = x[..., i]
62
+ else:
63
+ in_ = x[:, i, ...] # Select feature i
64
+ out = layer(in_) # Apply MLP to feature i
65
+ outputs.append(out)
66
+
67
+ x = torch.stack(outputs, dim=self.stack_dim) # Stack the outputs
68
+
69
+ return x
layers/transformer_layers.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Attention Block with option to return the mean of k over heads from attention
2
+
3
+ import torch
4
+ from timm.models.vision_transformer import Attention, Block
5
+ import torch.nn.functional as F
6
+ from typing import Tuple
7
+
8
+
9
+ class AttentionWQKVReturn(Attention):
10
+ """
11
+ Modifications:
12
+ - Return the qkv tensors from the attention
13
+ """
14
+
15
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
16
+ B, N, C = x.shape
17
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
18
+ q, k, v = qkv.unbind(0)
19
+ q, k = self.q_norm(q), self.k_norm(k)
20
+
21
+ if self.fused_attn:
22
+ x = F.scaled_dot_product_attention(
23
+ q, k, v,
24
+ dropout_p=self.attn_drop.p if self.training else 0.,
25
+ )
26
+ else:
27
+ q = q * self.scale
28
+ attn = q @ k.transpose(-2, -1)
29
+ attn = attn.softmax(dim=-1)
30
+ attn = self.attn_drop(attn)
31
+ x = attn @ v
32
+
33
+ x = x.transpose(1, 2).reshape(B, N, C)
34
+ x = self.proj(x)
35
+ x = self.proj_drop(x)
36
+ return x, torch.stack((q, k, v), dim=0)
37
+
38
+
39
+ class BlockWQKVReturn(Block):
40
+ """
41
+ Modifications:
42
+ - Use AttentionWQKVReturn instead of Attention
43
+ - Return the qkv tensors from the attention
44
+ """
45
+
46
+ def forward(self, x: torch.Tensor, return_qkv: bool = False) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
47
+ # Note: this is copied from timm.models.vision_transformer.Block with modifications.
48
+ x_attn, qkv = self.attn(self.norm1(x))
49
+ x = x + self.drop_path1(self.ls1(x_attn))
50
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
51
+ if return_qkv:
52
+ return x, qkv
53
+ else:
54
+ return x
load_model.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ from timm.models import create_model
7
+ from torchvision.models import get_model
8
+
9
+ from models import pdiscoformer_vit_bb, pdisconet_vit_bb, pdisconet_resnet_torchvision_bb
10
+ from models.individual_landmark_resnet import IndividualLandmarkResNet
11
+ from models.individual_landmark_convnext import IndividualLandmarkConvNext
12
+ from models.individual_landmark_vit import IndividualLandmarkViT
13
+ from utils import load_state_dict_pdisco
14
+
15
+
16
+ def load_model_arch(args, num_cls):
17
+ """
18
+ Function to load the model
19
+ :param args: Arguments from the command line
20
+ :param num_cls: Number of classes in the dataset
21
+ :return:
22
+ """
23
+ if 'resnet' in args.model_arch:
24
+ num_layers_split = [int(s) for s in args.model_arch if s.isdigit()]
25
+ num_layers = int(''.join(map(str, num_layers_split)))
26
+ if num_layers >= 100:
27
+ timm_model_arch = args.model_arch + ".a1h_in1k"
28
+ else:
29
+ timm_model_arch = args.model_arch + ".a1_in1k"
30
+
31
+ if "resnet" in args.model_arch and args.use_torchvision_resnet_model:
32
+ weights = "DEFAULT" if args.pretrained_start_weights else None
33
+ base_model = get_model(args.model_arch, weights=weights)
34
+ elif "resnet" in args.model_arch and not args.use_torchvision_resnet_model:
35
+ if args.eval_only:
36
+ base_model = create_model(
37
+ timm_model_arch,
38
+ pretrained=args.pretrained_start_weights,
39
+ num_classes=num_cls,
40
+ output_stride=args.output_stride,
41
+ )
42
+ else:
43
+ base_model = create_model(
44
+ timm_model_arch,
45
+ pretrained=args.pretrained_start_weights,
46
+ drop_path_rate=args.drop_path,
47
+ num_classes=num_cls,
48
+ output_stride=args.output_stride,
49
+ )
50
+
51
+ elif "convnext" in args.model_arch:
52
+ if args.eval_only:
53
+ base_model = create_model(
54
+ args.model_arch,
55
+ pretrained=args.pretrained_start_weights,
56
+ num_classes=num_cls,
57
+ output_stride=args.output_stride,
58
+ )
59
+ else:
60
+ base_model = create_model(
61
+ args.model_arch,
62
+ pretrained=args.pretrained_start_weights,
63
+ drop_path_rate=args.drop_path,
64
+ num_classes=num_cls,
65
+ output_stride=args.output_stride,
66
+ )
67
+ elif "vit" in args.model_arch:
68
+ if args.eval_only:
69
+ base_model = create_model(
70
+ args.model_arch,
71
+ pretrained=args.pretrained_start_weights,
72
+ img_size=args.image_size,
73
+ )
74
+ else:
75
+ base_model = create_model(
76
+ args.model_arch,
77
+ pretrained=args.pretrained_start_weights,
78
+ drop_path_rate=args.drop_path,
79
+ img_size=args.image_size,
80
+ )
81
+ vit_patch_size = base_model.patch_embed.proj.kernel_size[0]
82
+ if args.image_size % vit_patch_size != 0:
83
+ raise ValueError(f"Image size {args.image_size} must be divisible by patch size {vit_patch_size}")
84
+ else:
85
+ raise ValueError('Model not supported.')
86
+
87
+ return base_model
88
+
89
+
90
+ def init_pdisco_model(base_model, args, num_cls):
91
+ """
92
+ Function to initialize the model
93
+ :param base_model: Base model
94
+ :param args: Arguments from the command line
95
+ :param num_cls: Number of classes in the dataset
96
+ :return:
97
+ """
98
+ # Initialize the network
99
+ if 'convnext' in args.model_arch:
100
+ sl_channels = base_model.stages[-1].downsample[-1].in_channels
101
+ fl_channels = base_model.head.in_features
102
+ model = IndividualLandmarkConvNext(base_model, args.num_parts, num_classes=num_cls,
103
+ sl_channels=sl_channels, fl_channels=fl_channels,
104
+ part_dropout=args.part_dropout, modulation_type=args.modulation_type,
105
+ gumbel_softmax=args.gumbel_softmax,
106
+ gumbel_softmax_temperature=args.gumbel_softmax_temperature,
107
+ gumbel_softmax_hard=args.gumbel_softmax_hard,
108
+ modulation_orth=args.modulation_orth, classifier_type=args.classifier_type,
109
+ noise_variance=args.noise_variance)
110
+ elif 'resnet' in args.model_arch:
111
+ sl_channels = base_model.layer4[0].conv1.in_channels
112
+ fl_channels = base_model.fc.in_features
113
+ model = IndividualLandmarkResNet(base_model, args.num_parts, num_classes=num_cls,
114
+ sl_channels=sl_channels, fl_channels=fl_channels,
115
+ use_torchvision_model=args.use_torchvision_resnet_model,
116
+ part_dropout=args.part_dropout, modulation_type=args.modulation_type,
117
+ gumbel_softmax=args.gumbel_softmax,
118
+ gumbel_softmax_temperature=args.gumbel_softmax_temperature,
119
+ gumbel_softmax_hard=args.gumbel_softmax_hard,
120
+ modulation_orth=args.modulation_orth, classifier_type=args.classifier_type,
121
+ noise_variance=args.noise_variance)
122
+ elif 'vit' in args.model_arch:
123
+ model = IndividualLandmarkViT(base_model, num_landmarks=args.num_parts, num_classes=num_cls,
124
+ part_dropout=args.part_dropout,
125
+ modulation_type=args.modulation_type, gumbel_softmax=args.gumbel_softmax,
126
+ gumbel_softmax_temperature=args.gumbel_softmax_temperature,
127
+ gumbel_softmax_hard=args.gumbel_softmax_hard,
128
+ modulation_orth=args.modulation_orth, classifier_type=args.classifier_type,
129
+ noise_variance=args.noise_variance)
130
+ else:
131
+ raise ValueError('Model not supported.')
132
+
133
+ return model
134
+
135
+
136
+ def load_model_pdisco(args, num_cls):
137
+ """
138
+ Function to load the model
139
+ :param args: Arguments from the command line
140
+ :param num_cls: Number of classes in the dataset
141
+ :return:
142
+ """
143
+ base_model = load_model_arch(args, num_cls)
144
+ model = init_pdisco_model(base_model, args, num_cls)
145
+
146
+ return model
147
+
148
+
149
+ def pdiscoformer_vit(pretrained=True, model_dataset="cub", k=8, model_url="", img_size=224, num_cls=200):
150
+ """
151
+ Function to load the PDiscoFormer model with ViT backbone
152
+ :param pretrained: Boolean flag to load the pretrained weights
153
+ :param model_dataset: Dataset for which the model is trained
154
+ :param k: Number of unsupervised landmarks the model is trained on
155
+ :param model_url: URL to load the model weights from
156
+ :param img_size: Image size
157
+ :param num_cls: Number of classes in the dataset
158
+ :return: PDiscoFormer model with ViT backbone
159
+ """
160
+ model = pdiscoformer_vit_bb("vit_base_patch14_reg4_dinov2.lvd142m", num_cls=num_cls, k=k, img_size=img_size)
161
+ if pretrained:
162
+ hub_dir = torch.hub.get_dir()
163
+ model_dir = os.path.join(hub_dir, "pdiscoformer_checkpoints", f"pdiscoformer_{model_dataset}")
164
+
165
+ Path(model_dir).mkdir(parents=True, exist_ok=True)
166
+ url_path = model_url + str(k) + "_parts_snapshot_best.pt"
167
+ snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu')
168
+ if 'model_state' in snapshot_data:
169
+ _, state_dict = load_state_dict_pdisco(snapshot_data)
170
+ else:
171
+ state_dict = copy.deepcopy(snapshot_data)
172
+ model.load_state_dict(state_dict, strict=True)
173
+ return model
174
+
175
+
176
+ def pdisconet_vit(pretrained=True, model_dataset="nabirds", k=8, model_url="", img_size=224, num_cls=555):
177
+ """
178
+ Function to load the PDiscoNet model with ViT backbone
179
+ :param pretrained: Boolean flag to load the pretrained weights
180
+ :param model_dataset: Dataset for which the model is trained
181
+ :param k: Number of unsupervised landmarks the model is trained on
182
+ :param model_url: URL to load the model weights from
183
+ :param img_size: Image size
184
+ :param num_cls: Number of classes in the dataset
185
+ :return: PDiscoNet model with ViT backbone
186
+ """
187
+ model = pdisconet_vit_bb("vit_base_patch14_reg4_dinov2.lvd142m", num_cls=num_cls, k=k, img_size=img_size)
188
+ if pretrained:
189
+ hub_dir = torch.hub.get_dir()
190
+ model_dir = os.path.join(hub_dir, "pdiscoformer_checkpoints", f"pdisconet_{model_dataset}")
191
+
192
+ Path(model_dir).mkdir(parents=True, exist_ok=True)
193
+ url_path = model_url + str(k) + "_parts_snapshot_best.pt"
194
+ snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu')
195
+ if 'model_state' in snapshot_data:
196
+ _, state_dict = load_state_dict_pdisco(snapshot_data)
197
+ else:
198
+ state_dict = copy.deepcopy(snapshot_data)
199
+ model.load_state_dict(state_dict, strict=True)
200
+ return model
201
+
202
+
203
+ def pdisconet_resnet101(pretrained=True, model_dataset="nabirds", k=8, model_url="", num_cls=555):
204
+ """
205
+ Function to load the PDiscoNet model with ResNet-101 backbone
206
+ :param pretrained: Boolean flag to load the pretrained weights
207
+ :param model_dataset: Dataset for which the model is trained
208
+ :param k: Number of unsupervised landmarks the model is trained on
209
+ :param model_url: URL to load the model weights from
210
+ :param num_cls: Number of classes in the dataset
211
+ :return: PDiscoNet model with ResNet-101 backbone
212
+ """
213
+ model = pdisconet_resnet_torchvision_bb("resnet101", num_cls=num_cls, k=k)
214
+ if pretrained:
215
+ hub_dir = torch.hub.get_dir()
216
+ model_dir = os.path.join(hub_dir, "pdiscoformer_checkpoints", f"pdisconet_{model_dataset}")
217
+
218
+ Path(model_dir).mkdir(parents=True, exist_ok=True)
219
+ url_path = model_url + str(k) + "_parts_snapshot_best.pt"
220
+ snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu')
221
+ if 'model_state' in snapshot_data:
222
+ _, state_dict = load_state_dict_pdisco(snapshot_data)
223
+ else:
224
+ state_dict = copy.deepcopy(snapshot_data)
225
+ model.load_state_dict(state_dict, strict=True)
226
+ return model
models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .individual_landmark_resnet import *
2
+ from .individual_landmark_convnext import *
3
+ from .vit_baseline import *
4
+ from .individual_landmark_vit import *
models/individual_landmark_convnext.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ from torch.nn import Parameter
4
+ from typing import Any
5
+ from layers.independent_mlp import IndependentMLPs
6
+
7
+
8
+ # Baseline model, a modified convnext with reduced downsampling for a spatially larger feature tensor in the last layer
9
+ class IndividualLandmarkConvNext(torch.nn.Module):
10
+ def __init__(self, init_model: torch.nn.Module, num_landmarks: int = 8,
11
+ num_classes: int = 200, sl_channels: int = 1024, fl_channels: int = 2048, part_dropout: float = 0.3,
12
+ modulation_type: str = "original", modulation_orth: bool = False, gumbel_softmax: bool = False,
13
+ gumbel_softmax_temperature: float = 1.0, gumbel_softmax_hard: bool = False,
14
+ classifier_type: str = "linear", noise_variance: float = 0.0) -> None:
15
+ super().__init__()
16
+
17
+ self.num_landmarks = num_landmarks
18
+ self.num_classes = num_classes
19
+ self.noise_variance = noise_variance
20
+ self.stem = init_model.stem
21
+ self.stages = init_model.stages
22
+ self.feature_dim = sl_channels + fl_channels
23
+ self.fc_landmarks = torch.nn.Conv2d(self.feature_dim, num_landmarks + 1, 1, bias=False)
24
+ self.gumbel_softmax = gumbel_softmax
25
+ self.gumbel_softmax_temperature = gumbel_softmax_temperature
26
+ self.gumbel_softmax_hard = gumbel_softmax_hard
27
+ self.modulation_type = modulation_type
28
+ if modulation_type == "layer_norm":
29
+ self.modulation = torch.nn.LayerNorm([self.feature_dim, self.num_landmarks + 1])
30
+ elif modulation_type == "original":
31
+ self.modulation = torch.nn.Parameter(torch.ones(1, self.feature_dim, self.num_landmarks + 1))
32
+ elif modulation_type == "parallel_mlp":
33
+ self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
34
+ num_lin_layers=1, act_layer=True, bias=True)
35
+ elif modulation_type == "parallel_mlp_no_bias":
36
+ self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
37
+ num_lin_layers=1, act_layer=True, bias=False)
38
+ elif modulation_type == "parallel_mlp_no_act":
39
+ self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
40
+ num_lin_layers=1, act_layer=False, bias=True)
41
+ elif modulation_type == "parallel_mlp_no_act_no_bias":
42
+ self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
43
+ num_lin_layers=1, act_layer=False, bias=False)
44
+ elif modulation_type == "none":
45
+ self.modulation = torch.nn.Identity()
46
+ else:
47
+ raise ValueError("modulation_type not implemented")
48
+ self.modulation_orth = modulation_orth
49
+ self.dropout_full_landmarks = torch.nn.Dropout1d(part_dropout)
50
+ self.classifier_type = classifier_type
51
+ if classifier_type == "independent_mlp":
52
+ self.fc_class_landmarks = IndependentMLPs(part_dim=self.num_landmarks, latent_dim=self.feature_dim,
53
+ num_lin_layers=1, act_layer=False, out_dim=num_classes,
54
+ bias=False, stack_dim=1)
55
+ elif classifier_type == "linear":
56
+ self.fc_class_landmarks = torch.nn.Linear(in_features=self.feature_dim, out_features=num_classes,
57
+ bias=False)
58
+ else:
59
+ raise ValueError("classifier_type not implemented")
60
+
61
+ def forward(self, x: Tensor) -> tuple[Any, Any, Any, Any, Parameter, int | Any]:
62
+ # Pretrained ConvNeXt part of the model
63
+ x = self.stem(x)
64
+ x = self.stages[0](x)
65
+ x = self.stages[1](x)
66
+ l3 = self.stages[2](x)
67
+ x = self.stages[3](l3)
68
+ x = torch.nn.functional.interpolate(x, size=(l3.shape[-2], l3.shape[-1]), mode='bilinear', align_corners=False)
69
+ x = torch.cat((x, l3), dim=1)
70
+
71
+ # Compute per landmark attention maps
72
+ # (b - a)^2 = b^2 - 2ab + a^2, b = feature maps resnet, a = convolution kernel
73
+ batch_size = x.shape[0]
74
+ ab = self.fc_landmarks(x)
75
+ b_sq = x.pow(2).sum(1, keepdim=True)
76
+ b_sq = b_sq.expand(-1, self.num_landmarks + 1, -1, -1).contiguous()
77
+ a_sq = self.fc_landmarks.weight.pow(2).sum(1).unsqueeze(1).expand(-1, batch_size, x.shape[-2],
78
+ x.shape[-1]).contiguous()
79
+ a_sq = a_sq.permute(1, 0, 2, 3).contiguous()
80
+
81
+ dist = b_sq - 2 * ab + a_sq
82
+ maps = -dist
83
+
84
+ # Softmax so that the attention maps for each pixel add up to 1
85
+ if self.gumbel_softmax:
86
+ maps = torch.nn.functional.gumbel_softmax(maps, dim=1, tau=self.gumbel_softmax_temperature,
87
+ hard=self.gumbel_softmax_hard) # [B, num_landmarks + 1, H, W]
88
+ else:
89
+ maps = torch.nn.functional.softmax(maps, dim=1) # [B, num_landmarks + 1, H, W]
90
+
91
+ # Use maps to get weighted average features per landmark
92
+ all_features = (maps.unsqueeze(1) * x.unsqueeze(2)).mean(-1).mean(-1).contiguous()
93
+ if self.noise_variance > 0.0:
94
+ all_features += torch.randn_like(all_features,
95
+ device=all_features.device) * x.std().detach() * self.noise_variance
96
+
97
+ # Modulate the features
98
+ if self.modulation_type == "original":
99
+ all_features_mod = all_features * self.modulation
100
+ else:
101
+ all_features_mod = self.modulation(all_features)
102
+
103
+ # Classification based on the landmark features
104
+ scores = self.fc_class_landmarks(
105
+ self.dropout_full_landmarks(all_features_mod[..., :-1].permute(0, 2, 1).contiguous())).permute(0, 2,
106
+ 1).contiguous()
107
+ if self.modulation_orth:
108
+ return all_features_mod, maps, scores, dist
109
+ else:
110
+ return all_features, maps, scores, dist
models/individual_landmark_resnet.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/robertdvdk/part_detection/blob/main/nets.py
2
+ import torch
3
+ from torch import Tensor
4
+ from timm.models import create_model
5
+ from torchvision.models import get_model
6
+ from torch.nn import Parameter
7
+ from typing import Any
8
+ from layers.independent_mlp import IndependentMLPs
9
+
10
+
11
+ # Baseline model, a modified ResNet with reduced downsampling for a spatially larger feature tensor in the last layer
12
+ class IndividualLandmarkResNet(torch.nn.Module):
13
+ def __init__(self, init_model: torch.nn.Module, num_landmarks: int = 8,
14
+ num_classes: int = 200, sl_channels: int = 1024, fl_channels: int = 2048,
15
+ use_torchvision_model: bool = False, part_dropout: float = 0.3,
16
+ modulation_type: str = "original", modulation_orth: bool = False, gumbel_softmax: bool = False,
17
+ gumbel_softmax_temperature: float = 1.0, gumbel_softmax_hard: bool = False,
18
+ classifier_type: str = "linear", noise_variance: float = 0.0) -> None:
19
+ super().__init__()
20
+
21
+ self.num_landmarks = num_landmarks
22
+ self.num_classes = num_classes
23
+ self.noise_variance = noise_variance
24
+ self.conv1 = init_model.conv1
25
+ self.bn1 = init_model.bn1
26
+ if use_torchvision_model:
27
+ self.act1 = init_model.relu
28
+ else:
29
+ self.act1 = init_model.act1
30
+ self.maxpool = init_model.maxpool
31
+ self.layer1 = init_model.layer1
32
+ self.layer2 = init_model.layer2
33
+ self.layer3 = init_model.layer3
34
+ self.layer4 = init_model.layer4
35
+ self.feature_dim = sl_channels + fl_channels
36
+ self.fc_landmarks = torch.nn.Conv2d(self.feature_dim, num_landmarks + 1, 1, bias=False)
37
+ self.gumbel_softmax = gumbel_softmax
38
+ self.gumbel_softmax_temperature = gumbel_softmax_temperature
39
+ self.gumbel_softmax_hard = gumbel_softmax_hard
40
+ self.modulation_type = modulation_type
41
+ if modulation_type == "layer_norm":
42
+ self.modulation = torch.nn.LayerNorm([self.feature_dim, self.num_landmarks + 1])
43
+ elif modulation_type == "original":
44
+ self.modulation = torch.nn.Parameter(torch.ones(1, self.feature_dim, self.num_landmarks + 1))
45
+ elif modulation_type == "parallel_mlp":
46
+ self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
47
+ num_lin_layers=1, act_layer=True, bias=True)
48
+ elif modulation_type == "parallel_mlp_no_bias":
49
+ self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
50
+ num_lin_layers=1, act_layer=True, bias=False)
51
+ elif modulation_type == "parallel_mlp_no_act":
52
+ self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
53
+ num_lin_layers=1, act_layer=False, bias=True)
54
+ elif modulation_type == "parallel_mlp_no_act_no_bias":
55
+ self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
56
+ num_lin_layers=1, act_layer=False, bias=False)
57
+ elif modulation_type == "none":
58
+ self.modulation = torch.nn.Identity()
59
+ else:
60
+ raise ValueError("modulation_type not implemented")
61
+
62
+ self.modulation_orth = modulation_orth
63
+
64
+ self.dropout_full_landmarks = torch.nn.Dropout1d(part_dropout)
65
+ self.classifier_type = classifier_type
66
+ if classifier_type == "independent_mlp":
67
+ self.fc_class_landmarks = IndependentMLPs(part_dim=self.num_landmarks, latent_dim=self.feature_dim,
68
+ num_lin_layers=1, act_layer=False, out_dim=num_classes,
69
+ bias=False, stack_dim=1)
70
+ elif classifier_type == "linear":
71
+ self.fc_class_landmarks = torch.nn.Linear(in_features=self.feature_dim, out_features=num_classes,
72
+ bias=False)
73
+ else:
74
+ raise ValueError("classifier_type not implemented")
75
+
76
+ def forward(self, x: Tensor) -> tuple[Any, Any, Any, Any, Parameter, int | Any]:
77
+ # Pretrained ResNet part of the model
78
+ x = self.conv1(x)
79
+ x = self.bn1(x)
80
+ x = self.act1(x)
81
+ x = self.maxpool(x)
82
+ x = self.layer1(x)
83
+ x = self.layer2(x)
84
+ l3 = self.layer3(x)
85
+ x = self.layer4(l3)
86
+ x = torch.nn.functional.interpolate(x, size=(l3.shape[-2], l3.shape[-1]), mode='bilinear', align_corners=False)
87
+ x = torch.cat((x, l3), dim=1)
88
+
89
+ # Compute per landmark attention maps
90
+ # (b - a)^2 = b^2 - 2ab + a^2, b = feature maps resnet, a = convolution kernel
91
+ batch_size = x.shape[0]
92
+
93
+ ab = self.fc_landmarks(x)
94
+ b_sq = x.pow(2).sum(1, keepdim=True)
95
+ b_sq = b_sq.expand(-1, self.num_landmarks + 1, -1, -1).contiguous()
96
+ a_sq = self.fc_landmarks.weight.pow(2).sum(1).unsqueeze(1).expand(-1, batch_size, x.shape[-2],
97
+ x.shape[-1]).contiguous()
98
+ a_sq = a_sq.permute(1, 0, 2, 3).contiguous()
99
+
100
+ dist = b_sq - 2 * ab + a_sq
101
+ maps = -dist
102
+
103
+ # Softmax so that the attention maps for each pixel add up to 1
104
+ if self.gumbel_softmax:
105
+ maps = torch.nn.functional.gumbel_softmax(maps, dim=1, tau=self.gumbel_softmax_temperature,
106
+ hard=self.gumbel_softmax_hard) # [B, num_landmarks + 1, H, W]
107
+ else:
108
+ maps = torch.nn.functional.softmax(maps, dim=1) # [B, num_landmarks + 1, H, W]
109
+
110
+ # Use maps to get weighted average features per landmark
111
+ all_features = (maps.unsqueeze(1) * x.unsqueeze(2)).mean(-1).mean(-1).contiguous()
112
+ if self.noise_variance > 0.0:
113
+ all_features += torch.randn_like(all_features,
114
+ device=all_features.device) * x.std().detach() * self.noise_variance
115
+
116
+ # Modulate the features
117
+ if self.modulation_type == "original":
118
+ all_features_mod = all_features * self.modulation
119
+ else:
120
+ all_features_mod = self.modulation(all_features)
121
+
122
+ # Classification based on the landmark features
123
+ scores = self.fc_class_landmarks(
124
+ self.dropout_full_landmarks(all_features_mod[..., :-1].permute(0, 2, 1).contiguous())).permute(0, 2,
125
+ 1).contiguous()
126
+ if self.modulation_orth:
127
+ return all_features_mod, maps, scores, dist
128
+ else:
129
+ return all_features, maps, scores, dist
130
+
131
+
132
+ def pdisconet_resnet_torchvision_bb(backbone, num_cls=200, k=8, **kwargs):
133
+ base_model = get_model(backbone)
134
+ return IndividualLandmarkResNet(base_model, num_landmarks=k, num_classes=num_cls,
135
+ modulation_type="original")
136
+
137
+
138
+ def pdisconet_resnet_timm_bb(backbone, num_cls=200, k=8, output_stride=32, **kwargs):
139
+ base_model = create_model(backbone, pretrained=True, output_stride=output_stride)
140
+ return IndividualLandmarkResNet(base_model, num_landmarks=k, num_classes=num_cls,
141
+ modulation_type="original")
models/individual_landmark_vit.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Compostion of the VisionTransformer class from timm with extra features: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
2
+ from pathlib import Path
3
+ import os
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import Tensor
7
+ from typing import Any, Union, Sequence, Optional, Dict
8
+
9
+ from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
10
+
11
+ from timm.models import create_model
12
+ from timm.models.vision_transformer import Block, Attention
13
+ from utils.misc_utils import compute_attention
14
+
15
+ from layers.transformer_layers import BlockWQKVReturn, AttentionWQKVReturn
16
+ from layers.independent_mlp import IndependentMLPs
17
+
18
+ SAFETENSORS_SINGLE_FILE = "model.safetensors"
19
+
20
+
21
+ class IndividualLandmarkViT(torch.nn.Module, PyTorchModelHubMixin,
22
+ pipeline_tag='image-classification',
23
+ repo_url='https://github.com/ananthu-aniraj/pdiscoformer'):
24
+
25
+ def __init__(self, init_model: torch.nn.Module, num_landmarks: int = 8, num_classes: int = 200,
26
+ part_dropout: float = 0.3, return_transformer_qkv: bool = False,
27
+ modulation_type: str = "original", gumbel_softmax: bool = False,
28
+ gumbel_softmax_temperature: float = 1.0, gumbel_softmax_hard: bool = False,
29
+ modulation_orth: bool = False, classifier_type: str = "linear", noise_variance: float = 0.0) -> None:
30
+ super().__init__()
31
+ self.num_landmarks = num_landmarks
32
+ self.num_classes = num_classes
33
+ self.noise_variance = noise_variance
34
+ self.num_prefix_tokens = init_model.num_prefix_tokens
35
+ self.num_reg_tokens = init_model.num_reg_tokens
36
+ self.has_class_token = init_model.has_class_token
37
+ self.no_embed_class = init_model.no_embed_class
38
+ self.cls_token = init_model.cls_token
39
+ self.reg_token = init_model.reg_token
40
+
41
+ self.feature_dim = init_model.embed_dim
42
+ self.patch_embed = init_model.patch_embed
43
+ self.pos_embed = init_model.pos_embed
44
+ self.pos_drop = init_model.pos_drop
45
+ self.norm_pre = init_model.norm_pre
46
+ self.blocks = init_model.blocks
47
+ self.norm = init_model.norm
48
+ self.return_transformer_qkv = return_transformer_qkv
49
+ self.h_fmap = int(self.patch_embed.img_size[0] // self.patch_embed.patch_size[0])
50
+ self.w_fmap = int(self.patch_embed.img_size[1] // self.patch_embed.patch_size[1])
51
+
52
+ self.unflatten = nn.Unflatten(1, (self.h_fmap, self.w_fmap))
53
+ self.fc_landmarks = torch.nn.Conv2d(self.feature_dim, num_landmarks + 1, 1, bias=False)
54
+ self.gumbel_softmax = gumbel_softmax
55
+ self.gumbel_softmax_temperature = gumbel_softmax_temperature
56
+ self.gumbel_softmax_hard = gumbel_softmax_hard
57
+ self.modulation_type = modulation_type
58
+ if modulation_type == "layer_norm":
59
+ self.modulation = torch.nn.LayerNorm([self.feature_dim, self.num_landmarks + 1])
60
+ elif modulation_type == "original":
61
+ self.modulation = torch.nn.Parameter(torch.ones(1, self.feature_dim, self.num_landmarks + 1))
62
+ elif modulation_type == "parallel_mlp":
63
+ self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
64
+ num_lin_layers=1, act_layer=True, bias=True)
65
+ elif modulation_type == "parallel_mlp_no_bias":
66
+ self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
67
+ num_lin_layers=1, act_layer=True, bias=False)
68
+ elif modulation_type == "parallel_mlp_no_act":
69
+ self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
70
+ num_lin_layers=1, act_layer=False, bias=True)
71
+ elif modulation_type == "parallel_mlp_no_act_no_bias":
72
+ self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
73
+ num_lin_layers=1, act_layer=False, bias=False)
74
+ elif modulation_type == "none":
75
+ self.modulation = torch.nn.Identity()
76
+ else:
77
+ raise ValueError("modulation_type not implemented")
78
+ self.modulation_orth = modulation_orth
79
+ self.dropout_full_landmarks = torch.nn.Dropout1d(part_dropout)
80
+ self.classifier_type = classifier_type
81
+ if classifier_type == "independent_mlp":
82
+ self.fc_class_landmarks = IndependentMLPs(part_dim=self.num_landmarks, latent_dim=self.feature_dim,
83
+ num_lin_layers=1, act_layer=False, out_dim=num_classes,
84
+ bias=False, stack_dim=1)
85
+ elif classifier_type == "linear":
86
+ self.fc_class_landmarks = torch.nn.Linear(in_features=self.feature_dim, out_features=num_classes,
87
+ bias=False)
88
+ else:
89
+ raise ValueError("classifier_type not implemented")
90
+ self.convert_blocks_and_attention()
91
+ self._init_weights()
92
+
93
+ def _init_weights_head(self):
94
+ # Initialize weights with a truncated normal distribution
95
+ if self.classifier_type == "independent_mlp":
96
+ self.fc_class_landmarks.reset_weights()
97
+ else:
98
+ torch.nn.init.trunc_normal_(self.fc_class_landmarks.weight, std=0.02)
99
+ if self.fc_class_landmarks.bias is not None:
100
+ torch.nn.init.zeros_(self.fc_class_landmarks.bias)
101
+
102
+ def _init_weights(self):
103
+ self._init_weights_head()
104
+
105
+ def convert_blocks_and_attention(self):
106
+ for module in self.modules():
107
+ if isinstance(module, Block):
108
+ module.__class__ = BlockWQKVReturn
109
+ elif isinstance(module, Attention):
110
+ module.__class__ = AttentionWQKVReturn
111
+
112
+ def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
113
+ pos_embed = self.pos_embed
114
+ to_cat = []
115
+ if self.cls_token is not None:
116
+ to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
117
+ if self.reg_token is not None:
118
+ to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
119
+ if self.no_embed_class:
120
+ # deit-3, updated JAX (big vision)
121
+ # position embedding does not overlap with class token, add then concat
122
+ x = x + pos_embed
123
+ if to_cat:
124
+ x = torch.cat(to_cat + [x], dim=1)
125
+ else:
126
+ # original timm, JAX, and deit vit impl
127
+ # pos_embed has entry for class token, concat then add
128
+ if to_cat:
129
+ x = torch.cat(to_cat + [x], dim=1)
130
+ x = x + pos_embed
131
+ return self.pos_drop(x)
132
+
133
+ def forward(self, x: Tensor) -> tuple[Any, Any, Any, Any, int | Any] | tuple[Any, Any, Any, Any, int | Any]:
134
+
135
+ x = self.patch_embed(x)
136
+
137
+ # Position Embedding
138
+ x = self._pos_embed(x)
139
+
140
+ # Forward pass through transformer
141
+ x = self.norm_pre(x)
142
+
143
+ x = self.blocks(x)
144
+ x = self.norm(x)
145
+
146
+ # Compute per landmark attention maps
147
+ # (b - a)^2 = b^2 - 2ab + a^2, b = feature maps vit, a = convolution kernel
148
+ batch_size = x.shape[0]
149
+ x = x[:, self.num_prefix_tokens:, :] # [B, num_patch_tokens, embed_dim]
150
+ x = self.unflatten(x) # [B, H, W, embed_dim]
151
+ x = x.permute(0, 3, 1, 2).contiguous() # [B, embed_dim, H, W]
152
+ ab = self.fc_landmarks(x) # [B, num_landmarks + 1, H, W]
153
+ b_sq = x.pow(2).sum(1, keepdim=True)
154
+ b_sq = b_sq.expand(-1, self.num_landmarks + 1, -1, -1).contiguous()
155
+ a_sq = self.fc_landmarks.weight.pow(2).sum(1, keepdim=True).expand(-1, batch_size, x.shape[-2],
156
+ x.shape[-1]).contiguous()
157
+ a_sq = a_sq.permute(1, 0, 2, 3).contiguous()
158
+
159
+ dist = b_sq - 2 * ab + a_sq
160
+ maps = -dist
161
+
162
+ # Softmax so that the attention maps for each pixel add up to 1
163
+ if self.gumbel_softmax:
164
+ maps = torch.nn.functional.gumbel_softmax(maps, dim=1, tau=self.gumbel_softmax_temperature,
165
+ hard=self.gumbel_softmax_hard) # [B, num_landmarks + 1, H, W]
166
+ else:
167
+ maps = torch.nn.functional.softmax(maps, dim=1) # [B, num_landmarks + 1, H, W]
168
+
169
+ # Use maps to get weighted average features per landmark
170
+ all_features = (maps.unsqueeze(1) * x.unsqueeze(2)).contiguous()
171
+ if self.noise_variance > 0.0:
172
+ all_features += torch.randn_like(all_features,
173
+ device=all_features.device) * x.std().detach() * self.noise_variance
174
+
175
+ all_features = all_features.mean(-1).mean(-1).contiguous() # [B, embed_dim, num_landmarks + 1]
176
+
177
+ # Modulate the features
178
+ if self.modulation_type == "original":
179
+ all_features_mod = all_features * self.modulation # [B, embed_dim, num_landmarks + 1]
180
+ else:
181
+ all_features_mod = self.modulation(all_features) # [B, embed_dim, num_landmarks + 1]
182
+
183
+ # Classification based on the landmark features
184
+ scores = self.fc_class_landmarks(
185
+ self.dropout_full_landmarks(all_features_mod[..., :-1].permute(0, 2, 1).contiguous())).permute(0, 2,
186
+ 1).contiguous()
187
+ if self.modulation_orth:
188
+ return all_features_mod, maps, scores, dist
189
+ else:
190
+ return all_features, maps, scores, dist
191
+
192
+ def get_specific_intermediate_layer(
193
+ self,
194
+ x: torch.Tensor,
195
+ n: int = 1,
196
+ return_qkv: bool = False,
197
+ return_att_weights: bool = False,
198
+ ):
199
+ num_blocks = len(self.blocks)
200
+ attn_weights = []
201
+ if n >= num_blocks:
202
+ raise ValueError(f"n must be less than {num_blocks}")
203
+
204
+ # forward pass
205
+ x = self.patch_embed(x)
206
+ x = self._pos_embed(x)
207
+ x = self.norm_pre(x)
208
+
209
+ if n == -1:
210
+ if return_qkv:
211
+ raise ValueError("take_indice cannot be -1 if return_transformer_qkv is True")
212
+ else:
213
+ return x
214
+
215
+ for i, blk in enumerate(self.blocks):
216
+ if self.return_transformer_qkv:
217
+ x, qkv = blk(x, return_qkv=True)
218
+
219
+ if return_att_weights:
220
+ attn_weight, _ = compute_attention(qkv)
221
+ attn_weights.append(attn_weight.detach())
222
+ else:
223
+ x = blk(x)
224
+ if i == n:
225
+ output = x.clone()
226
+ if self.return_transformer_qkv and return_qkv:
227
+ qkv_output = qkv.clone()
228
+ break
229
+ if self.return_transformer_qkv and return_qkv and return_att_weights:
230
+ return output, qkv_output, attn_weights
231
+ elif self.return_transformer_qkv and return_qkv:
232
+ return output, qkv_output
233
+ elif self.return_transformer_qkv and return_att_weights:
234
+ return output, attn_weights
235
+ else:
236
+ return output
237
+
238
+ def _intermediate_layers(
239
+ self,
240
+ x: torch.Tensor,
241
+ n: Union[int, Sequence] = 1,
242
+ ):
243
+ outputs, num_blocks = [], len(self.blocks)
244
+ if self.return_transformer_qkv:
245
+ qkv_outputs = []
246
+ take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
247
+
248
+ # forward pass
249
+ x = self.patch_embed(x)
250
+ x = self._pos_embed(x)
251
+ x = self.norm_pre(x)
252
+
253
+ for i, blk in enumerate(self.blocks):
254
+ if self.return_transformer_qkv:
255
+ x, qkv = blk(x, return_qkv=True)
256
+ else:
257
+ x = blk(x)
258
+ if i in take_indices:
259
+ outputs.append(x)
260
+ if self.return_transformer_qkv:
261
+ qkv_outputs.append(qkv)
262
+ if self.return_transformer_qkv:
263
+ return outputs, qkv_outputs
264
+ else:
265
+ return outputs
266
+
267
+ def get_intermediate_layers(
268
+ self,
269
+ x: torch.Tensor,
270
+ n: Union[int, Sequence] = 1,
271
+ reshape: bool = False,
272
+ return_prefix_tokens: bool = False,
273
+ norm: bool = False,
274
+ ) -> tuple[tuple, Any]:
275
+ """ Intermediate layer accessor (NOTE: This is a WIP experiment).
276
+ Inspired by DINO / DINOv2 interface
277
+ """
278
+ # take last n blocks if n is an int, if in is a sequence, select by matching indices
279
+ if self.return_transformer_qkv:
280
+ outputs, qkv = self._intermediate_layers(x, n)
281
+ else:
282
+ outputs = self._intermediate_layers(x, n)
283
+
284
+ if norm:
285
+ outputs = [self.norm(out) for out in outputs]
286
+ prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
287
+ outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
288
+
289
+ if reshape:
290
+ grid_size = self.patch_embed.grid_size
291
+ outputs = [
292
+ out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
293
+ for out in outputs
294
+ ]
295
+
296
+ if return_prefix_tokens:
297
+ return_out = tuple(zip(outputs, prefix_tokens))
298
+ else:
299
+ return_out = tuple(outputs)
300
+
301
+ if self.return_transformer_qkv:
302
+ return return_out, qkv
303
+ else:
304
+ return return_out
305
+
306
+ @classmethod
307
+ def _from_pretrained(
308
+ cls,
309
+ *,
310
+ model_id: str,
311
+ revision: Optional[str],
312
+ cache_dir: Optional[Union[str, Path]],
313
+ force_download: bool,
314
+ proxies: Optional[Dict],
315
+ resume_download: Optional[bool],
316
+ local_files_only: bool,
317
+ token: Union[str, bool, None],
318
+ map_location: str = "cpu",
319
+ strict: bool = False,
320
+ timm_backbone: str = "hf_hub:timm/vit_base_patch14_reg4_dinov2.lvd142m",
321
+ input_size: int = 518,
322
+ **model_kwargs):
323
+ base_model = create_model(timm_backbone, pretrained=False, img_size=input_size)
324
+ model = cls(base_model, **model_kwargs)
325
+ if os.path.isdir(model_id):
326
+ print("Loading weights from local directory")
327
+ model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
328
+ return cls._load_as_safetensor(model, model_file, map_location, strict)
329
+ else:
330
+ model_file = hf_hub_download(
331
+ repo_id=model_id,
332
+ filename=SAFETENSORS_SINGLE_FILE,
333
+ revision=revision,
334
+ cache_dir=cache_dir,
335
+ force_download=force_download,
336
+ proxies=proxies,
337
+ resume_download=resume_download,
338
+ token=token,
339
+ local_files_only=local_files_only,
340
+ )
341
+ return cls._load_as_safetensor(model, model_file, map_location, strict)
342
+
343
+
344
+ def pdiscoformer_vit_bb(backbone, img_size=224, num_cls=200, k=8, **kwargs):
345
+ base_model = create_model(
346
+ backbone,
347
+ pretrained=False,
348
+ img_size=img_size,
349
+ )
350
+
351
+ model = IndividualLandmarkViT(base_model, num_landmarks=k, num_classes=num_cls,
352
+ modulation_type="layer_norm", gumbel_softmax=True,
353
+ modulation_orth=True)
354
+ return model
355
+
356
+
357
+ def pdisconet_vit_bb(backbone, img_size=224, num_cls=200, k=8, **kwargs):
358
+ base_model = create_model(
359
+ backbone,
360
+ pretrained=False,
361
+ img_size=img_size,
362
+ )
363
+
364
+ model = IndividualLandmarkViT(base_model, num_landmarks=k, num_classes=num_cls,
365
+ modulation_type="original")
366
+ return model
models/vit_baseline.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Compostion of the VisionTransformer class from timm with extra features: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import Tuple, Union, Sequence, Any
5
+ from timm.layers import trunc_normal_
6
+ from timm.models.vision_transformer import Block, Attention
7
+ from layers.transformer_layers import BlockWQKVReturn, AttentionWQKVReturn
8
+
9
+ from utils.misc_utils import compute_attention
10
+
11
+
12
+ class BaselineViT(torch.nn.Module):
13
+ """
14
+ Modifications:
15
+ - Use PDiscoBlock instead of Block
16
+ - Use PDiscoAttention instead of Attention
17
+ - Return the mean of k over heads from attention
18
+ - Option to use only class tokens or only patch tokens or both (concat) for classification
19
+ """
20
+
21
+ def __init__(self, init_model: torch.nn.Module, num_classes: int,
22
+ class_tokens_only: bool = False,
23
+ patch_tokens_only: bool = False, return_transformer_qkv: bool = False) -> None:
24
+ super().__init__()
25
+ self.num_classes = num_classes
26
+ self.class_tokens_only = class_tokens_only
27
+ self.patch_tokens_only = patch_tokens_only
28
+ self.num_prefix_tokens = init_model.num_prefix_tokens
29
+ self.num_reg_tokens = init_model.num_reg_tokens
30
+ self.has_class_token = init_model.has_class_token
31
+ self.no_embed_class = init_model.no_embed_class
32
+ self.cls_token = init_model.cls_token
33
+ self.reg_token = init_model.reg_token
34
+
35
+ self.patch_embed = init_model.patch_embed
36
+
37
+ self.pos_embed = init_model.pos_embed
38
+ self.pos_drop = init_model.pos_drop
39
+ self.part_embed = nn.Identity()
40
+ self.patch_prune = nn.Identity()
41
+ self.norm_pre = init_model.norm_pre
42
+ self.blocks = init_model.blocks
43
+ self.norm = init_model.norm
44
+
45
+ self.fc_norm = init_model.fc_norm
46
+ if class_tokens_only or patch_tokens_only:
47
+ self.head = nn.Linear(init_model.embed_dim, num_classes)
48
+ else:
49
+ self.head = nn.Linear(init_model.embed_dim * 2, num_classes)
50
+
51
+ self.h_fmap = int(self.patch_embed.img_size[0] // self.patch_embed.patch_size[0])
52
+ self.w_fmap = int(self.patch_embed.img_size[1] // self.patch_embed.patch_size[1])
53
+
54
+ self.return_transformer_qkv = return_transformer_qkv
55
+ self.convert_blocks_and_attention()
56
+ self._init_weights_head()
57
+
58
+ def convert_blocks_and_attention(self):
59
+ for module in self.modules():
60
+ if isinstance(module, Block):
61
+ module.__class__ = BlockWQKVReturn
62
+ elif isinstance(module, Attention):
63
+ module.__class__ = AttentionWQKVReturn
64
+
65
+ def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
66
+ pos_embed = self.pos_embed
67
+ to_cat = []
68
+ if self.cls_token is not None:
69
+ to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
70
+ if self.reg_token is not None:
71
+ to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
72
+ if self.no_embed_class:
73
+ # deit-3, updated JAX (big vision)
74
+ # position embedding does not overlap with class token, add then concat
75
+ x = x + pos_embed
76
+ if to_cat:
77
+ x = torch.cat(to_cat + [x], dim=1)
78
+ else:
79
+ # original timm, JAX, and deit vit impl
80
+ # pos_embed has entry for class token, concat then add
81
+ if to_cat:
82
+ x = torch.cat(to_cat + [x], dim=1)
83
+ x = x + pos_embed
84
+ return self.pos_drop(x)
85
+
86
+ def _init_weights_head(self):
87
+ trunc_normal_(self.head.weight, std=.02)
88
+ if self.head.bias is not None:
89
+ nn.init.constant_(self.head.bias, 0.)
90
+
91
+ def forward(self, x: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
92
+
93
+ x = self.patch_embed(x)
94
+
95
+ # Position Embedding
96
+ x = self._pos_embed(x)
97
+
98
+ x = self.part_embed(x)
99
+ x = self.patch_prune(x)
100
+
101
+ # Forward pass through transformer
102
+ x = self.norm_pre(x)
103
+
104
+ if self.return_transformer_qkv:
105
+ # Return keys of last attention layer
106
+ for i, blk in enumerate(self.blocks):
107
+ x, qkv = blk(x, return_qkv=True)
108
+ else:
109
+ x = self.blocks(x)
110
+
111
+ x = self.norm(x)
112
+
113
+ # Classification head
114
+ x = self.fc_norm(x)
115
+ if self.class_tokens_only: # only use class token
116
+ x = x[:, 0, :]
117
+ elif self.patch_tokens_only: # only use patch tokens
118
+ x = x[:, self.num_prefix_tokens:, :].mean(dim=1)
119
+ else:
120
+ x = torch.cat([x[:, 0, :], x[:, self.num_prefix_tokens:, :].mean(dim=1)], dim=1)
121
+ x = self.head(x)
122
+ if self.return_transformer_qkv:
123
+ return x, qkv
124
+ else:
125
+ return x
126
+
127
+ def get_specific_intermediate_layer(
128
+ self,
129
+ x: torch.Tensor,
130
+ n: int = 1,
131
+ return_qkv: bool = False,
132
+ return_att_weights: bool = False,
133
+ ):
134
+ num_blocks = len(self.blocks)
135
+ attn_weights = []
136
+ if n >= num_blocks:
137
+ raise ValueError(f"n must be less than {num_blocks}")
138
+
139
+ # forward pass
140
+ x = self.patch_embed(x)
141
+ x = self._pos_embed(x)
142
+ x = self.norm_pre(x)
143
+
144
+ if n == -1:
145
+ if return_qkv:
146
+ raise ValueError("take_indice cannot be -1 if return_transformer_qkv is True")
147
+ else:
148
+ return x
149
+
150
+ for i, blk in enumerate(self.blocks):
151
+ if self.return_transformer_qkv:
152
+ x, qkv = blk(x, return_qkv=True)
153
+
154
+ if return_att_weights:
155
+ attn_weight, _ = compute_attention(qkv)
156
+ attn_weights.append(attn_weight.detach())
157
+ else:
158
+ x = blk(x)
159
+ if i == n:
160
+ output = x.clone()
161
+ if self.return_transformer_qkv and return_qkv:
162
+ qkv_output = qkv.clone()
163
+ break
164
+ if self.return_transformer_qkv and return_qkv and return_att_weights:
165
+ return output, qkv_output, attn_weights
166
+ elif self.return_transformer_qkv and return_qkv:
167
+ return output, qkv_output
168
+ elif self.return_transformer_qkv and return_att_weights:
169
+ return output, attn_weights
170
+ else:
171
+ return output
172
+
173
+ def _intermediate_layers(
174
+ self,
175
+ x: torch.Tensor,
176
+ n: Union[int, Sequence] = 1,
177
+ ):
178
+ outputs, num_blocks = [], len(self.blocks)
179
+ if self.return_transformer_qkv:
180
+ qkv_outputs = []
181
+ take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
182
+
183
+ # forward pass
184
+ x = self.patch_embed(x)
185
+ x = self._pos_embed(x)
186
+ x = self.norm_pre(x)
187
+
188
+ for i, blk in enumerate(self.blocks):
189
+ if self.return_transformer_qkv:
190
+ x, qkv = blk(x, return_qkv=True)
191
+ else:
192
+ x = blk(x)
193
+ if i in take_indices:
194
+ outputs.append(x)
195
+ if self.return_transformer_qkv:
196
+ qkv_outputs.append(qkv)
197
+ if self.return_transformer_qkv:
198
+ return outputs, qkv_outputs
199
+ else:
200
+ return outputs
201
+
202
+ def get_intermediate_layers(
203
+ self,
204
+ x: torch.Tensor,
205
+ n: Union[int, Sequence] = 1,
206
+ reshape: bool = False,
207
+ return_prefix_tokens: bool = False,
208
+ norm: bool = False,
209
+ ) -> tuple[tuple, Any]:
210
+ """ Intermediate layer accessor (NOTE: This is a WIP experiment).
211
+ Inspired by DINO / DINOv2 interface
212
+ """
213
+ # take last n blocks if n is an int, if in is a sequence, select by matching indices
214
+ if self.return_transformer_qkv:
215
+ outputs, qkv = self._intermediate_layers(x, n)
216
+ else:
217
+ outputs = self._intermediate_layers(x, n)
218
+
219
+ if norm:
220
+ outputs = [self.norm(out) for out in outputs]
221
+ prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
222
+ outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
223
+
224
+ if reshape:
225
+ grid_size = self.patch_embed.grid_size
226
+ outputs = [
227
+ out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
228
+ for out in outputs
229
+ ]
230
+
231
+ if return_prefix_tokens:
232
+ return_out = tuple(zip(outputs, prefix_tokens))
233
+ else:
234
+ return_out = tuple(outputs)
235
+
236
+ if self.return_transformer_qkv:
237
+ return return_out, qkv
238
+ else:
239
+ return return_out
requirements.txt CHANGED
@@ -3,4 +3,8 @@ timm
3
  colorcet
4
  matplotlib
5
  torchvision
6
- streamlit
 
 
 
 
 
3
  colorcet
4
  matplotlib
5
  torchvision
6
+ streamlit
7
+ numpy
8
+ pillow
9
+ scikit-image
10
+ huggingface-hub
utils/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .data_utils import *
2
+ from .visualize_att_maps import *
3
+ from .misc_utils import *
4
+ from .get_landmark_coordinates import *
5
+
6
+
utils/data_utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .dataset_utils import *
2
+ from .reversible_affine_transform import *
3
+ from .transform_utils import *
4
+ from .class_balanced_distributed_sampler import *
5
+ from .class_balanced_sampler import *
utils/data_utils/class_balanced_distributed_sampler.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ from typing import Optional
4
+ import math
5
+ import torch.distributed as dist
6
+
7
+
8
+ class ClassBalancedDistributedSampler(torch.utils.data.Sampler):
9
+ """
10
+ A custom sampler that sub-samples a given dataset based on class labels. Based on the DistributedSampler class
11
+ Ref: https://github.com/pytorch/pytorch/blob/04c1df651aa58bea50977f4efcf19b09ce27cefd/torch/utils/data/distributed.py#L13
12
+ """
13
+
14
+ def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, rank: Optional[int] = None,
15
+ shuffle: bool = True, seed: int = 0, drop_last: bool = False, num_samples_per_class=100) -> None:
16
+
17
+ if not shuffle:
18
+ raise ValueError("ClassBalancedDatasetSubSampler requires shuffling, otherwise use DistributedSampler")
19
+
20
+ # Check if the dataset has a generate_class_balanced_indices method
21
+ if not hasattr(dataset, 'generate_class_balanced_indices'):
22
+ raise ValueError("Dataset does not have a generate_class_balanced_indices method")
23
+
24
+ self.shuffle = shuffle
25
+ self.seed = seed
26
+ if num_replicas is None:
27
+ if not dist.is_available():
28
+ raise RuntimeError("Requires distributed package to be available")
29
+ num_replicas = dist.get_world_size()
30
+ if rank is None:
31
+ if not dist.is_available():
32
+ raise RuntimeError("Requires distributed package to be available")
33
+ rank = dist.get_rank()
34
+ if rank >= num_replicas or rank < 0:
35
+ raise ValueError(
36
+ f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]")
37
+ self.dataset = dataset
38
+ self.num_replicas = num_replicas
39
+ self.rank = rank
40
+ self.epoch = 0
41
+ self.drop_last = drop_last
42
+
43
+ # Calculate the number of samples
44
+ g = torch.Generator()
45
+ g.manual_seed(self.seed + self.epoch)
46
+ self.num_samples_per_class = num_samples_per_class
47
+ indices = dataset.generate_class_balanced_indices(torch.Generator(),
48
+ num_samples_per_class=num_samples_per_class)
49
+ dataset_size = len(indices)
50
+
51
+ # If the dataset length is evenly divisible by # of replicas, then there
52
+ # is no need to drop any data, since the dataset will be split equally.
53
+ if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
54
+ # Split to nearest available length that is evenly divisible.
55
+ # This is to ensure each rank receives the same amount of data when
56
+ # using this Sampler.
57
+ self.num_samples = math.ceil(
58
+ (dataset_size - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
59
+ )
60
+ else:
61
+ self.num_samples = math.ceil(dataset_size / self.num_replicas) # type: ignore[arg-type]
62
+ self.total_size = self.num_samples * self.num_replicas
63
+
64
+ def __iter__(self):
65
+ # deterministically shuffle based on epoch and seed, here shuffle is assumed to be True
66
+ g = torch.Generator()
67
+ g.manual_seed(self.seed + self.epoch)
68
+ indices = self.dataset.generate_class_balanced_indices(g, num_samples_per_class=self.num_samples_per_class)
69
+
70
+ if not self.drop_last:
71
+ # add extra samples to make it evenly divisible
72
+ padding_size = self.total_size - len(indices)
73
+ if padding_size <= len(indices):
74
+ indices += indices[:padding_size]
75
+ else:
76
+ indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
77
+ else:
78
+ # remove tail of data to make it evenly divisible.
79
+ indices = indices[:self.total_size]
80
+
81
+ # subsample
82
+ indices = indices[self.rank:self.total_size:self.num_replicas]
83
+
84
+ return iter(indices)
85
+
86
+ def __len__(self) -> int:
87
+ return self.num_samples
88
+
89
+ def set_epoch(self, epoch: int) -> None:
90
+ r"""
91
+ Set the epoch for this sampler.
92
+
93
+ When :attr:`shuffle=True`, this ensures all replicas
94
+ use a different random ordering for each epoch. Otherwise, the next iteration of this
95
+ sampler will yield the same ordering.
96
+
97
+ Args:
98
+ epoch (int): Epoch number.
99
+ """
100
+ self.epoch = epoch
utils/data_utils/class_balanced_sampler.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+
4
+
5
+ class ClassBalancedRandomSampler(torch.utils.data.Sampler):
6
+ """
7
+ A custom sampler that sub-samples a given dataset based on class labels. Based on the RandomSampler class
8
+ This is essentially the non-ddp version of ClassBalancedDistributedSampler
9
+ Ref: https://github.com/pytorch/pytorch/blob/abe3c55a6a01c5b625eeb4fc9aab1421a5965cd2/torch/utils/data/sampler.py#L117
10
+ """
11
+
12
+ def __init__(self, dataset: Dataset, num_samples_per_class=100, seed: int = 0) -> None:
13
+ self.dataset = dataset
14
+ self.seed = seed
15
+ # Calculate the number of samples
16
+ self.generator = torch.Generator()
17
+ self.generator.manual_seed(self.seed)
18
+ self.num_samples_per_class = num_samples_per_class
19
+ indices = dataset.generate_class_balanced_indices(self.generator,
20
+ num_samples_per_class=num_samples_per_class)
21
+ self.num_samples = len(indices)
22
+
23
+ def __iter__(self):
24
+ # Change seed for every function call
25
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
26
+ self.generator.manual_seed(seed)
27
+ indices = self.dataset.generate_class_balanced_indices(self.generator, num_samples_per_class=self.num_samples_per_class)
28
+ return iter(indices)
29
+
30
+ def __len__(self) -> int:
31
+ return self.num_samples
utils/data_utils/dataset_utils.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from torch import Tensor
3
+ from typing import List, Optional
4
+ import numpy as np
5
+ import torchvision
6
+ import json
7
+
8
+
9
+ def load_json(path: str):
10
+ """
11
+ Load json file from path and return the data
12
+ :param path: Path to the json file
13
+ :return:
14
+ data: Data in the json file
15
+ """
16
+ with open(path, 'r') as f:
17
+ data = json.load(f)
18
+ return data
19
+
20
+
21
+ def save_json(data: dict, path: str):
22
+ """
23
+ Save data to a json file
24
+ :param data: Data to be saved
25
+ :param path: Path to save the data
26
+ :return:
27
+ """
28
+ with open(path, "w") as f:
29
+ json.dump(data, f)
30
+
31
+
32
+ def pil_loader(path):
33
+ """
34
+ Load image from path using PIL
35
+ :param path: Path to the image
36
+ :return:
37
+ img: PIL Image
38
+ """
39
+ with open(path, 'rb') as f:
40
+ img = Image.open(f)
41
+ return img.convert('RGB')
42
+
43
+
44
+ def get_dimensions(image: Tensor):
45
+ """
46
+ Get the dimensions of the image
47
+ :param image: Tensor or PIL Image or np.ndarray
48
+ :return:
49
+ h: Height of the image
50
+ w: Width of the image
51
+ """
52
+ if isinstance(image, Tensor):
53
+ _, h, w = image.shape
54
+ elif isinstance(image, np.ndarray):
55
+ h, w, _ = image.shape
56
+ elif isinstance(image, Image.Image):
57
+ w, h = image.size
58
+ else:
59
+ raise ValueError(f"Invalid image type: {type(image)}")
60
+ return h, w
61
+
62
+
63
+ def center_crop_boxes_kps(img: Tensor, output_size: Optional[List[int]] = 448, parts: Optional[Tensor] = None,
64
+ boxes: Optional[Tensor] = None, num_keypoints: int = 15):
65
+ """
66
+ Calculate the center crop parameters for the bounding boxes and landmarks and update them
67
+ :param img: Image
68
+ :param output_size: Output size of the cropped image
69
+ :param parts: Locations of the landmarks of following format: <part_id> <x> <y> <visible>
70
+ :param boxes: Bounding boxes of the landmarks of following format: <image_id> <x> <y> <width> <height>
71
+ :param num_keypoints: Number of keypoints
72
+ :return:
73
+ cropped_img: Center cropped image
74
+ parts: Updated locations of the landmarks
75
+ boxes: Updated bounding boxes of the landmarks
76
+ """
77
+ if isinstance(output_size, int):
78
+ output_size = (output_size, output_size)
79
+ elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
80
+ output_size = (output_size[0], output_size[0])
81
+ elif isinstance(output_size, (tuple, list)) and len(output_size) == 2:
82
+ output_size = output_size
83
+ else:
84
+ raise ValueError(f"Invalid output size: {output_size}")
85
+
86
+ crop_height, crop_width = output_size
87
+ image_height, image_width = get_dimensions(img)
88
+ img = torchvision.transforms.functional.center_crop(img, output_size)
89
+
90
+ crop_top, crop_left = _get_center_crop_params_(image_height, image_width, output_size)
91
+
92
+ if parts is not None:
93
+ for j in range(num_keypoints):
94
+ # Skip if part is invisible
95
+ if parts[j][-1] == 0:
96
+ continue
97
+ parts[j][1] -= crop_left
98
+ parts[j][2] -= crop_top
99
+
100
+ # Skip if part is outside the crop
101
+ if parts[j][1] > crop_width or parts[j][2] > crop_height:
102
+ parts[j][-1] = 0
103
+ if parts[j][1] < 0 or parts[j][2] < 0:
104
+ parts[j][-1] = 0
105
+
106
+ parts[j][1] = min(crop_width, parts[j][1])
107
+ parts[j][2] = min(crop_height, parts[j][2])
108
+ parts[j][1] = max(0, parts[j][1])
109
+ parts[j][2] = max(0, parts[j][2])
110
+
111
+ if boxes is not None:
112
+ boxes[1] -= crop_left
113
+ boxes[2] -= crop_top
114
+ boxes[1] = max(0, boxes[1])
115
+ boxes[2] = max(0, boxes[2])
116
+ boxes[1] = min(crop_width, boxes[1])
117
+ boxes[2] = min(crop_height, boxes[2])
118
+
119
+ return img, parts, boxes
120
+
121
+
122
+ def _get_center_crop_params_(image_height: int, image_width: int, output_size: Optional[List[int]] = 448):
123
+ """
124
+ Get the parameters for center cropping the image
125
+ :param image_height: Height of the image
126
+ :param image_width: Width of the image
127
+ :param output_size: Output size of the cropped image
128
+ :return:
129
+ crop_top: Top coordinate of the cropped image
130
+ crop_left: Left coordinate of the cropped image
131
+ """
132
+ if isinstance(output_size, int):
133
+ output_size = (output_size, output_size)
134
+ elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
135
+ output_size = (output_size[0], output_size[0])
136
+ elif isinstance(output_size, (tuple, list)) and len(output_size) == 2:
137
+ output_size = output_size
138
+ else:
139
+ raise ValueError(f"Invalid output size: {output_size}")
140
+
141
+ crop_height, crop_width = output_size
142
+
143
+ if crop_width > image_width or crop_height > image_height:
144
+ padding_ltrb = [
145
+ (crop_width - image_width) // 2 if crop_width > image_width else 0,
146
+ (crop_height - image_height) // 2 if crop_height > image_height else 0,
147
+ (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
148
+ (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
149
+ ]
150
+ crop_top, crop_left = padding_ltrb[1], padding_ltrb[0]
151
+ return crop_top, crop_left
152
+
153
+ if crop_width == image_width and crop_height == image_height:
154
+ crop_top = 0
155
+ crop_left = 0
156
+ return crop_top, crop_left
157
+
158
+ crop_top = int(round((image_height - crop_height) / 2.0))
159
+ crop_left = int(round((image_width - crop_width) / 2.0))
160
+
161
+ return crop_top, crop_left
utils/data_utils/reversible_affine_transform.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Description: This file contains the code for the reversible affine transform
2
+ import torchvision.transforms as transforms
3
+ import torch
4
+ from typing import List, Optional, Tuple, Any
5
+
6
+
7
+ def generate_affine_trans_params(
8
+ degrees: List[float],
9
+ translate: Optional[List[float]],
10
+ scale_ranges: Optional[List[float]],
11
+ shears: Optional[List[float]],
12
+ img_size: List[int],
13
+ ) -> Tuple[float, Tuple[int, int], float, Any]:
14
+ """Get parameters for affine transformation
15
+
16
+ Returns:
17
+ params to be passed to the affine transformation
18
+ """
19
+ angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
20
+ if translate is not None:
21
+ max_dx = float(translate[0] * img_size[0])
22
+ max_dy = float(translate[1] * img_size[1])
23
+ tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
24
+ ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
25
+ translations = (tx, ty)
26
+ else:
27
+ translations = (0, 0)
28
+
29
+ if scale_ranges is not None:
30
+ scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item())
31
+ else:
32
+ scale = 1.0
33
+
34
+ shear_x = shear_y = 0.0
35
+ if shears is not None:
36
+ shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item())
37
+ if len(shears) == 4:
38
+ shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item())
39
+
40
+ shear = (shear_x, shear_y)
41
+ if shear_x == 0.0 and shear_y == 0.0:
42
+ shear = 0.0
43
+
44
+ return angle, translations, scale, shear
45
+
46
+
47
+ def rigid_transform(img, angle, translate, scale, invert=False, shear=0,
48
+ interpolation=transforms.InterpolationMode.BILINEAR):
49
+ """
50
+ Affine transforms input image
51
+ Modified from: https://github.com/robertdvdk/part_detection/blob/eec53f2f40602113f74c6c1f60a2034823b0fcaf/lib.py#L54
52
+ Parameters
53
+ ----------
54
+ img: Tensor
55
+ Input image
56
+ angle: int
57
+ Rotation angle between -180 and 180 degrees
58
+ translate: [int]
59
+ Sequence of horizontal/vertical translations
60
+ scale: float
61
+ How to scale the image
62
+ invert: bool
63
+ Whether to invert the transformation
64
+ shear: float
65
+ Shear angle in degrees
66
+ interpolation: InterpolationMode
67
+ Interpolation mode to calculate output values
68
+ Returns
69
+ ----------
70
+ img: Tensor
71
+ Transformed image
72
+
73
+ """
74
+ if not invert:
75
+ img = transforms.functional.affine(img, angle=angle, translate=translate, scale=scale, shear=shear,
76
+ interpolation=interpolation)
77
+ else:
78
+ translate = [-t for t in translate]
79
+ img = transforms.functional.affine(img=img, angle=0, translate=translate, scale=1, shear=shear)
80
+ img = transforms.functional.affine(img=img, angle=-angle, translate=[0, 0], scale=1 / scale, shear=shear)
81
+
82
+ return img
utils/data_utils/transform_utils.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms as transforms
3
+ from torchvision.transforms import Compose
4
+
5
+ from timm.data.constants import \
6
+ IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
7
+ from timm.data import create_transform
8
+
9
+
10
+ def make_train_transforms(args):
11
+ train_transforms: Compose = transforms.Compose([
12
+ transforms.Resize(size=args.image_size, antialias=True),
13
+ transforms.RandomHorizontalFlip(p=args.hflip),
14
+ transforms.RandomVerticalFlip(p=args.vflip),
15
+ transforms.ColorJitter(),
16
+ transforms.RandomAffine(degrees=90, translate=(0.2, 0.2), scale=(0.8, 1.2)),
17
+ transforms.RandomCrop(args.image_size),
18
+ transforms.ToTensor(),
19
+ transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
20
+
21
+ ])
22
+ return train_transforms
23
+
24
+
25
+ def make_test_transforms(args):
26
+ test_transforms: Compose = transforms.Compose([
27
+ transforms.Resize(size=args.image_size, antialias=True),
28
+ transforms.CenterCrop(args.image_size),
29
+ transforms.ToTensor(),
30
+ transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
31
+
32
+ ])
33
+ return test_transforms
34
+
35
+
36
+ def build_transform_timm(args, is_train=True):
37
+ resize_im = args.image_size > 32
38
+ imagenet_default_mean_and_std = args.imagenet_default_mean_and_std
39
+ mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
40
+ std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
41
+
42
+ if is_train:
43
+ # this should always dispatch to transforms_imagenet_train
44
+ transform = create_transform(
45
+ input_size=args.image_size,
46
+ is_training=True,
47
+ color_jitter=args.color_jitter,
48
+ hflip=args.hflip,
49
+ vflip=args.vflip,
50
+ auto_augment=args.aa,
51
+ interpolation=args.train_interpolation,
52
+ re_prob=args.reprob,
53
+ re_mode=args.remode,
54
+ re_count=args.recount,
55
+ mean=mean,
56
+ std=std,
57
+ )
58
+ if not resize_im:
59
+ transform.transforms[0] = transforms.RandomCrop(
60
+ args.image_size, padding=4)
61
+ return transform
62
+
63
+ t = []
64
+ if resize_im:
65
+ # warping (no cropping) when evaluated at 384 or larger
66
+ if args.image_size >= 384:
67
+ t.append(
68
+ transforms.Resize((args.image_size, args.image_size),
69
+ interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),
70
+ )
71
+ print(f"Warping {args.image_size} size input images...")
72
+ else:
73
+ if args.crop_pct is None:
74
+ args.crop_pct = 224 / 256
75
+ size = int(args.image_size / args.crop_pct)
76
+ t.append(
77
+ # to maintain same ratio w.r.t. 224 images
78
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),
79
+ )
80
+ t.append(transforms.CenterCrop(args.image_size))
81
+
82
+ t.append(transforms.ToTensor())
83
+ t.append(transforms.Normalize(mean, std))
84
+ return transforms.Compose(t)
85
+
86
+
87
+ def inverse_normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD):
88
+ mean = torch.as_tensor(mean)
89
+ std = torch.as_tensor(std)
90
+ un_normalize = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
91
+ return un_normalize
92
+
93
+
94
+ def normalize_only(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD):
95
+ normalize = transforms.Normalize(mean=mean, std=std)
96
+ return normalize
97
+
98
+
99
+ def inverse_normalize_w_resize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
100
+ resize_resolution=(256, 256)):
101
+ mean = torch.as_tensor(mean)
102
+ std = torch.as_tensor(std)
103
+ resize_unnorm = transforms.Compose([
104
+ transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist()),
105
+ transforms.Resize(size=resize_resolution, antialias=True)])
106
+ return resize_unnorm
107
+
108
+
109
+ def load_transforms(args):
110
+ # Get the transforms and load the dataset
111
+ if args.augmentations_to_use == 'timm':
112
+ train_transforms = build_transform_timm(args, is_train=True)
113
+ elif args.augmentations_to_use == 'cub_original':
114
+ train_transforms = make_train_transforms(args)
115
+ else:
116
+ raise ValueError('Augmentations not supported.')
117
+ test_transforms = make_test_transforms(args)
118
+ return train_transforms, test_transforms
utils/get_landmark_coordinates.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file contains the function to generate the center coordinates as tensor for the current net.
2
+ import torch
3
+
4
+
5
+ def landmark_coordinates(maps, grid_x=None, grid_y=None):
6
+ """
7
+ Generate the center coordinates as tensor for the current net.
8
+ Modified from: https://github.com/robertdvdk/part_detection/blob/eec53f2f40602113f74c6c1f60a2034823b0fcaf/lib.py#L19
9
+ Parameters
10
+ ----------
11
+ maps: torch.Tensor
12
+ Attention map with shape (batch_size, channels, height, width) where channels is the landmark probability
13
+ grid_x: torch.Tensor
14
+ The grid x coordinates
15
+ grid_y: torch.Tensor
16
+ The grid y coordinates
17
+ Returns
18
+ ----------
19
+ loc_x: Tensor
20
+ The centroid x coordinates
21
+ loc_y: Tensor
22
+ The centroid y coordinates
23
+ grid_x: Tensor
24
+ grid_y: Tensor
25
+ """
26
+ return_grid = False
27
+ if grid_x is None or grid_y is None:
28
+ return_grid = True
29
+ grid_x, grid_y = torch.meshgrid(torch.arange(maps.shape[2]),
30
+ torch.arange(maps.shape[3]), indexing='ij')
31
+ grid_x = grid_x.unsqueeze(0).unsqueeze(0).contiguous().to(maps.device, non_blocking=True)
32
+ grid_y = grid_y.unsqueeze(0).unsqueeze(0).contiguous().to(maps.device, non_blocking=True)
33
+ map_sums = maps.sum(3).sum(2).detach()
34
+ maps_x = grid_x * maps
35
+ maps_y = grid_y * maps
36
+ loc_x = maps_x.sum(3).sum(2) / map_sums
37
+ loc_y = maps_y.sum(3).sum(2) / map_sums
38
+ if return_grid:
39
+ return loc_x, loc_y, grid_x, grid_y
40
+ else:
41
+ return loc_x, loc_y
utils/misc_utils.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import reduce
3
+
4
+ import torch
5
+ import numpy as np
6
+ import os
7
+ from pathlib import Path
8
+
9
+
10
+ def factors(n):
11
+ return reduce(list.__add__,
12
+ ([i, n // i] for i in range(1, int(n ** 0.5) + 1) if n % i == 0))
13
+
14
+
15
+ def file_line_count(filename: str) -> int:
16
+ """Count the number of lines in a file"""
17
+ with open(filename, 'rb') as f:
18
+ return sum(1 for _ in f)
19
+
20
+
21
+ def compute_attention(qkv, scale=None):
22
+ """
23
+ Compute attention matrix (same as in the pytorch scaled dot product attention)
24
+ Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
25
+ :param qkv: Query, key and value tensors concatenated along the first dimension
26
+ :param scale: Scale factor for the attention computation
27
+ :return:
28
+ """
29
+ if isinstance(qkv, torch.Tensor):
30
+ query, key, value = qkv.unbind(0)
31
+ else:
32
+ query, key, value = qkv
33
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
34
+ L, S = query.size(-2), key.size(-2)
35
+ attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
36
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
37
+ attn_weight += attn_bias
38
+ attn_weight = torch.softmax(attn_weight, dim=-1)
39
+ attn_out = attn_weight @ value
40
+ return attn_weight, attn_out
41
+
42
+
43
+ def compute_dot_product_similarity(a, b):
44
+ scores = a @ b.transpose(-1, -2)
45
+ return scores
46
+
47
+
48
+ def compute_cross_entropy(p, q):
49
+ q = torch.nn.functional.log_softmax(q, dim=-1)
50
+ loss = torch.sum(p * q, dim=-1)
51
+ return - loss.mean()
52
+
53
+
54
+ def rollout(attentions, discard_ratio=0.9, head_fusion="max", device=torch.device("cuda")):
55
+ """
56
+ Perform attention rollout,
57
+ Ref: https://github.com/jacobgil/vit-explain/blob/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/vit_rollout.py#L9C1-L42C16
58
+ Parameters
59
+ ----------
60
+ attentions : list
61
+ List of attention matrices, one for each transformer layer
62
+ discard_ratio : float
63
+ Ratio of lowest attention values to discard
64
+ head_fusion : str
65
+ Type of fusion to use for attention heads. One of "mean", "max", "min"
66
+ device : torch.device
67
+ Device to use for computation
68
+ Returns
69
+ -------
70
+ mask : np.ndarray
71
+ Mask of shape (width, width), where width is the square root of the number of patches
72
+ """
73
+ result = torch.eye(attentions[0].size(-1), device=device)
74
+ attentions = [attention.to(device) for attention in attentions]
75
+ with torch.no_grad():
76
+ for attention in attentions:
77
+ if head_fusion == "mean":
78
+ attention_heads_fused = attention.mean(axis=1)
79
+ elif head_fusion == "max":
80
+ attention_heads_fused = attention.max(axis=1).values
81
+ elif head_fusion == "min":
82
+ attention_heads_fused = attention.min(axis=1).values
83
+ else:
84
+ raise "Attention head fusion type Not supported"
85
+
86
+ # Drop the lowest attentions, but
87
+ # don't drop the class token
88
+ flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
89
+ _, indices = flat.topk(int(flat.size(-1) * discard_ratio), -1, False)
90
+ indices = indices[indices != 0]
91
+ flat[0, indices] = 0
92
+
93
+ I = torch.eye(attention_heads_fused.size(-1), device=device)
94
+ a = (attention_heads_fused + 1.0 * I) / 2
95
+ a = a / a.sum(dim=-1)
96
+
97
+ result = torch.matmul(a, result)
98
+
99
+ # Normalize the result by max value in each row
100
+ result = result / result.max(dim=-1, keepdim=True)[0]
101
+ return result
102
+
103
+
104
+ def sync_bn_conversion(model: torch.nn.Module):
105
+ """
106
+ Convert BatchNorm to SyncBatchNorm (used for DDP)
107
+ :param model: PyTorch model
108
+ :return:
109
+ model: PyTorch model with SyncBatchNorm layers
110
+ """
111
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
112
+ return model
113
+
114
+
115
+ def check_snapshot(args):
116
+ """
117
+ Create directory to save training checkpoints, otherwise load the existing checkpoint.
118
+ Additionally, if it is an array training job, create a new directory for each training job.
119
+ :param args: Arguments from the argument parser
120
+ :return:
121
+ """
122
+ # Check if it is an array training job (i.e. training with multiple random seeds on the same settings)
123
+ if args.array_training_job and not args.resume_training:
124
+ args.snapshot_dir = os.path.join(args.snapshot_dir, str(args.seed))
125
+ if not os.path.exists(args.snapshot_dir):
126
+ save_dir = Path(args.snapshot_dir)
127
+ save_dir.mkdir(parents=True, exist_ok=True)
128
+ else:
129
+ # Create directory to save training checkpoints, otherwise load the existing checkpoint
130
+ if not os.path.exists(args.snapshot_dir):
131
+ if ".pt" not in args.snapshot_dir or ".pth" not in args.snapshot_dir:
132
+ save_dir = Path(args.snapshot_dir)
133
+ save_dir.mkdir(parents=True, exist_ok=True)
134
+ else:
135
+ raise ValueError('Snapshot checkpoint does not exist.')
utils/visualize_att_maps.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
3
+ import colorcet as cc
4
+ import numpy as np
5
+ import skimage
6
+ from pathlib import Path
7
+ import os
8
+ import torch
9
+
10
+ from utils.data_utils.transform_utils import inverse_normalize_w_resize
11
+ from utils.misc_utils import factors
12
+
13
+ # Define the colors to use for the attention maps
14
+ colors = cc.glasbey_category10
15
+
16
+
17
+ class VisualizeAttentionMaps:
18
+ def __init__(self, snapshot_dir="", save_resolution=(256, 256), alpha=0.5, sub_path_test="",
19
+ dataset_name="", bg_label=0, batch_size=32, num_parts=15, plot_ims_separately=False,
20
+ plot_landmark_amaps=False):
21
+ """
22
+ Plot attention maps and optionally landmark centroids on images.
23
+ :param snapshot_dir: Directory to save the visualization results
24
+ :param save_resolution: Size of the images to save
25
+ :param alpha: The transparency of the attention maps
26
+ :param sub_path_test: The sub-path of the test dataset
27
+ :param dataset_name: The name of the dataset
28
+ :param bg_label: The background label index in the attention maps
29
+ :param batch_size: The batch size
30
+ :param num_parts: The number of parts in the attention maps
31
+ :param plot_ims_separately: Whether to plot the images separately
32
+ :param plot_landmark_amaps: Whether to plot the landmark attention maps
33
+ """
34
+ self.save_resolution = save_resolution
35
+ self.alpha = alpha
36
+ self.sub_path_test = sub_path_test
37
+ self.dataset_name = dataset_name
38
+ self.bg_label = bg_label
39
+ self.snapshot_dir = snapshot_dir
40
+
41
+ self.resize_unnorm = inverse_normalize_w_resize(resize_resolution=self.save_resolution)
42
+ self.batch_size = batch_size
43
+ self.nrows = factors(self.batch_size)[-1]
44
+ self.ncols = factors(self.batch_size)[-2]
45
+ self.num_parts = num_parts
46
+ self.req_colors = colors[:num_parts]
47
+ self.plot_ims_separately = plot_ims_separately
48
+ self.plot_landmark_amaps = plot_landmark_amaps
49
+ if self.nrows == 1 and self.ncols == 1:
50
+ self.figs_size = (10, 10)
51
+ else:
52
+ self.figs_size = (self.ncols * 2, self.nrows * 2)
53
+
54
+ def recalculate_nrows_ncols(self):
55
+ self.nrows = factors(self.batch_size)[-1]
56
+ self.ncols = factors(self.batch_size)[-2]
57
+ if self.nrows == 1 and self.ncols == 1:
58
+ self.figs_size = (10, 10)
59
+ else:
60
+ self.figs_size = (self.ncols * 2, self.nrows * 2)
61
+
62
+ @torch.no_grad()
63
+ def show_maps(self, ims, maps, epoch=0, curr_iter=0, extra_info=""):
64
+ """
65
+ Plot images, attention maps and landmark centroids.
66
+ Parameters
67
+ ----------
68
+ ims: Tensor, [batch_size, 3, width_im, height_im]
69
+ Input images on which to show the attention maps
70
+ maps: Tensor, [batch_size, number of parts + 1, width_map, height_map]
71
+ The attention maps to display
72
+ epoch: int
73
+ The epoch number
74
+ curr_iter: int
75
+ The current iteration number
76
+ extra_info: str
77
+ Any extra information to add to the file name
78
+ """
79
+ ims = self.resize_unnorm(ims)
80
+ if ims.shape[0] != self.batch_size:
81
+ self.batch_size = ims.shape[0]
82
+ self.recalculate_nrows_ncols()
83
+ fig, axs = plt.subplots(nrows=self.nrows, ncols=self.ncols, squeeze=False, figsize=self.figs_size)
84
+ ims = (ims.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
85
+ map_argmax = torch.nn.functional.interpolate(maps.clone().detach(), size=self.save_resolution,
86
+ mode='bilinear',
87
+ align_corners=True).argmax(dim=1).cpu().numpy()
88
+ for i, ax in enumerate(axs.ravel()):
89
+ curr_map = skimage.color.label2rgb(label=map_argmax[i], image=ims[i], colors=self.req_colors,
90
+ bg_label=self.bg_label, alpha=self.alpha)
91
+ ax.imshow(curr_map)
92
+ ax.axis('off')
93
+ save_dir = Path(os.path.join(self.snapshot_dir, 'results_vis_' + self.sub_path_test))
94
+ save_dir.mkdir(parents=True, exist_ok=True)
95
+ save_path = os.path.join(save_dir, f'{epoch}_{curr_iter}_{self.dataset_name}{extra_info}.png')
96
+ fig.tight_layout()
97
+ if self.snapshot_dir != "":
98
+ plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
99
+ else:
100
+ plt.show()
101
+ plt.close('all')
102
+
103
+ if self.plot_ims_separately:
104
+ fig, axs = plt.subplots(nrows=self.nrows, ncols=self.ncols, squeeze=False, figsize=self.figs_size)
105
+ for i, ax in enumerate(axs.ravel()):
106
+ ax.imshow(ims[i])
107
+ ax.axis('off')
108
+ save_path = os.path.join(save_dir, f'image_{epoch}_{curr_iter}_{self.dataset_name}{extra_info}.jpg')
109
+ fig.tight_layout()
110
+ if self.snapshot_dir != "":
111
+ plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
112
+ else:
113
+ plt.show()
114
+ plt.close('all')
115
+
116
+ if self.plot_landmark_amaps:
117
+ if self.batch_size > 1:
118
+ raise ValueError('Not implemented for batch size > 1')
119
+ for i in range(self.num_parts):
120
+ fig, ax = plt.subplots(1, 1, figsize=self.figs_size)
121
+ divider = make_axes_locatable(ax)
122
+ cax = divider.append_axes('right', size='5%', pad=0.05)
123
+ im = ax.imshow(maps[0, i, ...].detach().cpu().numpy(), cmap='cet_gouldian')
124
+ fig.colorbar(im, cax=cax, orientation='vertical')
125
+ ax.axis('off')
126
+ save_path = os.path.join(save_dir,
127
+ f'landmark_{i}_{epoch}_{curr_iter}_{self.dataset_name}{extra_info}.png')
128
+ fig.tight_layout()
129
+ if self.snapshot_dir != "":
130
+ plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
131
+ else:
132
+ plt.show()
133
+ plt.close()
134
+
135
+ plt.close('all')