Spaces:
Sleeping
Sleeping
ananthu-aniraj
commited on
Commit
·
20239f9
1
Parent(s):
b507f8e
add initial files
Browse files- .gitignore +40 -0
- files/images/Laysan_Albatross_0050_870.jpg +0 -0
- layers/__init__.py +2 -0
- layers/independent_mlp.py +69 -0
- layers/transformer_layers.py +54 -0
- load_model.py +226 -0
- models/__init__.py +4 -0
- models/individual_landmark_convnext.py +110 -0
- models/individual_landmark_resnet.py +141 -0
- models/individual_landmark_vit.py +366 -0
- models/vit_baseline.py +239 -0
- requirements.txt +5 -1
- utils/__init__.py +6 -0
- utils/data_utils/__init__.py +5 -0
- utils/data_utils/class_balanced_distributed_sampler.py +100 -0
- utils/data_utils/class_balanced_sampler.py +31 -0
- utils/data_utils/dataset_utils.py +161 -0
- utils/data_utils/reversible_affine_transform.py +82 -0
- utils/data_utils/transform_utils.py +118 -0
- utils/get_landmark_coordinates.py +41 -0
- utils/misc_utils.py +135 -0
- utils/visualize_att_maps.py +135 -0
.gitignore
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# editor settings
|
2 |
+
.idea
|
3 |
+
.vscode
|
4 |
+
_darcs
|
5 |
+
|
6 |
+
# compilation and distribution
|
7 |
+
__pycache__
|
8 |
+
_ext
|
9 |
+
*.pyc
|
10 |
+
*.pyd
|
11 |
+
*.so
|
12 |
+
*.dll
|
13 |
+
*.egg-info/
|
14 |
+
build/
|
15 |
+
dist/
|
16 |
+
wheels/
|
17 |
+
|
18 |
+
# pytorch/python/numpy formats
|
19 |
+
*.pth
|
20 |
+
*.pkl
|
21 |
+
*.npy
|
22 |
+
*.ts
|
23 |
+
*.pt
|
24 |
+
|
25 |
+
# ipython/jupyter notebooks
|
26 |
+
*.ipynb
|
27 |
+
**/.ipynb_checkpoints/
|
28 |
+
|
29 |
+
# Editor temporaries
|
30 |
+
*.swn
|
31 |
+
*.swo
|
32 |
+
*.swp
|
33 |
+
*~
|
34 |
+
|
35 |
+
# Results temporary
|
36 |
+
*.png
|
37 |
+
*.txt
|
38 |
+
*.tsv
|
39 |
+
wandb/
|
40 |
+
exps/
|
files/images/Laysan_Albatross_0050_870.jpg
ADDED
layers/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .transformer_layers import *
|
2 |
+
from .independent_mlp import *
|
layers/independent_mlp.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file contains the implementation of the IndependentMLPs class
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class IndependentMLPs(torch.nn.Module):
|
6 |
+
"""
|
7 |
+
This class implements the MLP used for classification with the option to use an additional independent MLP layer
|
8 |
+
"""
|
9 |
+
|
10 |
+
def __init__(self, part_dim, latent_dim, bias=False, num_lin_layers=1, act_layer=True, out_dim=None, stack_dim=-1):
|
11 |
+
"""
|
12 |
+
|
13 |
+
:param part_dim: Number of parts
|
14 |
+
:param latent_dim: Latent dimension
|
15 |
+
:param bias: Whether to use bias
|
16 |
+
:param num_lin_layers: Number of linear layers
|
17 |
+
:param act_layer: Whether to use activation layer
|
18 |
+
:param out_dim: Output dimension (default: None)
|
19 |
+
:param stack_dim: Dimension to stack the outputs (default: -1)
|
20 |
+
"""
|
21 |
+
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
self.bias = bias
|
25 |
+
self.latent_dim = latent_dim
|
26 |
+
if out_dim is None:
|
27 |
+
out_dim = latent_dim
|
28 |
+
self.out_dim = out_dim
|
29 |
+
self.part_dim = part_dim
|
30 |
+
self.stack_dim = stack_dim
|
31 |
+
|
32 |
+
layer_stack = torch.nn.ModuleList()
|
33 |
+
for i in range(part_dim):
|
34 |
+
layer_stack.append(torch.nn.Sequential())
|
35 |
+
for j in range(num_lin_layers):
|
36 |
+
layer_stack[i].add_module(f"fc_{j}", torch.nn.Linear(latent_dim, self.out_dim, bias=bias))
|
37 |
+
if act_layer:
|
38 |
+
layer_stack[i].add_module(f"act_{j}", torch.nn.GELU())
|
39 |
+
self.feature_layers = layer_stack
|
40 |
+
self.reset_weights()
|
41 |
+
|
42 |
+
def __repr__(self):
|
43 |
+
return f"IndependentMLPs(part_dim={self.part_dim}, latent_dim={self.latent_dim}), bias={self.bias}"
|
44 |
+
|
45 |
+
def reset_weights(self):
|
46 |
+
""" Initialize weights with a identity matrix"""
|
47 |
+
for layer in self.feature_layers:
|
48 |
+
for m in layer.modules():
|
49 |
+
if isinstance(m, torch.nn.Linear):
|
50 |
+
# Initialize weights with a truncated normal distribution
|
51 |
+
torch.nn.init.trunc_normal_(m.weight, std=0.02)
|
52 |
+
if m.bias is not None:
|
53 |
+
torch.nn.init.zeros_(m.bias)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
""" Input X has the dimensions batch x latent_dim x part_dim """
|
57 |
+
|
58 |
+
outputs = []
|
59 |
+
for i, layer in enumerate(self.feature_layers):
|
60 |
+
if self.stack_dim == -1:
|
61 |
+
in_ = x[..., i]
|
62 |
+
else:
|
63 |
+
in_ = x[:, i, ...] # Select feature i
|
64 |
+
out = layer(in_) # Apply MLP to feature i
|
65 |
+
outputs.append(out)
|
66 |
+
|
67 |
+
x = torch.stack(outputs, dim=self.stack_dim) # Stack the outputs
|
68 |
+
|
69 |
+
return x
|
layers/transformer_layers.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Attention Block with option to return the mean of k over heads from attention
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from timm.models.vision_transformer import Attention, Block
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from typing import Tuple
|
7 |
+
|
8 |
+
|
9 |
+
class AttentionWQKVReturn(Attention):
|
10 |
+
"""
|
11 |
+
Modifications:
|
12 |
+
- Return the qkv tensors from the attention
|
13 |
+
"""
|
14 |
+
|
15 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
16 |
+
B, N, C = x.shape
|
17 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
18 |
+
q, k, v = qkv.unbind(0)
|
19 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
20 |
+
|
21 |
+
if self.fused_attn:
|
22 |
+
x = F.scaled_dot_product_attention(
|
23 |
+
q, k, v,
|
24 |
+
dropout_p=self.attn_drop.p if self.training else 0.,
|
25 |
+
)
|
26 |
+
else:
|
27 |
+
q = q * self.scale
|
28 |
+
attn = q @ k.transpose(-2, -1)
|
29 |
+
attn = attn.softmax(dim=-1)
|
30 |
+
attn = self.attn_drop(attn)
|
31 |
+
x = attn @ v
|
32 |
+
|
33 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
34 |
+
x = self.proj(x)
|
35 |
+
x = self.proj_drop(x)
|
36 |
+
return x, torch.stack((q, k, v), dim=0)
|
37 |
+
|
38 |
+
|
39 |
+
class BlockWQKVReturn(Block):
|
40 |
+
"""
|
41 |
+
Modifications:
|
42 |
+
- Use AttentionWQKVReturn instead of Attention
|
43 |
+
- Return the qkv tensors from the attention
|
44 |
+
"""
|
45 |
+
|
46 |
+
def forward(self, x: torch.Tensor, return_qkv: bool = False) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
|
47 |
+
# Note: this is copied from timm.models.vision_transformer.Block with modifications.
|
48 |
+
x_attn, qkv = self.attn(self.norm1(x))
|
49 |
+
x = x + self.drop_path1(self.ls1(x_attn))
|
50 |
+
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
51 |
+
if return_qkv:
|
52 |
+
return x, qkv
|
53 |
+
else:
|
54 |
+
return x
|
load_model.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from timm.models import create_model
|
7 |
+
from torchvision.models import get_model
|
8 |
+
|
9 |
+
from models import pdiscoformer_vit_bb, pdisconet_vit_bb, pdisconet_resnet_torchvision_bb
|
10 |
+
from models.individual_landmark_resnet import IndividualLandmarkResNet
|
11 |
+
from models.individual_landmark_convnext import IndividualLandmarkConvNext
|
12 |
+
from models.individual_landmark_vit import IndividualLandmarkViT
|
13 |
+
from utils import load_state_dict_pdisco
|
14 |
+
|
15 |
+
|
16 |
+
def load_model_arch(args, num_cls):
|
17 |
+
"""
|
18 |
+
Function to load the model
|
19 |
+
:param args: Arguments from the command line
|
20 |
+
:param num_cls: Number of classes in the dataset
|
21 |
+
:return:
|
22 |
+
"""
|
23 |
+
if 'resnet' in args.model_arch:
|
24 |
+
num_layers_split = [int(s) for s in args.model_arch if s.isdigit()]
|
25 |
+
num_layers = int(''.join(map(str, num_layers_split)))
|
26 |
+
if num_layers >= 100:
|
27 |
+
timm_model_arch = args.model_arch + ".a1h_in1k"
|
28 |
+
else:
|
29 |
+
timm_model_arch = args.model_arch + ".a1_in1k"
|
30 |
+
|
31 |
+
if "resnet" in args.model_arch and args.use_torchvision_resnet_model:
|
32 |
+
weights = "DEFAULT" if args.pretrained_start_weights else None
|
33 |
+
base_model = get_model(args.model_arch, weights=weights)
|
34 |
+
elif "resnet" in args.model_arch and not args.use_torchvision_resnet_model:
|
35 |
+
if args.eval_only:
|
36 |
+
base_model = create_model(
|
37 |
+
timm_model_arch,
|
38 |
+
pretrained=args.pretrained_start_weights,
|
39 |
+
num_classes=num_cls,
|
40 |
+
output_stride=args.output_stride,
|
41 |
+
)
|
42 |
+
else:
|
43 |
+
base_model = create_model(
|
44 |
+
timm_model_arch,
|
45 |
+
pretrained=args.pretrained_start_weights,
|
46 |
+
drop_path_rate=args.drop_path,
|
47 |
+
num_classes=num_cls,
|
48 |
+
output_stride=args.output_stride,
|
49 |
+
)
|
50 |
+
|
51 |
+
elif "convnext" in args.model_arch:
|
52 |
+
if args.eval_only:
|
53 |
+
base_model = create_model(
|
54 |
+
args.model_arch,
|
55 |
+
pretrained=args.pretrained_start_weights,
|
56 |
+
num_classes=num_cls,
|
57 |
+
output_stride=args.output_stride,
|
58 |
+
)
|
59 |
+
else:
|
60 |
+
base_model = create_model(
|
61 |
+
args.model_arch,
|
62 |
+
pretrained=args.pretrained_start_weights,
|
63 |
+
drop_path_rate=args.drop_path,
|
64 |
+
num_classes=num_cls,
|
65 |
+
output_stride=args.output_stride,
|
66 |
+
)
|
67 |
+
elif "vit" in args.model_arch:
|
68 |
+
if args.eval_only:
|
69 |
+
base_model = create_model(
|
70 |
+
args.model_arch,
|
71 |
+
pretrained=args.pretrained_start_weights,
|
72 |
+
img_size=args.image_size,
|
73 |
+
)
|
74 |
+
else:
|
75 |
+
base_model = create_model(
|
76 |
+
args.model_arch,
|
77 |
+
pretrained=args.pretrained_start_weights,
|
78 |
+
drop_path_rate=args.drop_path,
|
79 |
+
img_size=args.image_size,
|
80 |
+
)
|
81 |
+
vit_patch_size = base_model.patch_embed.proj.kernel_size[0]
|
82 |
+
if args.image_size % vit_patch_size != 0:
|
83 |
+
raise ValueError(f"Image size {args.image_size} must be divisible by patch size {vit_patch_size}")
|
84 |
+
else:
|
85 |
+
raise ValueError('Model not supported.')
|
86 |
+
|
87 |
+
return base_model
|
88 |
+
|
89 |
+
|
90 |
+
def init_pdisco_model(base_model, args, num_cls):
|
91 |
+
"""
|
92 |
+
Function to initialize the model
|
93 |
+
:param base_model: Base model
|
94 |
+
:param args: Arguments from the command line
|
95 |
+
:param num_cls: Number of classes in the dataset
|
96 |
+
:return:
|
97 |
+
"""
|
98 |
+
# Initialize the network
|
99 |
+
if 'convnext' in args.model_arch:
|
100 |
+
sl_channels = base_model.stages[-1].downsample[-1].in_channels
|
101 |
+
fl_channels = base_model.head.in_features
|
102 |
+
model = IndividualLandmarkConvNext(base_model, args.num_parts, num_classes=num_cls,
|
103 |
+
sl_channels=sl_channels, fl_channels=fl_channels,
|
104 |
+
part_dropout=args.part_dropout, modulation_type=args.modulation_type,
|
105 |
+
gumbel_softmax=args.gumbel_softmax,
|
106 |
+
gumbel_softmax_temperature=args.gumbel_softmax_temperature,
|
107 |
+
gumbel_softmax_hard=args.gumbel_softmax_hard,
|
108 |
+
modulation_orth=args.modulation_orth, classifier_type=args.classifier_type,
|
109 |
+
noise_variance=args.noise_variance)
|
110 |
+
elif 'resnet' in args.model_arch:
|
111 |
+
sl_channels = base_model.layer4[0].conv1.in_channels
|
112 |
+
fl_channels = base_model.fc.in_features
|
113 |
+
model = IndividualLandmarkResNet(base_model, args.num_parts, num_classes=num_cls,
|
114 |
+
sl_channels=sl_channels, fl_channels=fl_channels,
|
115 |
+
use_torchvision_model=args.use_torchvision_resnet_model,
|
116 |
+
part_dropout=args.part_dropout, modulation_type=args.modulation_type,
|
117 |
+
gumbel_softmax=args.gumbel_softmax,
|
118 |
+
gumbel_softmax_temperature=args.gumbel_softmax_temperature,
|
119 |
+
gumbel_softmax_hard=args.gumbel_softmax_hard,
|
120 |
+
modulation_orth=args.modulation_orth, classifier_type=args.classifier_type,
|
121 |
+
noise_variance=args.noise_variance)
|
122 |
+
elif 'vit' in args.model_arch:
|
123 |
+
model = IndividualLandmarkViT(base_model, num_landmarks=args.num_parts, num_classes=num_cls,
|
124 |
+
part_dropout=args.part_dropout,
|
125 |
+
modulation_type=args.modulation_type, gumbel_softmax=args.gumbel_softmax,
|
126 |
+
gumbel_softmax_temperature=args.gumbel_softmax_temperature,
|
127 |
+
gumbel_softmax_hard=args.gumbel_softmax_hard,
|
128 |
+
modulation_orth=args.modulation_orth, classifier_type=args.classifier_type,
|
129 |
+
noise_variance=args.noise_variance)
|
130 |
+
else:
|
131 |
+
raise ValueError('Model not supported.')
|
132 |
+
|
133 |
+
return model
|
134 |
+
|
135 |
+
|
136 |
+
def load_model_pdisco(args, num_cls):
|
137 |
+
"""
|
138 |
+
Function to load the model
|
139 |
+
:param args: Arguments from the command line
|
140 |
+
:param num_cls: Number of classes in the dataset
|
141 |
+
:return:
|
142 |
+
"""
|
143 |
+
base_model = load_model_arch(args, num_cls)
|
144 |
+
model = init_pdisco_model(base_model, args, num_cls)
|
145 |
+
|
146 |
+
return model
|
147 |
+
|
148 |
+
|
149 |
+
def pdiscoformer_vit(pretrained=True, model_dataset="cub", k=8, model_url="", img_size=224, num_cls=200):
|
150 |
+
"""
|
151 |
+
Function to load the PDiscoFormer model with ViT backbone
|
152 |
+
:param pretrained: Boolean flag to load the pretrained weights
|
153 |
+
:param model_dataset: Dataset for which the model is trained
|
154 |
+
:param k: Number of unsupervised landmarks the model is trained on
|
155 |
+
:param model_url: URL to load the model weights from
|
156 |
+
:param img_size: Image size
|
157 |
+
:param num_cls: Number of classes in the dataset
|
158 |
+
:return: PDiscoFormer model with ViT backbone
|
159 |
+
"""
|
160 |
+
model = pdiscoformer_vit_bb("vit_base_patch14_reg4_dinov2.lvd142m", num_cls=num_cls, k=k, img_size=img_size)
|
161 |
+
if pretrained:
|
162 |
+
hub_dir = torch.hub.get_dir()
|
163 |
+
model_dir = os.path.join(hub_dir, "pdiscoformer_checkpoints", f"pdiscoformer_{model_dataset}")
|
164 |
+
|
165 |
+
Path(model_dir).mkdir(parents=True, exist_ok=True)
|
166 |
+
url_path = model_url + str(k) + "_parts_snapshot_best.pt"
|
167 |
+
snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu')
|
168 |
+
if 'model_state' in snapshot_data:
|
169 |
+
_, state_dict = load_state_dict_pdisco(snapshot_data)
|
170 |
+
else:
|
171 |
+
state_dict = copy.deepcopy(snapshot_data)
|
172 |
+
model.load_state_dict(state_dict, strict=True)
|
173 |
+
return model
|
174 |
+
|
175 |
+
|
176 |
+
def pdisconet_vit(pretrained=True, model_dataset="nabirds", k=8, model_url="", img_size=224, num_cls=555):
|
177 |
+
"""
|
178 |
+
Function to load the PDiscoNet model with ViT backbone
|
179 |
+
:param pretrained: Boolean flag to load the pretrained weights
|
180 |
+
:param model_dataset: Dataset for which the model is trained
|
181 |
+
:param k: Number of unsupervised landmarks the model is trained on
|
182 |
+
:param model_url: URL to load the model weights from
|
183 |
+
:param img_size: Image size
|
184 |
+
:param num_cls: Number of classes in the dataset
|
185 |
+
:return: PDiscoNet model with ViT backbone
|
186 |
+
"""
|
187 |
+
model = pdisconet_vit_bb("vit_base_patch14_reg4_dinov2.lvd142m", num_cls=num_cls, k=k, img_size=img_size)
|
188 |
+
if pretrained:
|
189 |
+
hub_dir = torch.hub.get_dir()
|
190 |
+
model_dir = os.path.join(hub_dir, "pdiscoformer_checkpoints", f"pdisconet_{model_dataset}")
|
191 |
+
|
192 |
+
Path(model_dir).mkdir(parents=True, exist_ok=True)
|
193 |
+
url_path = model_url + str(k) + "_parts_snapshot_best.pt"
|
194 |
+
snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu')
|
195 |
+
if 'model_state' in snapshot_data:
|
196 |
+
_, state_dict = load_state_dict_pdisco(snapshot_data)
|
197 |
+
else:
|
198 |
+
state_dict = copy.deepcopy(snapshot_data)
|
199 |
+
model.load_state_dict(state_dict, strict=True)
|
200 |
+
return model
|
201 |
+
|
202 |
+
|
203 |
+
def pdisconet_resnet101(pretrained=True, model_dataset="nabirds", k=8, model_url="", num_cls=555):
|
204 |
+
"""
|
205 |
+
Function to load the PDiscoNet model with ResNet-101 backbone
|
206 |
+
:param pretrained: Boolean flag to load the pretrained weights
|
207 |
+
:param model_dataset: Dataset for which the model is trained
|
208 |
+
:param k: Number of unsupervised landmarks the model is trained on
|
209 |
+
:param model_url: URL to load the model weights from
|
210 |
+
:param num_cls: Number of classes in the dataset
|
211 |
+
:return: PDiscoNet model with ResNet-101 backbone
|
212 |
+
"""
|
213 |
+
model = pdisconet_resnet_torchvision_bb("resnet101", num_cls=num_cls, k=k)
|
214 |
+
if pretrained:
|
215 |
+
hub_dir = torch.hub.get_dir()
|
216 |
+
model_dir = os.path.join(hub_dir, "pdiscoformer_checkpoints", f"pdisconet_{model_dataset}")
|
217 |
+
|
218 |
+
Path(model_dir).mkdir(parents=True, exist_ok=True)
|
219 |
+
url_path = model_url + str(k) + "_parts_snapshot_best.pt"
|
220 |
+
snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu')
|
221 |
+
if 'model_state' in snapshot_data:
|
222 |
+
_, state_dict = load_state_dict_pdisco(snapshot_data)
|
223 |
+
else:
|
224 |
+
state_dict = copy.deepcopy(snapshot_data)
|
225 |
+
model.load_state_dict(state_dict, strict=True)
|
226 |
+
return model
|
models/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .individual_landmark_resnet import *
|
2 |
+
from .individual_landmark_convnext import *
|
3 |
+
from .vit_baseline import *
|
4 |
+
from .individual_landmark_vit import *
|
models/individual_landmark_convnext.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import Tensor
|
3 |
+
from torch.nn import Parameter
|
4 |
+
from typing import Any
|
5 |
+
from layers.independent_mlp import IndependentMLPs
|
6 |
+
|
7 |
+
|
8 |
+
# Baseline model, a modified convnext with reduced downsampling for a spatially larger feature tensor in the last layer
|
9 |
+
class IndividualLandmarkConvNext(torch.nn.Module):
|
10 |
+
def __init__(self, init_model: torch.nn.Module, num_landmarks: int = 8,
|
11 |
+
num_classes: int = 200, sl_channels: int = 1024, fl_channels: int = 2048, part_dropout: float = 0.3,
|
12 |
+
modulation_type: str = "original", modulation_orth: bool = False, gumbel_softmax: bool = False,
|
13 |
+
gumbel_softmax_temperature: float = 1.0, gumbel_softmax_hard: bool = False,
|
14 |
+
classifier_type: str = "linear", noise_variance: float = 0.0) -> None:
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
self.num_landmarks = num_landmarks
|
18 |
+
self.num_classes = num_classes
|
19 |
+
self.noise_variance = noise_variance
|
20 |
+
self.stem = init_model.stem
|
21 |
+
self.stages = init_model.stages
|
22 |
+
self.feature_dim = sl_channels + fl_channels
|
23 |
+
self.fc_landmarks = torch.nn.Conv2d(self.feature_dim, num_landmarks + 1, 1, bias=False)
|
24 |
+
self.gumbel_softmax = gumbel_softmax
|
25 |
+
self.gumbel_softmax_temperature = gumbel_softmax_temperature
|
26 |
+
self.gumbel_softmax_hard = gumbel_softmax_hard
|
27 |
+
self.modulation_type = modulation_type
|
28 |
+
if modulation_type == "layer_norm":
|
29 |
+
self.modulation = torch.nn.LayerNorm([self.feature_dim, self.num_landmarks + 1])
|
30 |
+
elif modulation_type == "original":
|
31 |
+
self.modulation = torch.nn.Parameter(torch.ones(1, self.feature_dim, self.num_landmarks + 1))
|
32 |
+
elif modulation_type == "parallel_mlp":
|
33 |
+
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
|
34 |
+
num_lin_layers=1, act_layer=True, bias=True)
|
35 |
+
elif modulation_type == "parallel_mlp_no_bias":
|
36 |
+
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
|
37 |
+
num_lin_layers=1, act_layer=True, bias=False)
|
38 |
+
elif modulation_type == "parallel_mlp_no_act":
|
39 |
+
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
|
40 |
+
num_lin_layers=1, act_layer=False, bias=True)
|
41 |
+
elif modulation_type == "parallel_mlp_no_act_no_bias":
|
42 |
+
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
|
43 |
+
num_lin_layers=1, act_layer=False, bias=False)
|
44 |
+
elif modulation_type == "none":
|
45 |
+
self.modulation = torch.nn.Identity()
|
46 |
+
else:
|
47 |
+
raise ValueError("modulation_type not implemented")
|
48 |
+
self.modulation_orth = modulation_orth
|
49 |
+
self.dropout_full_landmarks = torch.nn.Dropout1d(part_dropout)
|
50 |
+
self.classifier_type = classifier_type
|
51 |
+
if classifier_type == "independent_mlp":
|
52 |
+
self.fc_class_landmarks = IndependentMLPs(part_dim=self.num_landmarks, latent_dim=self.feature_dim,
|
53 |
+
num_lin_layers=1, act_layer=False, out_dim=num_classes,
|
54 |
+
bias=False, stack_dim=1)
|
55 |
+
elif classifier_type == "linear":
|
56 |
+
self.fc_class_landmarks = torch.nn.Linear(in_features=self.feature_dim, out_features=num_classes,
|
57 |
+
bias=False)
|
58 |
+
else:
|
59 |
+
raise ValueError("classifier_type not implemented")
|
60 |
+
|
61 |
+
def forward(self, x: Tensor) -> tuple[Any, Any, Any, Any, Parameter, int | Any]:
|
62 |
+
# Pretrained ConvNeXt part of the model
|
63 |
+
x = self.stem(x)
|
64 |
+
x = self.stages[0](x)
|
65 |
+
x = self.stages[1](x)
|
66 |
+
l3 = self.stages[2](x)
|
67 |
+
x = self.stages[3](l3)
|
68 |
+
x = torch.nn.functional.interpolate(x, size=(l3.shape[-2], l3.shape[-1]), mode='bilinear', align_corners=False)
|
69 |
+
x = torch.cat((x, l3), dim=1)
|
70 |
+
|
71 |
+
# Compute per landmark attention maps
|
72 |
+
# (b - a)^2 = b^2 - 2ab + a^2, b = feature maps resnet, a = convolution kernel
|
73 |
+
batch_size = x.shape[0]
|
74 |
+
ab = self.fc_landmarks(x)
|
75 |
+
b_sq = x.pow(2).sum(1, keepdim=True)
|
76 |
+
b_sq = b_sq.expand(-1, self.num_landmarks + 1, -1, -1).contiguous()
|
77 |
+
a_sq = self.fc_landmarks.weight.pow(2).sum(1).unsqueeze(1).expand(-1, batch_size, x.shape[-2],
|
78 |
+
x.shape[-1]).contiguous()
|
79 |
+
a_sq = a_sq.permute(1, 0, 2, 3).contiguous()
|
80 |
+
|
81 |
+
dist = b_sq - 2 * ab + a_sq
|
82 |
+
maps = -dist
|
83 |
+
|
84 |
+
# Softmax so that the attention maps for each pixel add up to 1
|
85 |
+
if self.gumbel_softmax:
|
86 |
+
maps = torch.nn.functional.gumbel_softmax(maps, dim=1, tau=self.gumbel_softmax_temperature,
|
87 |
+
hard=self.gumbel_softmax_hard) # [B, num_landmarks + 1, H, W]
|
88 |
+
else:
|
89 |
+
maps = torch.nn.functional.softmax(maps, dim=1) # [B, num_landmarks + 1, H, W]
|
90 |
+
|
91 |
+
# Use maps to get weighted average features per landmark
|
92 |
+
all_features = (maps.unsqueeze(1) * x.unsqueeze(2)).mean(-1).mean(-1).contiguous()
|
93 |
+
if self.noise_variance > 0.0:
|
94 |
+
all_features += torch.randn_like(all_features,
|
95 |
+
device=all_features.device) * x.std().detach() * self.noise_variance
|
96 |
+
|
97 |
+
# Modulate the features
|
98 |
+
if self.modulation_type == "original":
|
99 |
+
all_features_mod = all_features * self.modulation
|
100 |
+
else:
|
101 |
+
all_features_mod = self.modulation(all_features)
|
102 |
+
|
103 |
+
# Classification based on the landmark features
|
104 |
+
scores = self.fc_class_landmarks(
|
105 |
+
self.dropout_full_landmarks(all_features_mod[..., :-1].permute(0, 2, 1).contiguous())).permute(0, 2,
|
106 |
+
1).contiguous()
|
107 |
+
if self.modulation_orth:
|
108 |
+
return all_features_mod, maps, scores, dist
|
109 |
+
else:
|
110 |
+
return all_features, maps, scores, dist
|
models/individual_landmark_resnet.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from https://github.com/robertdvdk/part_detection/blob/main/nets.py
|
2 |
+
import torch
|
3 |
+
from torch import Tensor
|
4 |
+
from timm.models import create_model
|
5 |
+
from torchvision.models import get_model
|
6 |
+
from torch.nn import Parameter
|
7 |
+
from typing import Any
|
8 |
+
from layers.independent_mlp import IndependentMLPs
|
9 |
+
|
10 |
+
|
11 |
+
# Baseline model, a modified ResNet with reduced downsampling for a spatially larger feature tensor in the last layer
|
12 |
+
class IndividualLandmarkResNet(torch.nn.Module):
|
13 |
+
def __init__(self, init_model: torch.nn.Module, num_landmarks: int = 8,
|
14 |
+
num_classes: int = 200, sl_channels: int = 1024, fl_channels: int = 2048,
|
15 |
+
use_torchvision_model: bool = False, part_dropout: float = 0.3,
|
16 |
+
modulation_type: str = "original", modulation_orth: bool = False, gumbel_softmax: bool = False,
|
17 |
+
gumbel_softmax_temperature: float = 1.0, gumbel_softmax_hard: bool = False,
|
18 |
+
classifier_type: str = "linear", noise_variance: float = 0.0) -> None:
|
19 |
+
super().__init__()
|
20 |
+
|
21 |
+
self.num_landmarks = num_landmarks
|
22 |
+
self.num_classes = num_classes
|
23 |
+
self.noise_variance = noise_variance
|
24 |
+
self.conv1 = init_model.conv1
|
25 |
+
self.bn1 = init_model.bn1
|
26 |
+
if use_torchvision_model:
|
27 |
+
self.act1 = init_model.relu
|
28 |
+
else:
|
29 |
+
self.act1 = init_model.act1
|
30 |
+
self.maxpool = init_model.maxpool
|
31 |
+
self.layer1 = init_model.layer1
|
32 |
+
self.layer2 = init_model.layer2
|
33 |
+
self.layer3 = init_model.layer3
|
34 |
+
self.layer4 = init_model.layer4
|
35 |
+
self.feature_dim = sl_channels + fl_channels
|
36 |
+
self.fc_landmarks = torch.nn.Conv2d(self.feature_dim, num_landmarks + 1, 1, bias=False)
|
37 |
+
self.gumbel_softmax = gumbel_softmax
|
38 |
+
self.gumbel_softmax_temperature = gumbel_softmax_temperature
|
39 |
+
self.gumbel_softmax_hard = gumbel_softmax_hard
|
40 |
+
self.modulation_type = modulation_type
|
41 |
+
if modulation_type == "layer_norm":
|
42 |
+
self.modulation = torch.nn.LayerNorm([self.feature_dim, self.num_landmarks + 1])
|
43 |
+
elif modulation_type == "original":
|
44 |
+
self.modulation = torch.nn.Parameter(torch.ones(1, self.feature_dim, self.num_landmarks + 1))
|
45 |
+
elif modulation_type == "parallel_mlp":
|
46 |
+
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
|
47 |
+
num_lin_layers=1, act_layer=True, bias=True)
|
48 |
+
elif modulation_type == "parallel_mlp_no_bias":
|
49 |
+
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
|
50 |
+
num_lin_layers=1, act_layer=True, bias=False)
|
51 |
+
elif modulation_type == "parallel_mlp_no_act":
|
52 |
+
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
|
53 |
+
num_lin_layers=1, act_layer=False, bias=True)
|
54 |
+
elif modulation_type == "parallel_mlp_no_act_no_bias":
|
55 |
+
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
|
56 |
+
num_lin_layers=1, act_layer=False, bias=False)
|
57 |
+
elif modulation_type == "none":
|
58 |
+
self.modulation = torch.nn.Identity()
|
59 |
+
else:
|
60 |
+
raise ValueError("modulation_type not implemented")
|
61 |
+
|
62 |
+
self.modulation_orth = modulation_orth
|
63 |
+
|
64 |
+
self.dropout_full_landmarks = torch.nn.Dropout1d(part_dropout)
|
65 |
+
self.classifier_type = classifier_type
|
66 |
+
if classifier_type == "independent_mlp":
|
67 |
+
self.fc_class_landmarks = IndependentMLPs(part_dim=self.num_landmarks, latent_dim=self.feature_dim,
|
68 |
+
num_lin_layers=1, act_layer=False, out_dim=num_classes,
|
69 |
+
bias=False, stack_dim=1)
|
70 |
+
elif classifier_type == "linear":
|
71 |
+
self.fc_class_landmarks = torch.nn.Linear(in_features=self.feature_dim, out_features=num_classes,
|
72 |
+
bias=False)
|
73 |
+
else:
|
74 |
+
raise ValueError("classifier_type not implemented")
|
75 |
+
|
76 |
+
def forward(self, x: Tensor) -> tuple[Any, Any, Any, Any, Parameter, int | Any]:
|
77 |
+
# Pretrained ResNet part of the model
|
78 |
+
x = self.conv1(x)
|
79 |
+
x = self.bn1(x)
|
80 |
+
x = self.act1(x)
|
81 |
+
x = self.maxpool(x)
|
82 |
+
x = self.layer1(x)
|
83 |
+
x = self.layer2(x)
|
84 |
+
l3 = self.layer3(x)
|
85 |
+
x = self.layer4(l3)
|
86 |
+
x = torch.nn.functional.interpolate(x, size=(l3.shape[-2], l3.shape[-1]), mode='bilinear', align_corners=False)
|
87 |
+
x = torch.cat((x, l3), dim=1)
|
88 |
+
|
89 |
+
# Compute per landmark attention maps
|
90 |
+
# (b - a)^2 = b^2 - 2ab + a^2, b = feature maps resnet, a = convolution kernel
|
91 |
+
batch_size = x.shape[0]
|
92 |
+
|
93 |
+
ab = self.fc_landmarks(x)
|
94 |
+
b_sq = x.pow(2).sum(1, keepdim=True)
|
95 |
+
b_sq = b_sq.expand(-1, self.num_landmarks + 1, -1, -1).contiguous()
|
96 |
+
a_sq = self.fc_landmarks.weight.pow(2).sum(1).unsqueeze(1).expand(-1, batch_size, x.shape[-2],
|
97 |
+
x.shape[-1]).contiguous()
|
98 |
+
a_sq = a_sq.permute(1, 0, 2, 3).contiguous()
|
99 |
+
|
100 |
+
dist = b_sq - 2 * ab + a_sq
|
101 |
+
maps = -dist
|
102 |
+
|
103 |
+
# Softmax so that the attention maps for each pixel add up to 1
|
104 |
+
if self.gumbel_softmax:
|
105 |
+
maps = torch.nn.functional.gumbel_softmax(maps, dim=1, tau=self.gumbel_softmax_temperature,
|
106 |
+
hard=self.gumbel_softmax_hard) # [B, num_landmarks + 1, H, W]
|
107 |
+
else:
|
108 |
+
maps = torch.nn.functional.softmax(maps, dim=1) # [B, num_landmarks + 1, H, W]
|
109 |
+
|
110 |
+
# Use maps to get weighted average features per landmark
|
111 |
+
all_features = (maps.unsqueeze(1) * x.unsqueeze(2)).mean(-1).mean(-1).contiguous()
|
112 |
+
if self.noise_variance > 0.0:
|
113 |
+
all_features += torch.randn_like(all_features,
|
114 |
+
device=all_features.device) * x.std().detach() * self.noise_variance
|
115 |
+
|
116 |
+
# Modulate the features
|
117 |
+
if self.modulation_type == "original":
|
118 |
+
all_features_mod = all_features * self.modulation
|
119 |
+
else:
|
120 |
+
all_features_mod = self.modulation(all_features)
|
121 |
+
|
122 |
+
# Classification based on the landmark features
|
123 |
+
scores = self.fc_class_landmarks(
|
124 |
+
self.dropout_full_landmarks(all_features_mod[..., :-1].permute(0, 2, 1).contiguous())).permute(0, 2,
|
125 |
+
1).contiguous()
|
126 |
+
if self.modulation_orth:
|
127 |
+
return all_features_mod, maps, scores, dist
|
128 |
+
else:
|
129 |
+
return all_features, maps, scores, dist
|
130 |
+
|
131 |
+
|
132 |
+
def pdisconet_resnet_torchvision_bb(backbone, num_cls=200, k=8, **kwargs):
|
133 |
+
base_model = get_model(backbone)
|
134 |
+
return IndividualLandmarkResNet(base_model, num_landmarks=k, num_classes=num_cls,
|
135 |
+
modulation_type="original")
|
136 |
+
|
137 |
+
|
138 |
+
def pdisconet_resnet_timm_bb(backbone, num_cls=200, k=8, output_stride=32, **kwargs):
|
139 |
+
base_model = create_model(backbone, pretrained=True, output_stride=output_stride)
|
140 |
+
return IndividualLandmarkResNet(base_model, num_landmarks=k, num_classes=num_cls,
|
141 |
+
modulation_type="original")
|
models/individual_landmark_vit.py
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Compostion of the VisionTransformer class from timm with extra features: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
|
2 |
+
from pathlib import Path
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch import Tensor
|
7 |
+
from typing import Any, Union, Sequence, Optional, Dict
|
8 |
+
|
9 |
+
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
10 |
+
|
11 |
+
from timm.models import create_model
|
12 |
+
from timm.models.vision_transformer import Block, Attention
|
13 |
+
from utils.misc_utils import compute_attention
|
14 |
+
|
15 |
+
from layers.transformer_layers import BlockWQKVReturn, AttentionWQKVReturn
|
16 |
+
from layers.independent_mlp import IndependentMLPs
|
17 |
+
|
18 |
+
SAFETENSORS_SINGLE_FILE = "model.safetensors"
|
19 |
+
|
20 |
+
|
21 |
+
class IndividualLandmarkViT(torch.nn.Module, PyTorchModelHubMixin,
|
22 |
+
pipeline_tag='image-classification',
|
23 |
+
repo_url='https://github.com/ananthu-aniraj/pdiscoformer'):
|
24 |
+
|
25 |
+
def __init__(self, init_model: torch.nn.Module, num_landmarks: int = 8, num_classes: int = 200,
|
26 |
+
part_dropout: float = 0.3, return_transformer_qkv: bool = False,
|
27 |
+
modulation_type: str = "original", gumbel_softmax: bool = False,
|
28 |
+
gumbel_softmax_temperature: float = 1.0, gumbel_softmax_hard: bool = False,
|
29 |
+
modulation_orth: bool = False, classifier_type: str = "linear", noise_variance: float = 0.0) -> None:
|
30 |
+
super().__init__()
|
31 |
+
self.num_landmarks = num_landmarks
|
32 |
+
self.num_classes = num_classes
|
33 |
+
self.noise_variance = noise_variance
|
34 |
+
self.num_prefix_tokens = init_model.num_prefix_tokens
|
35 |
+
self.num_reg_tokens = init_model.num_reg_tokens
|
36 |
+
self.has_class_token = init_model.has_class_token
|
37 |
+
self.no_embed_class = init_model.no_embed_class
|
38 |
+
self.cls_token = init_model.cls_token
|
39 |
+
self.reg_token = init_model.reg_token
|
40 |
+
|
41 |
+
self.feature_dim = init_model.embed_dim
|
42 |
+
self.patch_embed = init_model.patch_embed
|
43 |
+
self.pos_embed = init_model.pos_embed
|
44 |
+
self.pos_drop = init_model.pos_drop
|
45 |
+
self.norm_pre = init_model.norm_pre
|
46 |
+
self.blocks = init_model.blocks
|
47 |
+
self.norm = init_model.norm
|
48 |
+
self.return_transformer_qkv = return_transformer_qkv
|
49 |
+
self.h_fmap = int(self.patch_embed.img_size[0] // self.patch_embed.patch_size[0])
|
50 |
+
self.w_fmap = int(self.patch_embed.img_size[1] // self.patch_embed.patch_size[1])
|
51 |
+
|
52 |
+
self.unflatten = nn.Unflatten(1, (self.h_fmap, self.w_fmap))
|
53 |
+
self.fc_landmarks = torch.nn.Conv2d(self.feature_dim, num_landmarks + 1, 1, bias=False)
|
54 |
+
self.gumbel_softmax = gumbel_softmax
|
55 |
+
self.gumbel_softmax_temperature = gumbel_softmax_temperature
|
56 |
+
self.gumbel_softmax_hard = gumbel_softmax_hard
|
57 |
+
self.modulation_type = modulation_type
|
58 |
+
if modulation_type == "layer_norm":
|
59 |
+
self.modulation = torch.nn.LayerNorm([self.feature_dim, self.num_landmarks + 1])
|
60 |
+
elif modulation_type == "original":
|
61 |
+
self.modulation = torch.nn.Parameter(torch.ones(1, self.feature_dim, self.num_landmarks + 1))
|
62 |
+
elif modulation_type == "parallel_mlp":
|
63 |
+
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
|
64 |
+
num_lin_layers=1, act_layer=True, bias=True)
|
65 |
+
elif modulation_type == "parallel_mlp_no_bias":
|
66 |
+
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
|
67 |
+
num_lin_layers=1, act_layer=True, bias=False)
|
68 |
+
elif modulation_type == "parallel_mlp_no_act":
|
69 |
+
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
|
70 |
+
num_lin_layers=1, act_layer=False, bias=True)
|
71 |
+
elif modulation_type == "parallel_mlp_no_act_no_bias":
|
72 |
+
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
|
73 |
+
num_lin_layers=1, act_layer=False, bias=False)
|
74 |
+
elif modulation_type == "none":
|
75 |
+
self.modulation = torch.nn.Identity()
|
76 |
+
else:
|
77 |
+
raise ValueError("modulation_type not implemented")
|
78 |
+
self.modulation_orth = modulation_orth
|
79 |
+
self.dropout_full_landmarks = torch.nn.Dropout1d(part_dropout)
|
80 |
+
self.classifier_type = classifier_type
|
81 |
+
if classifier_type == "independent_mlp":
|
82 |
+
self.fc_class_landmarks = IndependentMLPs(part_dim=self.num_landmarks, latent_dim=self.feature_dim,
|
83 |
+
num_lin_layers=1, act_layer=False, out_dim=num_classes,
|
84 |
+
bias=False, stack_dim=1)
|
85 |
+
elif classifier_type == "linear":
|
86 |
+
self.fc_class_landmarks = torch.nn.Linear(in_features=self.feature_dim, out_features=num_classes,
|
87 |
+
bias=False)
|
88 |
+
else:
|
89 |
+
raise ValueError("classifier_type not implemented")
|
90 |
+
self.convert_blocks_and_attention()
|
91 |
+
self._init_weights()
|
92 |
+
|
93 |
+
def _init_weights_head(self):
|
94 |
+
# Initialize weights with a truncated normal distribution
|
95 |
+
if self.classifier_type == "independent_mlp":
|
96 |
+
self.fc_class_landmarks.reset_weights()
|
97 |
+
else:
|
98 |
+
torch.nn.init.trunc_normal_(self.fc_class_landmarks.weight, std=0.02)
|
99 |
+
if self.fc_class_landmarks.bias is not None:
|
100 |
+
torch.nn.init.zeros_(self.fc_class_landmarks.bias)
|
101 |
+
|
102 |
+
def _init_weights(self):
|
103 |
+
self._init_weights_head()
|
104 |
+
|
105 |
+
def convert_blocks_and_attention(self):
|
106 |
+
for module in self.modules():
|
107 |
+
if isinstance(module, Block):
|
108 |
+
module.__class__ = BlockWQKVReturn
|
109 |
+
elif isinstance(module, Attention):
|
110 |
+
module.__class__ = AttentionWQKVReturn
|
111 |
+
|
112 |
+
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
|
113 |
+
pos_embed = self.pos_embed
|
114 |
+
to_cat = []
|
115 |
+
if self.cls_token is not None:
|
116 |
+
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
|
117 |
+
if self.reg_token is not None:
|
118 |
+
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
|
119 |
+
if self.no_embed_class:
|
120 |
+
# deit-3, updated JAX (big vision)
|
121 |
+
# position embedding does not overlap with class token, add then concat
|
122 |
+
x = x + pos_embed
|
123 |
+
if to_cat:
|
124 |
+
x = torch.cat(to_cat + [x], dim=1)
|
125 |
+
else:
|
126 |
+
# original timm, JAX, and deit vit impl
|
127 |
+
# pos_embed has entry for class token, concat then add
|
128 |
+
if to_cat:
|
129 |
+
x = torch.cat(to_cat + [x], dim=1)
|
130 |
+
x = x + pos_embed
|
131 |
+
return self.pos_drop(x)
|
132 |
+
|
133 |
+
def forward(self, x: Tensor) -> tuple[Any, Any, Any, Any, int | Any] | tuple[Any, Any, Any, Any, int | Any]:
|
134 |
+
|
135 |
+
x = self.patch_embed(x)
|
136 |
+
|
137 |
+
# Position Embedding
|
138 |
+
x = self._pos_embed(x)
|
139 |
+
|
140 |
+
# Forward pass through transformer
|
141 |
+
x = self.norm_pre(x)
|
142 |
+
|
143 |
+
x = self.blocks(x)
|
144 |
+
x = self.norm(x)
|
145 |
+
|
146 |
+
# Compute per landmark attention maps
|
147 |
+
# (b - a)^2 = b^2 - 2ab + a^2, b = feature maps vit, a = convolution kernel
|
148 |
+
batch_size = x.shape[0]
|
149 |
+
x = x[:, self.num_prefix_tokens:, :] # [B, num_patch_tokens, embed_dim]
|
150 |
+
x = self.unflatten(x) # [B, H, W, embed_dim]
|
151 |
+
x = x.permute(0, 3, 1, 2).contiguous() # [B, embed_dim, H, W]
|
152 |
+
ab = self.fc_landmarks(x) # [B, num_landmarks + 1, H, W]
|
153 |
+
b_sq = x.pow(2).sum(1, keepdim=True)
|
154 |
+
b_sq = b_sq.expand(-1, self.num_landmarks + 1, -1, -1).contiguous()
|
155 |
+
a_sq = self.fc_landmarks.weight.pow(2).sum(1, keepdim=True).expand(-1, batch_size, x.shape[-2],
|
156 |
+
x.shape[-1]).contiguous()
|
157 |
+
a_sq = a_sq.permute(1, 0, 2, 3).contiguous()
|
158 |
+
|
159 |
+
dist = b_sq - 2 * ab + a_sq
|
160 |
+
maps = -dist
|
161 |
+
|
162 |
+
# Softmax so that the attention maps for each pixel add up to 1
|
163 |
+
if self.gumbel_softmax:
|
164 |
+
maps = torch.nn.functional.gumbel_softmax(maps, dim=1, tau=self.gumbel_softmax_temperature,
|
165 |
+
hard=self.gumbel_softmax_hard) # [B, num_landmarks + 1, H, W]
|
166 |
+
else:
|
167 |
+
maps = torch.nn.functional.softmax(maps, dim=1) # [B, num_landmarks + 1, H, W]
|
168 |
+
|
169 |
+
# Use maps to get weighted average features per landmark
|
170 |
+
all_features = (maps.unsqueeze(1) * x.unsqueeze(2)).contiguous()
|
171 |
+
if self.noise_variance > 0.0:
|
172 |
+
all_features += torch.randn_like(all_features,
|
173 |
+
device=all_features.device) * x.std().detach() * self.noise_variance
|
174 |
+
|
175 |
+
all_features = all_features.mean(-1).mean(-1).contiguous() # [B, embed_dim, num_landmarks + 1]
|
176 |
+
|
177 |
+
# Modulate the features
|
178 |
+
if self.modulation_type == "original":
|
179 |
+
all_features_mod = all_features * self.modulation # [B, embed_dim, num_landmarks + 1]
|
180 |
+
else:
|
181 |
+
all_features_mod = self.modulation(all_features) # [B, embed_dim, num_landmarks + 1]
|
182 |
+
|
183 |
+
# Classification based on the landmark features
|
184 |
+
scores = self.fc_class_landmarks(
|
185 |
+
self.dropout_full_landmarks(all_features_mod[..., :-1].permute(0, 2, 1).contiguous())).permute(0, 2,
|
186 |
+
1).contiguous()
|
187 |
+
if self.modulation_orth:
|
188 |
+
return all_features_mod, maps, scores, dist
|
189 |
+
else:
|
190 |
+
return all_features, maps, scores, dist
|
191 |
+
|
192 |
+
def get_specific_intermediate_layer(
|
193 |
+
self,
|
194 |
+
x: torch.Tensor,
|
195 |
+
n: int = 1,
|
196 |
+
return_qkv: bool = False,
|
197 |
+
return_att_weights: bool = False,
|
198 |
+
):
|
199 |
+
num_blocks = len(self.blocks)
|
200 |
+
attn_weights = []
|
201 |
+
if n >= num_blocks:
|
202 |
+
raise ValueError(f"n must be less than {num_blocks}")
|
203 |
+
|
204 |
+
# forward pass
|
205 |
+
x = self.patch_embed(x)
|
206 |
+
x = self._pos_embed(x)
|
207 |
+
x = self.norm_pre(x)
|
208 |
+
|
209 |
+
if n == -1:
|
210 |
+
if return_qkv:
|
211 |
+
raise ValueError("take_indice cannot be -1 if return_transformer_qkv is True")
|
212 |
+
else:
|
213 |
+
return x
|
214 |
+
|
215 |
+
for i, blk in enumerate(self.blocks):
|
216 |
+
if self.return_transformer_qkv:
|
217 |
+
x, qkv = blk(x, return_qkv=True)
|
218 |
+
|
219 |
+
if return_att_weights:
|
220 |
+
attn_weight, _ = compute_attention(qkv)
|
221 |
+
attn_weights.append(attn_weight.detach())
|
222 |
+
else:
|
223 |
+
x = blk(x)
|
224 |
+
if i == n:
|
225 |
+
output = x.clone()
|
226 |
+
if self.return_transformer_qkv and return_qkv:
|
227 |
+
qkv_output = qkv.clone()
|
228 |
+
break
|
229 |
+
if self.return_transformer_qkv and return_qkv and return_att_weights:
|
230 |
+
return output, qkv_output, attn_weights
|
231 |
+
elif self.return_transformer_qkv and return_qkv:
|
232 |
+
return output, qkv_output
|
233 |
+
elif self.return_transformer_qkv and return_att_weights:
|
234 |
+
return output, attn_weights
|
235 |
+
else:
|
236 |
+
return output
|
237 |
+
|
238 |
+
def _intermediate_layers(
|
239 |
+
self,
|
240 |
+
x: torch.Tensor,
|
241 |
+
n: Union[int, Sequence] = 1,
|
242 |
+
):
|
243 |
+
outputs, num_blocks = [], len(self.blocks)
|
244 |
+
if self.return_transformer_qkv:
|
245 |
+
qkv_outputs = []
|
246 |
+
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
|
247 |
+
|
248 |
+
# forward pass
|
249 |
+
x = self.patch_embed(x)
|
250 |
+
x = self._pos_embed(x)
|
251 |
+
x = self.norm_pre(x)
|
252 |
+
|
253 |
+
for i, blk in enumerate(self.blocks):
|
254 |
+
if self.return_transformer_qkv:
|
255 |
+
x, qkv = blk(x, return_qkv=True)
|
256 |
+
else:
|
257 |
+
x = blk(x)
|
258 |
+
if i in take_indices:
|
259 |
+
outputs.append(x)
|
260 |
+
if self.return_transformer_qkv:
|
261 |
+
qkv_outputs.append(qkv)
|
262 |
+
if self.return_transformer_qkv:
|
263 |
+
return outputs, qkv_outputs
|
264 |
+
else:
|
265 |
+
return outputs
|
266 |
+
|
267 |
+
def get_intermediate_layers(
|
268 |
+
self,
|
269 |
+
x: torch.Tensor,
|
270 |
+
n: Union[int, Sequence] = 1,
|
271 |
+
reshape: bool = False,
|
272 |
+
return_prefix_tokens: bool = False,
|
273 |
+
norm: bool = False,
|
274 |
+
) -> tuple[tuple, Any]:
|
275 |
+
""" Intermediate layer accessor (NOTE: This is a WIP experiment).
|
276 |
+
Inspired by DINO / DINOv2 interface
|
277 |
+
"""
|
278 |
+
# take last n blocks if n is an int, if in is a sequence, select by matching indices
|
279 |
+
if self.return_transformer_qkv:
|
280 |
+
outputs, qkv = self._intermediate_layers(x, n)
|
281 |
+
else:
|
282 |
+
outputs = self._intermediate_layers(x, n)
|
283 |
+
|
284 |
+
if norm:
|
285 |
+
outputs = [self.norm(out) for out in outputs]
|
286 |
+
prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
|
287 |
+
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
|
288 |
+
|
289 |
+
if reshape:
|
290 |
+
grid_size = self.patch_embed.grid_size
|
291 |
+
outputs = [
|
292 |
+
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
|
293 |
+
for out in outputs
|
294 |
+
]
|
295 |
+
|
296 |
+
if return_prefix_tokens:
|
297 |
+
return_out = tuple(zip(outputs, prefix_tokens))
|
298 |
+
else:
|
299 |
+
return_out = tuple(outputs)
|
300 |
+
|
301 |
+
if self.return_transformer_qkv:
|
302 |
+
return return_out, qkv
|
303 |
+
else:
|
304 |
+
return return_out
|
305 |
+
|
306 |
+
@classmethod
|
307 |
+
def _from_pretrained(
|
308 |
+
cls,
|
309 |
+
*,
|
310 |
+
model_id: str,
|
311 |
+
revision: Optional[str],
|
312 |
+
cache_dir: Optional[Union[str, Path]],
|
313 |
+
force_download: bool,
|
314 |
+
proxies: Optional[Dict],
|
315 |
+
resume_download: Optional[bool],
|
316 |
+
local_files_only: bool,
|
317 |
+
token: Union[str, bool, None],
|
318 |
+
map_location: str = "cpu",
|
319 |
+
strict: bool = False,
|
320 |
+
timm_backbone: str = "hf_hub:timm/vit_base_patch14_reg4_dinov2.lvd142m",
|
321 |
+
input_size: int = 518,
|
322 |
+
**model_kwargs):
|
323 |
+
base_model = create_model(timm_backbone, pretrained=False, img_size=input_size)
|
324 |
+
model = cls(base_model, **model_kwargs)
|
325 |
+
if os.path.isdir(model_id):
|
326 |
+
print("Loading weights from local directory")
|
327 |
+
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
328 |
+
return cls._load_as_safetensor(model, model_file, map_location, strict)
|
329 |
+
else:
|
330 |
+
model_file = hf_hub_download(
|
331 |
+
repo_id=model_id,
|
332 |
+
filename=SAFETENSORS_SINGLE_FILE,
|
333 |
+
revision=revision,
|
334 |
+
cache_dir=cache_dir,
|
335 |
+
force_download=force_download,
|
336 |
+
proxies=proxies,
|
337 |
+
resume_download=resume_download,
|
338 |
+
token=token,
|
339 |
+
local_files_only=local_files_only,
|
340 |
+
)
|
341 |
+
return cls._load_as_safetensor(model, model_file, map_location, strict)
|
342 |
+
|
343 |
+
|
344 |
+
def pdiscoformer_vit_bb(backbone, img_size=224, num_cls=200, k=8, **kwargs):
|
345 |
+
base_model = create_model(
|
346 |
+
backbone,
|
347 |
+
pretrained=False,
|
348 |
+
img_size=img_size,
|
349 |
+
)
|
350 |
+
|
351 |
+
model = IndividualLandmarkViT(base_model, num_landmarks=k, num_classes=num_cls,
|
352 |
+
modulation_type="layer_norm", gumbel_softmax=True,
|
353 |
+
modulation_orth=True)
|
354 |
+
return model
|
355 |
+
|
356 |
+
|
357 |
+
def pdisconet_vit_bb(backbone, img_size=224, num_cls=200, k=8, **kwargs):
|
358 |
+
base_model = create_model(
|
359 |
+
backbone,
|
360 |
+
pretrained=False,
|
361 |
+
img_size=img_size,
|
362 |
+
)
|
363 |
+
|
364 |
+
model = IndividualLandmarkViT(base_model, num_landmarks=k, num_classes=num_cls,
|
365 |
+
modulation_type="original")
|
366 |
+
return model
|
models/vit_baseline.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Compostion of the VisionTransformer class from timm with extra features: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from typing import Tuple, Union, Sequence, Any
|
5 |
+
from timm.layers import trunc_normal_
|
6 |
+
from timm.models.vision_transformer import Block, Attention
|
7 |
+
from layers.transformer_layers import BlockWQKVReturn, AttentionWQKVReturn
|
8 |
+
|
9 |
+
from utils.misc_utils import compute_attention
|
10 |
+
|
11 |
+
|
12 |
+
class BaselineViT(torch.nn.Module):
|
13 |
+
"""
|
14 |
+
Modifications:
|
15 |
+
- Use PDiscoBlock instead of Block
|
16 |
+
- Use PDiscoAttention instead of Attention
|
17 |
+
- Return the mean of k over heads from attention
|
18 |
+
- Option to use only class tokens or only patch tokens or both (concat) for classification
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, init_model: torch.nn.Module, num_classes: int,
|
22 |
+
class_tokens_only: bool = False,
|
23 |
+
patch_tokens_only: bool = False, return_transformer_qkv: bool = False) -> None:
|
24 |
+
super().__init__()
|
25 |
+
self.num_classes = num_classes
|
26 |
+
self.class_tokens_only = class_tokens_only
|
27 |
+
self.patch_tokens_only = patch_tokens_only
|
28 |
+
self.num_prefix_tokens = init_model.num_prefix_tokens
|
29 |
+
self.num_reg_tokens = init_model.num_reg_tokens
|
30 |
+
self.has_class_token = init_model.has_class_token
|
31 |
+
self.no_embed_class = init_model.no_embed_class
|
32 |
+
self.cls_token = init_model.cls_token
|
33 |
+
self.reg_token = init_model.reg_token
|
34 |
+
|
35 |
+
self.patch_embed = init_model.patch_embed
|
36 |
+
|
37 |
+
self.pos_embed = init_model.pos_embed
|
38 |
+
self.pos_drop = init_model.pos_drop
|
39 |
+
self.part_embed = nn.Identity()
|
40 |
+
self.patch_prune = nn.Identity()
|
41 |
+
self.norm_pre = init_model.norm_pre
|
42 |
+
self.blocks = init_model.blocks
|
43 |
+
self.norm = init_model.norm
|
44 |
+
|
45 |
+
self.fc_norm = init_model.fc_norm
|
46 |
+
if class_tokens_only or patch_tokens_only:
|
47 |
+
self.head = nn.Linear(init_model.embed_dim, num_classes)
|
48 |
+
else:
|
49 |
+
self.head = nn.Linear(init_model.embed_dim * 2, num_classes)
|
50 |
+
|
51 |
+
self.h_fmap = int(self.patch_embed.img_size[0] // self.patch_embed.patch_size[0])
|
52 |
+
self.w_fmap = int(self.patch_embed.img_size[1] // self.patch_embed.patch_size[1])
|
53 |
+
|
54 |
+
self.return_transformer_qkv = return_transformer_qkv
|
55 |
+
self.convert_blocks_and_attention()
|
56 |
+
self._init_weights_head()
|
57 |
+
|
58 |
+
def convert_blocks_and_attention(self):
|
59 |
+
for module in self.modules():
|
60 |
+
if isinstance(module, Block):
|
61 |
+
module.__class__ = BlockWQKVReturn
|
62 |
+
elif isinstance(module, Attention):
|
63 |
+
module.__class__ = AttentionWQKVReturn
|
64 |
+
|
65 |
+
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
|
66 |
+
pos_embed = self.pos_embed
|
67 |
+
to_cat = []
|
68 |
+
if self.cls_token is not None:
|
69 |
+
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
|
70 |
+
if self.reg_token is not None:
|
71 |
+
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
|
72 |
+
if self.no_embed_class:
|
73 |
+
# deit-3, updated JAX (big vision)
|
74 |
+
# position embedding does not overlap with class token, add then concat
|
75 |
+
x = x + pos_embed
|
76 |
+
if to_cat:
|
77 |
+
x = torch.cat(to_cat + [x], dim=1)
|
78 |
+
else:
|
79 |
+
# original timm, JAX, and deit vit impl
|
80 |
+
# pos_embed has entry for class token, concat then add
|
81 |
+
if to_cat:
|
82 |
+
x = torch.cat(to_cat + [x], dim=1)
|
83 |
+
x = x + pos_embed
|
84 |
+
return self.pos_drop(x)
|
85 |
+
|
86 |
+
def _init_weights_head(self):
|
87 |
+
trunc_normal_(self.head.weight, std=.02)
|
88 |
+
if self.head.bias is not None:
|
89 |
+
nn.init.constant_(self.head.bias, 0.)
|
90 |
+
|
91 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
|
92 |
+
|
93 |
+
x = self.patch_embed(x)
|
94 |
+
|
95 |
+
# Position Embedding
|
96 |
+
x = self._pos_embed(x)
|
97 |
+
|
98 |
+
x = self.part_embed(x)
|
99 |
+
x = self.patch_prune(x)
|
100 |
+
|
101 |
+
# Forward pass through transformer
|
102 |
+
x = self.norm_pre(x)
|
103 |
+
|
104 |
+
if self.return_transformer_qkv:
|
105 |
+
# Return keys of last attention layer
|
106 |
+
for i, blk in enumerate(self.blocks):
|
107 |
+
x, qkv = blk(x, return_qkv=True)
|
108 |
+
else:
|
109 |
+
x = self.blocks(x)
|
110 |
+
|
111 |
+
x = self.norm(x)
|
112 |
+
|
113 |
+
# Classification head
|
114 |
+
x = self.fc_norm(x)
|
115 |
+
if self.class_tokens_only: # only use class token
|
116 |
+
x = x[:, 0, :]
|
117 |
+
elif self.patch_tokens_only: # only use patch tokens
|
118 |
+
x = x[:, self.num_prefix_tokens:, :].mean(dim=1)
|
119 |
+
else:
|
120 |
+
x = torch.cat([x[:, 0, :], x[:, self.num_prefix_tokens:, :].mean(dim=1)], dim=1)
|
121 |
+
x = self.head(x)
|
122 |
+
if self.return_transformer_qkv:
|
123 |
+
return x, qkv
|
124 |
+
else:
|
125 |
+
return x
|
126 |
+
|
127 |
+
def get_specific_intermediate_layer(
|
128 |
+
self,
|
129 |
+
x: torch.Tensor,
|
130 |
+
n: int = 1,
|
131 |
+
return_qkv: bool = False,
|
132 |
+
return_att_weights: bool = False,
|
133 |
+
):
|
134 |
+
num_blocks = len(self.blocks)
|
135 |
+
attn_weights = []
|
136 |
+
if n >= num_blocks:
|
137 |
+
raise ValueError(f"n must be less than {num_blocks}")
|
138 |
+
|
139 |
+
# forward pass
|
140 |
+
x = self.patch_embed(x)
|
141 |
+
x = self._pos_embed(x)
|
142 |
+
x = self.norm_pre(x)
|
143 |
+
|
144 |
+
if n == -1:
|
145 |
+
if return_qkv:
|
146 |
+
raise ValueError("take_indice cannot be -1 if return_transformer_qkv is True")
|
147 |
+
else:
|
148 |
+
return x
|
149 |
+
|
150 |
+
for i, blk in enumerate(self.blocks):
|
151 |
+
if self.return_transformer_qkv:
|
152 |
+
x, qkv = blk(x, return_qkv=True)
|
153 |
+
|
154 |
+
if return_att_weights:
|
155 |
+
attn_weight, _ = compute_attention(qkv)
|
156 |
+
attn_weights.append(attn_weight.detach())
|
157 |
+
else:
|
158 |
+
x = blk(x)
|
159 |
+
if i == n:
|
160 |
+
output = x.clone()
|
161 |
+
if self.return_transformer_qkv and return_qkv:
|
162 |
+
qkv_output = qkv.clone()
|
163 |
+
break
|
164 |
+
if self.return_transformer_qkv and return_qkv and return_att_weights:
|
165 |
+
return output, qkv_output, attn_weights
|
166 |
+
elif self.return_transformer_qkv and return_qkv:
|
167 |
+
return output, qkv_output
|
168 |
+
elif self.return_transformer_qkv and return_att_weights:
|
169 |
+
return output, attn_weights
|
170 |
+
else:
|
171 |
+
return output
|
172 |
+
|
173 |
+
def _intermediate_layers(
|
174 |
+
self,
|
175 |
+
x: torch.Tensor,
|
176 |
+
n: Union[int, Sequence] = 1,
|
177 |
+
):
|
178 |
+
outputs, num_blocks = [], len(self.blocks)
|
179 |
+
if self.return_transformer_qkv:
|
180 |
+
qkv_outputs = []
|
181 |
+
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
|
182 |
+
|
183 |
+
# forward pass
|
184 |
+
x = self.patch_embed(x)
|
185 |
+
x = self._pos_embed(x)
|
186 |
+
x = self.norm_pre(x)
|
187 |
+
|
188 |
+
for i, blk in enumerate(self.blocks):
|
189 |
+
if self.return_transformer_qkv:
|
190 |
+
x, qkv = blk(x, return_qkv=True)
|
191 |
+
else:
|
192 |
+
x = blk(x)
|
193 |
+
if i in take_indices:
|
194 |
+
outputs.append(x)
|
195 |
+
if self.return_transformer_qkv:
|
196 |
+
qkv_outputs.append(qkv)
|
197 |
+
if self.return_transformer_qkv:
|
198 |
+
return outputs, qkv_outputs
|
199 |
+
else:
|
200 |
+
return outputs
|
201 |
+
|
202 |
+
def get_intermediate_layers(
|
203 |
+
self,
|
204 |
+
x: torch.Tensor,
|
205 |
+
n: Union[int, Sequence] = 1,
|
206 |
+
reshape: bool = False,
|
207 |
+
return_prefix_tokens: bool = False,
|
208 |
+
norm: bool = False,
|
209 |
+
) -> tuple[tuple, Any]:
|
210 |
+
""" Intermediate layer accessor (NOTE: This is a WIP experiment).
|
211 |
+
Inspired by DINO / DINOv2 interface
|
212 |
+
"""
|
213 |
+
# take last n blocks if n is an int, if in is a sequence, select by matching indices
|
214 |
+
if self.return_transformer_qkv:
|
215 |
+
outputs, qkv = self._intermediate_layers(x, n)
|
216 |
+
else:
|
217 |
+
outputs = self._intermediate_layers(x, n)
|
218 |
+
|
219 |
+
if norm:
|
220 |
+
outputs = [self.norm(out) for out in outputs]
|
221 |
+
prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
|
222 |
+
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
|
223 |
+
|
224 |
+
if reshape:
|
225 |
+
grid_size = self.patch_embed.grid_size
|
226 |
+
outputs = [
|
227 |
+
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
|
228 |
+
for out in outputs
|
229 |
+
]
|
230 |
+
|
231 |
+
if return_prefix_tokens:
|
232 |
+
return_out = tuple(zip(outputs, prefix_tokens))
|
233 |
+
else:
|
234 |
+
return_out = tuple(outputs)
|
235 |
+
|
236 |
+
if self.return_transformer_qkv:
|
237 |
+
return return_out, qkv
|
238 |
+
else:
|
239 |
+
return return_out
|
requirements.txt
CHANGED
@@ -3,4 +3,8 @@ timm
|
|
3 |
colorcet
|
4 |
matplotlib
|
5 |
torchvision
|
6 |
-
streamlit
|
|
|
|
|
|
|
|
|
|
3 |
colorcet
|
4 |
matplotlib
|
5 |
torchvision
|
6 |
+
streamlit
|
7 |
+
numpy
|
8 |
+
pillow
|
9 |
+
scikit-image
|
10 |
+
huggingface-hub
|
utils/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .data_utils import *
|
2 |
+
from .visualize_att_maps import *
|
3 |
+
from .misc_utils import *
|
4 |
+
from .get_landmark_coordinates import *
|
5 |
+
|
6 |
+
|
utils/data_utils/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .dataset_utils import *
|
2 |
+
from .reversible_affine_transform import *
|
3 |
+
from .transform_utils import *
|
4 |
+
from .class_balanced_distributed_sampler import *
|
5 |
+
from .class_balanced_sampler import *
|
utils/data_utils/class_balanced_distributed_sampler.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
from typing import Optional
|
4 |
+
import math
|
5 |
+
import torch.distributed as dist
|
6 |
+
|
7 |
+
|
8 |
+
class ClassBalancedDistributedSampler(torch.utils.data.Sampler):
|
9 |
+
"""
|
10 |
+
A custom sampler that sub-samples a given dataset based on class labels. Based on the DistributedSampler class
|
11 |
+
Ref: https://github.com/pytorch/pytorch/blob/04c1df651aa58bea50977f4efcf19b09ce27cefd/torch/utils/data/distributed.py#L13
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, rank: Optional[int] = None,
|
15 |
+
shuffle: bool = True, seed: int = 0, drop_last: bool = False, num_samples_per_class=100) -> None:
|
16 |
+
|
17 |
+
if not shuffle:
|
18 |
+
raise ValueError("ClassBalancedDatasetSubSampler requires shuffling, otherwise use DistributedSampler")
|
19 |
+
|
20 |
+
# Check if the dataset has a generate_class_balanced_indices method
|
21 |
+
if not hasattr(dataset, 'generate_class_balanced_indices'):
|
22 |
+
raise ValueError("Dataset does not have a generate_class_balanced_indices method")
|
23 |
+
|
24 |
+
self.shuffle = shuffle
|
25 |
+
self.seed = seed
|
26 |
+
if num_replicas is None:
|
27 |
+
if not dist.is_available():
|
28 |
+
raise RuntimeError("Requires distributed package to be available")
|
29 |
+
num_replicas = dist.get_world_size()
|
30 |
+
if rank is None:
|
31 |
+
if not dist.is_available():
|
32 |
+
raise RuntimeError("Requires distributed package to be available")
|
33 |
+
rank = dist.get_rank()
|
34 |
+
if rank >= num_replicas or rank < 0:
|
35 |
+
raise ValueError(
|
36 |
+
f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]")
|
37 |
+
self.dataset = dataset
|
38 |
+
self.num_replicas = num_replicas
|
39 |
+
self.rank = rank
|
40 |
+
self.epoch = 0
|
41 |
+
self.drop_last = drop_last
|
42 |
+
|
43 |
+
# Calculate the number of samples
|
44 |
+
g = torch.Generator()
|
45 |
+
g.manual_seed(self.seed + self.epoch)
|
46 |
+
self.num_samples_per_class = num_samples_per_class
|
47 |
+
indices = dataset.generate_class_balanced_indices(torch.Generator(),
|
48 |
+
num_samples_per_class=num_samples_per_class)
|
49 |
+
dataset_size = len(indices)
|
50 |
+
|
51 |
+
# If the dataset length is evenly divisible by # of replicas, then there
|
52 |
+
# is no need to drop any data, since the dataset will be split equally.
|
53 |
+
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
|
54 |
+
# Split to nearest available length that is evenly divisible.
|
55 |
+
# This is to ensure each rank receives the same amount of data when
|
56 |
+
# using this Sampler.
|
57 |
+
self.num_samples = math.ceil(
|
58 |
+
(dataset_size - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
|
59 |
+
)
|
60 |
+
else:
|
61 |
+
self.num_samples = math.ceil(dataset_size / self.num_replicas) # type: ignore[arg-type]
|
62 |
+
self.total_size = self.num_samples * self.num_replicas
|
63 |
+
|
64 |
+
def __iter__(self):
|
65 |
+
# deterministically shuffle based on epoch and seed, here shuffle is assumed to be True
|
66 |
+
g = torch.Generator()
|
67 |
+
g.manual_seed(self.seed + self.epoch)
|
68 |
+
indices = self.dataset.generate_class_balanced_indices(g, num_samples_per_class=self.num_samples_per_class)
|
69 |
+
|
70 |
+
if not self.drop_last:
|
71 |
+
# add extra samples to make it evenly divisible
|
72 |
+
padding_size = self.total_size - len(indices)
|
73 |
+
if padding_size <= len(indices):
|
74 |
+
indices += indices[:padding_size]
|
75 |
+
else:
|
76 |
+
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
|
77 |
+
else:
|
78 |
+
# remove tail of data to make it evenly divisible.
|
79 |
+
indices = indices[:self.total_size]
|
80 |
+
|
81 |
+
# subsample
|
82 |
+
indices = indices[self.rank:self.total_size:self.num_replicas]
|
83 |
+
|
84 |
+
return iter(indices)
|
85 |
+
|
86 |
+
def __len__(self) -> int:
|
87 |
+
return self.num_samples
|
88 |
+
|
89 |
+
def set_epoch(self, epoch: int) -> None:
|
90 |
+
r"""
|
91 |
+
Set the epoch for this sampler.
|
92 |
+
|
93 |
+
When :attr:`shuffle=True`, this ensures all replicas
|
94 |
+
use a different random ordering for each epoch. Otherwise, the next iteration of this
|
95 |
+
sampler will yield the same ordering.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
epoch (int): Epoch number.
|
99 |
+
"""
|
100 |
+
self.epoch = epoch
|
utils/data_utils/class_balanced_sampler.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
|
4 |
+
|
5 |
+
class ClassBalancedRandomSampler(torch.utils.data.Sampler):
|
6 |
+
"""
|
7 |
+
A custom sampler that sub-samples a given dataset based on class labels. Based on the RandomSampler class
|
8 |
+
This is essentially the non-ddp version of ClassBalancedDistributedSampler
|
9 |
+
Ref: https://github.com/pytorch/pytorch/blob/abe3c55a6a01c5b625eeb4fc9aab1421a5965cd2/torch/utils/data/sampler.py#L117
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, dataset: Dataset, num_samples_per_class=100, seed: int = 0) -> None:
|
13 |
+
self.dataset = dataset
|
14 |
+
self.seed = seed
|
15 |
+
# Calculate the number of samples
|
16 |
+
self.generator = torch.Generator()
|
17 |
+
self.generator.manual_seed(self.seed)
|
18 |
+
self.num_samples_per_class = num_samples_per_class
|
19 |
+
indices = dataset.generate_class_balanced_indices(self.generator,
|
20 |
+
num_samples_per_class=num_samples_per_class)
|
21 |
+
self.num_samples = len(indices)
|
22 |
+
|
23 |
+
def __iter__(self):
|
24 |
+
# Change seed for every function call
|
25 |
+
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
26 |
+
self.generator.manual_seed(seed)
|
27 |
+
indices = self.dataset.generate_class_balanced_indices(self.generator, num_samples_per_class=self.num_samples_per_class)
|
28 |
+
return iter(indices)
|
29 |
+
|
30 |
+
def __len__(self) -> int:
|
31 |
+
return self.num_samples
|
utils/data_utils/dataset_utils.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from torch import Tensor
|
3 |
+
from typing import List, Optional
|
4 |
+
import numpy as np
|
5 |
+
import torchvision
|
6 |
+
import json
|
7 |
+
|
8 |
+
|
9 |
+
def load_json(path: str):
|
10 |
+
"""
|
11 |
+
Load json file from path and return the data
|
12 |
+
:param path: Path to the json file
|
13 |
+
:return:
|
14 |
+
data: Data in the json file
|
15 |
+
"""
|
16 |
+
with open(path, 'r') as f:
|
17 |
+
data = json.load(f)
|
18 |
+
return data
|
19 |
+
|
20 |
+
|
21 |
+
def save_json(data: dict, path: str):
|
22 |
+
"""
|
23 |
+
Save data to a json file
|
24 |
+
:param data: Data to be saved
|
25 |
+
:param path: Path to save the data
|
26 |
+
:return:
|
27 |
+
"""
|
28 |
+
with open(path, "w") as f:
|
29 |
+
json.dump(data, f)
|
30 |
+
|
31 |
+
|
32 |
+
def pil_loader(path):
|
33 |
+
"""
|
34 |
+
Load image from path using PIL
|
35 |
+
:param path: Path to the image
|
36 |
+
:return:
|
37 |
+
img: PIL Image
|
38 |
+
"""
|
39 |
+
with open(path, 'rb') as f:
|
40 |
+
img = Image.open(f)
|
41 |
+
return img.convert('RGB')
|
42 |
+
|
43 |
+
|
44 |
+
def get_dimensions(image: Tensor):
|
45 |
+
"""
|
46 |
+
Get the dimensions of the image
|
47 |
+
:param image: Tensor or PIL Image or np.ndarray
|
48 |
+
:return:
|
49 |
+
h: Height of the image
|
50 |
+
w: Width of the image
|
51 |
+
"""
|
52 |
+
if isinstance(image, Tensor):
|
53 |
+
_, h, w = image.shape
|
54 |
+
elif isinstance(image, np.ndarray):
|
55 |
+
h, w, _ = image.shape
|
56 |
+
elif isinstance(image, Image.Image):
|
57 |
+
w, h = image.size
|
58 |
+
else:
|
59 |
+
raise ValueError(f"Invalid image type: {type(image)}")
|
60 |
+
return h, w
|
61 |
+
|
62 |
+
|
63 |
+
def center_crop_boxes_kps(img: Tensor, output_size: Optional[List[int]] = 448, parts: Optional[Tensor] = None,
|
64 |
+
boxes: Optional[Tensor] = None, num_keypoints: int = 15):
|
65 |
+
"""
|
66 |
+
Calculate the center crop parameters for the bounding boxes and landmarks and update them
|
67 |
+
:param img: Image
|
68 |
+
:param output_size: Output size of the cropped image
|
69 |
+
:param parts: Locations of the landmarks of following format: <part_id> <x> <y> <visible>
|
70 |
+
:param boxes: Bounding boxes of the landmarks of following format: <image_id> <x> <y> <width> <height>
|
71 |
+
:param num_keypoints: Number of keypoints
|
72 |
+
:return:
|
73 |
+
cropped_img: Center cropped image
|
74 |
+
parts: Updated locations of the landmarks
|
75 |
+
boxes: Updated bounding boxes of the landmarks
|
76 |
+
"""
|
77 |
+
if isinstance(output_size, int):
|
78 |
+
output_size = (output_size, output_size)
|
79 |
+
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
|
80 |
+
output_size = (output_size[0], output_size[0])
|
81 |
+
elif isinstance(output_size, (tuple, list)) and len(output_size) == 2:
|
82 |
+
output_size = output_size
|
83 |
+
else:
|
84 |
+
raise ValueError(f"Invalid output size: {output_size}")
|
85 |
+
|
86 |
+
crop_height, crop_width = output_size
|
87 |
+
image_height, image_width = get_dimensions(img)
|
88 |
+
img = torchvision.transforms.functional.center_crop(img, output_size)
|
89 |
+
|
90 |
+
crop_top, crop_left = _get_center_crop_params_(image_height, image_width, output_size)
|
91 |
+
|
92 |
+
if parts is not None:
|
93 |
+
for j in range(num_keypoints):
|
94 |
+
# Skip if part is invisible
|
95 |
+
if parts[j][-1] == 0:
|
96 |
+
continue
|
97 |
+
parts[j][1] -= crop_left
|
98 |
+
parts[j][2] -= crop_top
|
99 |
+
|
100 |
+
# Skip if part is outside the crop
|
101 |
+
if parts[j][1] > crop_width or parts[j][2] > crop_height:
|
102 |
+
parts[j][-1] = 0
|
103 |
+
if parts[j][1] < 0 or parts[j][2] < 0:
|
104 |
+
parts[j][-1] = 0
|
105 |
+
|
106 |
+
parts[j][1] = min(crop_width, parts[j][1])
|
107 |
+
parts[j][2] = min(crop_height, parts[j][2])
|
108 |
+
parts[j][1] = max(0, parts[j][1])
|
109 |
+
parts[j][2] = max(0, parts[j][2])
|
110 |
+
|
111 |
+
if boxes is not None:
|
112 |
+
boxes[1] -= crop_left
|
113 |
+
boxes[2] -= crop_top
|
114 |
+
boxes[1] = max(0, boxes[1])
|
115 |
+
boxes[2] = max(0, boxes[2])
|
116 |
+
boxes[1] = min(crop_width, boxes[1])
|
117 |
+
boxes[2] = min(crop_height, boxes[2])
|
118 |
+
|
119 |
+
return img, parts, boxes
|
120 |
+
|
121 |
+
|
122 |
+
def _get_center_crop_params_(image_height: int, image_width: int, output_size: Optional[List[int]] = 448):
|
123 |
+
"""
|
124 |
+
Get the parameters for center cropping the image
|
125 |
+
:param image_height: Height of the image
|
126 |
+
:param image_width: Width of the image
|
127 |
+
:param output_size: Output size of the cropped image
|
128 |
+
:return:
|
129 |
+
crop_top: Top coordinate of the cropped image
|
130 |
+
crop_left: Left coordinate of the cropped image
|
131 |
+
"""
|
132 |
+
if isinstance(output_size, int):
|
133 |
+
output_size = (output_size, output_size)
|
134 |
+
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
|
135 |
+
output_size = (output_size[0], output_size[0])
|
136 |
+
elif isinstance(output_size, (tuple, list)) and len(output_size) == 2:
|
137 |
+
output_size = output_size
|
138 |
+
else:
|
139 |
+
raise ValueError(f"Invalid output size: {output_size}")
|
140 |
+
|
141 |
+
crop_height, crop_width = output_size
|
142 |
+
|
143 |
+
if crop_width > image_width or crop_height > image_height:
|
144 |
+
padding_ltrb = [
|
145 |
+
(crop_width - image_width) // 2 if crop_width > image_width else 0,
|
146 |
+
(crop_height - image_height) // 2 if crop_height > image_height else 0,
|
147 |
+
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
|
148 |
+
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
|
149 |
+
]
|
150 |
+
crop_top, crop_left = padding_ltrb[1], padding_ltrb[0]
|
151 |
+
return crop_top, crop_left
|
152 |
+
|
153 |
+
if crop_width == image_width and crop_height == image_height:
|
154 |
+
crop_top = 0
|
155 |
+
crop_left = 0
|
156 |
+
return crop_top, crop_left
|
157 |
+
|
158 |
+
crop_top = int(round((image_height - crop_height) / 2.0))
|
159 |
+
crop_left = int(round((image_width - crop_width) / 2.0))
|
160 |
+
|
161 |
+
return crop_top, crop_left
|
utils/data_utils/reversible_affine_transform.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Description: This file contains the code for the reversible affine transform
|
2 |
+
import torchvision.transforms as transforms
|
3 |
+
import torch
|
4 |
+
from typing import List, Optional, Tuple, Any
|
5 |
+
|
6 |
+
|
7 |
+
def generate_affine_trans_params(
|
8 |
+
degrees: List[float],
|
9 |
+
translate: Optional[List[float]],
|
10 |
+
scale_ranges: Optional[List[float]],
|
11 |
+
shears: Optional[List[float]],
|
12 |
+
img_size: List[int],
|
13 |
+
) -> Tuple[float, Tuple[int, int], float, Any]:
|
14 |
+
"""Get parameters for affine transformation
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
params to be passed to the affine transformation
|
18 |
+
"""
|
19 |
+
angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
|
20 |
+
if translate is not None:
|
21 |
+
max_dx = float(translate[0] * img_size[0])
|
22 |
+
max_dy = float(translate[1] * img_size[1])
|
23 |
+
tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
|
24 |
+
ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
|
25 |
+
translations = (tx, ty)
|
26 |
+
else:
|
27 |
+
translations = (0, 0)
|
28 |
+
|
29 |
+
if scale_ranges is not None:
|
30 |
+
scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item())
|
31 |
+
else:
|
32 |
+
scale = 1.0
|
33 |
+
|
34 |
+
shear_x = shear_y = 0.0
|
35 |
+
if shears is not None:
|
36 |
+
shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item())
|
37 |
+
if len(shears) == 4:
|
38 |
+
shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item())
|
39 |
+
|
40 |
+
shear = (shear_x, shear_y)
|
41 |
+
if shear_x == 0.0 and shear_y == 0.0:
|
42 |
+
shear = 0.0
|
43 |
+
|
44 |
+
return angle, translations, scale, shear
|
45 |
+
|
46 |
+
|
47 |
+
def rigid_transform(img, angle, translate, scale, invert=False, shear=0,
|
48 |
+
interpolation=transforms.InterpolationMode.BILINEAR):
|
49 |
+
"""
|
50 |
+
Affine transforms input image
|
51 |
+
Modified from: https://github.com/robertdvdk/part_detection/blob/eec53f2f40602113f74c6c1f60a2034823b0fcaf/lib.py#L54
|
52 |
+
Parameters
|
53 |
+
----------
|
54 |
+
img: Tensor
|
55 |
+
Input image
|
56 |
+
angle: int
|
57 |
+
Rotation angle between -180 and 180 degrees
|
58 |
+
translate: [int]
|
59 |
+
Sequence of horizontal/vertical translations
|
60 |
+
scale: float
|
61 |
+
How to scale the image
|
62 |
+
invert: bool
|
63 |
+
Whether to invert the transformation
|
64 |
+
shear: float
|
65 |
+
Shear angle in degrees
|
66 |
+
interpolation: InterpolationMode
|
67 |
+
Interpolation mode to calculate output values
|
68 |
+
Returns
|
69 |
+
----------
|
70 |
+
img: Tensor
|
71 |
+
Transformed image
|
72 |
+
|
73 |
+
"""
|
74 |
+
if not invert:
|
75 |
+
img = transforms.functional.affine(img, angle=angle, translate=translate, scale=scale, shear=shear,
|
76 |
+
interpolation=interpolation)
|
77 |
+
else:
|
78 |
+
translate = [-t for t in translate]
|
79 |
+
img = transforms.functional.affine(img=img, angle=0, translate=translate, scale=1, shear=shear)
|
80 |
+
img = transforms.functional.affine(img=img, angle=-angle, translate=[0, 0], scale=1 / scale, shear=shear)
|
81 |
+
|
82 |
+
return img
|
utils/data_utils/transform_utils.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision import transforms as transforms
|
3 |
+
from torchvision.transforms import Compose
|
4 |
+
|
5 |
+
from timm.data.constants import \
|
6 |
+
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
7 |
+
from timm.data import create_transform
|
8 |
+
|
9 |
+
|
10 |
+
def make_train_transforms(args):
|
11 |
+
train_transforms: Compose = transforms.Compose([
|
12 |
+
transforms.Resize(size=args.image_size, antialias=True),
|
13 |
+
transforms.RandomHorizontalFlip(p=args.hflip),
|
14 |
+
transforms.RandomVerticalFlip(p=args.vflip),
|
15 |
+
transforms.ColorJitter(),
|
16 |
+
transforms.RandomAffine(degrees=90, translate=(0.2, 0.2), scale=(0.8, 1.2)),
|
17 |
+
transforms.RandomCrop(args.image_size),
|
18 |
+
transforms.ToTensor(),
|
19 |
+
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
|
20 |
+
|
21 |
+
])
|
22 |
+
return train_transforms
|
23 |
+
|
24 |
+
|
25 |
+
def make_test_transforms(args):
|
26 |
+
test_transforms: Compose = transforms.Compose([
|
27 |
+
transforms.Resize(size=args.image_size, antialias=True),
|
28 |
+
transforms.CenterCrop(args.image_size),
|
29 |
+
transforms.ToTensor(),
|
30 |
+
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
|
31 |
+
|
32 |
+
])
|
33 |
+
return test_transforms
|
34 |
+
|
35 |
+
|
36 |
+
def build_transform_timm(args, is_train=True):
|
37 |
+
resize_im = args.image_size > 32
|
38 |
+
imagenet_default_mean_and_std = args.imagenet_default_mean_and_std
|
39 |
+
mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
|
40 |
+
std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
|
41 |
+
|
42 |
+
if is_train:
|
43 |
+
# this should always dispatch to transforms_imagenet_train
|
44 |
+
transform = create_transform(
|
45 |
+
input_size=args.image_size,
|
46 |
+
is_training=True,
|
47 |
+
color_jitter=args.color_jitter,
|
48 |
+
hflip=args.hflip,
|
49 |
+
vflip=args.vflip,
|
50 |
+
auto_augment=args.aa,
|
51 |
+
interpolation=args.train_interpolation,
|
52 |
+
re_prob=args.reprob,
|
53 |
+
re_mode=args.remode,
|
54 |
+
re_count=args.recount,
|
55 |
+
mean=mean,
|
56 |
+
std=std,
|
57 |
+
)
|
58 |
+
if not resize_im:
|
59 |
+
transform.transforms[0] = transforms.RandomCrop(
|
60 |
+
args.image_size, padding=4)
|
61 |
+
return transform
|
62 |
+
|
63 |
+
t = []
|
64 |
+
if resize_im:
|
65 |
+
# warping (no cropping) when evaluated at 384 or larger
|
66 |
+
if args.image_size >= 384:
|
67 |
+
t.append(
|
68 |
+
transforms.Resize((args.image_size, args.image_size),
|
69 |
+
interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),
|
70 |
+
)
|
71 |
+
print(f"Warping {args.image_size} size input images...")
|
72 |
+
else:
|
73 |
+
if args.crop_pct is None:
|
74 |
+
args.crop_pct = 224 / 256
|
75 |
+
size = int(args.image_size / args.crop_pct)
|
76 |
+
t.append(
|
77 |
+
# to maintain same ratio w.r.t. 224 images
|
78 |
+
transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),
|
79 |
+
)
|
80 |
+
t.append(transforms.CenterCrop(args.image_size))
|
81 |
+
|
82 |
+
t.append(transforms.ToTensor())
|
83 |
+
t.append(transforms.Normalize(mean, std))
|
84 |
+
return transforms.Compose(t)
|
85 |
+
|
86 |
+
|
87 |
+
def inverse_normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD):
|
88 |
+
mean = torch.as_tensor(mean)
|
89 |
+
std = torch.as_tensor(std)
|
90 |
+
un_normalize = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
|
91 |
+
return un_normalize
|
92 |
+
|
93 |
+
|
94 |
+
def normalize_only(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD):
|
95 |
+
normalize = transforms.Normalize(mean=mean, std=std)
|
96 |
+
return normalize
|
97 |
+
|
98 |
+
|
99 |
+
def inverse_normalize_w_resize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
|
100 |
+
resize_resolution=(256, 256)):
|
101 |
+
mean = torch.as_tensor(mean)
|
102 |
+
std = torch.as_tensor(std)
|
103 |
+
resize_unnorm = transforms.Compose([
|
104 |
+
transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist()),
|
105 |
+
transforms.Resize(size=resize_resolution, antialias=True)])
|
106 |
+
return resize_unnorm
|
107 |
+
|
108 |
+
|
109 |
+
def load_transforms(args):
|
110 |
+
# Get the transforms and load the dataset
|
111 |
+
if args.augmentations_to_use == 'timm':
|
112 |
+
train_transforms = build_transform_timm(args, is_train=True)
|
113 |
+
elif args.augmentations_to_use == 'cub_original':
|
114 |
+
train_transforms = make_train_transforms(args)
|
115 |
+
else:
|
116 |
+
raise ValueError('Augmentations not supported.')
|
117 |
+
test_transforms = make_test_transforms(args)
|
118 |
+
return train_transforms, test_transforms
|
utils/get_landmark_coordinates.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file contains the function to generate the center coordinates as tensor for the current net.
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
def landmark_coordinates(maps, grid_x=None, grid_y=None):
|
6 |
+
"""
|
7 |
+
Generate the center coordinates as tensor for the current net.
|
8 |
+
Modified from: https://github.com/robertdvdk/part_detection/blob/eec53f2f40602113f74c6c1f60a2034823b0fcaf/lib.py#L19
|
9 |
+
Parameters
|
10 |
+
----------
|
11 |
+
maps: torch.Tensor
|
12 |
+
Attention map with shape (batch_size, channels, height, width) where channels is the landmark probability
|
13 |
+
grid_x: torch.Tensor
|
14 |
+
The grid x coordinates
|
15 |
+
grid_y: torch.Tensor
|
16 |
+
The grid y coordinates
|
17 |
+
Returns
|
18 |
+
----------
|
19 |
+
loc_x: Tensor
|
20 |
+
The centroid x coordinates
|
21 |
+
loc_y: Tensor
|
22 |
+
The centroid y coordinates
|
23 |
+
grid_x: Tensor
|
24 |
+
grid_y: Tensor
|
25 |
+
"""
|
26 |
+
return_grid = False
|
27 |
+
if grid_x is None or grid_y is None:
|
28 |
+
return_grid = True
|
29 |
+
grid_x, grid_y = torch.meshgrid(torch.arange(maps.shape[2]),
|
30 |
+
torch.arange(maps.shape[3]), indexing='ij')
|
31 |
+
grid_x = grid_x.unsqueeze(0).unsqueeze(0).contiguous().to(maps.device, non_blocking=True)
|
32 |
+
grid_y = grid_y.unsqueeze(0).unsqueeze(0).contiguous().to(maps.device, non_blocking=True)
|
33 |
+
map_sums = maps.sum(3).sum(2).detach()
|
34 |
+
maps_x = grid_x * maps
|
35 |
+
maps_y = grid_y * maps
|
36 |
+
loc_x = maps_x.sum(3).sum(2) / map_sums
|
37 |
+
loc_y = maps_y.sum(3).sum(2) / map_sums
|
38 |
+
if return_grid:
|
39 |
+
return loc_x, loc_y, grid_x, grid_y
|
40 |
+
else:
|
41 |
+
return loc_x, loc_y
|
utils/misc_utils.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from functools import reduce
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
|
10 |
+
def factors(n):
|
11 |
+
return reduce(list.__add__,
|
12 |
+
([i, n // i] for i in range(1, int(n ** 0.5) + 1) if n % i == 0))
|
13 |
+
|
14 |
+
|
15 |
+
def file_line_count(filename: str) -> int:
|
16 |
+
"""Count the number of lines in a file"""
|
17 |
+
with open(filename, 'rb') as f:
|
18 |
+
return sum(1 for _ in f)
|
19 |
+
|
20 |
+
|
21 |
+
def compute_attention(qkv, scale=None):
|
22 |
+
"""
|
23 |
+
Compute attention matrix (same as in the pytorch scaled dot product attention)
|
24 |
+
Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
25 |
+
:param qkv: Query, key and value tensors concatenated along the first dimension
|
26 |
+
:param scale: Scale factor for the attention computation
|
27 |
+
:return:
|
28 |
+
"""
|
29 |
+
if isinstance(qkv, torch.Tensor):
|
30 |
+
query, key, value = qkv.unbind(0)
|
31 |
+
else:
|
32 |
+
query, key, value = qkv
|
33 |
+
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
34 |
+
L, S = query.size(-2), key.size(-2)
|
35 |
+
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
|
36 |
+
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
37 |
+
attn_weight += attn_bias
|
38 |
+
attn_weight = torch.softmax(attn_weight, dim=-1)
|
39 |
+
attn_out = attn_weight @ value
|
40 |
+
return attn_weight, attn_out
|
41 |
+
|
42 |
+
|
43 |
+
def compute_dot_product_similarity(a, b):
|
44 |
+
scores = a @ b.transpose(-1, -2)
|
45 |
+
return scores
|
46 |
+
|
47 |
+
|
48 |
+
def compute_cross_entropy(p, q):
|
49 |
+
q = torch.nn.functional.log_softmax(q, dim=-1)
|
50 |
+
loss = torch.sum(p * q, dim=-1)
|
51 |
+
return - loss.mean()
|
52 |
+
|
53 |
+
|
54 |
+
def rollout(attentions, discard_ratio=0.9, head_fusion="max", device=torch.device("cuda")):
|
55 |
+
"""
|
56 |
+
Perform attention rollout,
|
57 |
+
Ref: https://github.com/jacobgil/vit-explain/blob/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/vit_rollout.py#L9C1-L42C16
|
58 |
+
Parameters
|
59 |
+
----------
|
60 |
+
attentions : list
|
61 |
+
List of attention matrices, one for each transformer layer
|
62 |
+
discard_ratio : float
|
63 |
+
Ratio of lowest attention values to discard
|
64 |
+
head_fusion : str
|
65 |
+
Type of fusion to use for attention heads. One of "mean", "max", "min"
|
66 |
+
device : torch.device
|
67 |
+
Device to use for computation
|
68 |
+
Returns
|
69 |
+
-------
|
70 |
+
mask : np.ndarray
|
71 |
+
Mask of shape (width, width), where width is the square root of the number of patches
|
72 |
+
"""
|
73 |
+
result = torch.eye(attentions[0].size(-1), device=device)
|
74 |
+
attentions = [attention.to(device) for attention in attentions]
|
75 |
+
with torch.no_grad():
|
76 |
+
for attention in attentions:
|
77 |
+
if head_fusion == "mean":
|
78 |
+
attention_heads_fused = attention.mean(axis=1)
|
79 |
+
elif head_fusion == "max":
|
80 |
+
attention_heads_fused = attention.max(axis=1).values
|
81 |
+
elif head_fusion == "min":
|
82 |
+
attention_heads_fused = attention.min(axis=1).values
|
83 |
+
else:
|
84 |
+
raise "Attention head fusion type Not supported"
|
85 |
+
|
86 |
+
# Drop the lowest attentions, but
|
87 |
+
# don't drop the class token
|
88 |
+
flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
|
89 |
+
_, indices = flat.topk(int(flat.size(-1) * discard_ratio), -1, False)
|
90 |
+
indices = indices[indices != 0]
|
91 |
+
flat[0, indices] = 0
|
92 |
+
|
93 |
+
I = torch.eye(attention_heads_fused.size(-1), device=device)
|
94 |
+
a = (attention_heads_fused + 1.0 * I) / 2
|
95 |
+
a = a / a.sum(dim=-1)
|
96 |
+
|
97 |
+
result = torch.matmul(a, result)
|
98 |
+
|
99 |
+
# Normalize the result by max value in each row
|
100 |
+
result = result / result.max(dim=-1, keepdim=True)[0]
|
101 |
+
return result
|
102 |
+
|
103 |
+
|
104 |
+
def sync_bn_conversion(model: torch.nn.Module):
|
105 |
+
"""
|
106 |
+
Convert BatchNorm to SyncBatchNorm (used for DDP)
|
107 |
+
:param model: PyTorch model
|
108 |
+
:return:
|
109 |
+
model: PyTorch model with SyncBatchNorm layers
|
110 |
+
"""
|
111 |
+
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
112 |
+
return model
|
113 |
+
|
114 |
+
|
115 |
+
def check_snapshot(args):
|
116 |
+
"""
|
117 |
+
Create directory to save training checkpoints, otherwise load the existing checkpoint.
|
118 |
+
Additionally, if it is an array training job, create a new directory for each training job.
|
119 |
+
:param args: Arguments from the argument parser
|
120 |
+
:return:
|
121 |
+
"""
|
122 |
+
# Check if it is an array training job (i.e. training with multiple random seeds on the same settings)
|
123 |
+
if args.array_training_job and not args.resume_training:
|
124 |
+
args.snapshot_dir = os.path.join(args.snapshot_dir, str(args.seed))
|
125 |
+
if not os.path.exists(args.snapshot_dir):
|
126 |
+
save_dir = Path(args.snapshot_dir)
|
127 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
128 |
+
else:
|
129 |
+
# Create directory to save training checkpoints, otherwise load the existing checkpoint
|
130 |
+
if not os.path.exists(args.snapshot_dir):
|
131 |
+
if ".pt" not in args.snapshot_dir or ".pth" not in args.snapshot_dir:
|
132 |
+
save_dir = Path(args.snapshot_dir)
|
133 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
134 |
+
else:
|
135 |
+
raise ValueError('Snapshot checkpoint does not exist.')
|
utils/visualize_att_maps.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
3 |
+
import colorcet as cc
|
4 |
+
import numpy as np
|
5 |
+
import skimage
|
6 |
+
from pathlib import Path
|
7 |
+
import os
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from utils.data_utils.transform_utils import inverse_normalize_w_resize
|
11 |
+
from utils.misc_utils import factors
|
12 |
+
|
13 |
+
# Define the colors to use for the attention maps
|
14 |
+
colors = cc.glasbey_category10
|
15 |
+
|
16 |
+
|
17 |
+
class VisualizeAttentionMaps:
|
18 |
+
def __init__(self, snapshot_dir="", save_resolution=(256, 256), alpha=0.5, sub_path_test="",
|
19 |
+
dataset_name="", bg_label=0, batch_size=32, num_parts=15, plot_ims_separately=False,
|
20 |
+
plot_landmark_amaps=False):
|
21 |
+
"""
|
22 |
+
Plot attention maps and optionally landmark centroids on images.
|
23 |
+
:param snapshot_dir: Directory to save the visualization results
|
24 |
+
:param save_resolution: Size of the images to save
|
25 |
+
:param alpha: The transparency of the attention maps
|
26 |
+
:param sub_path_test: The sub-path of the test dataset
|
27 |
+
:param dataset_name: The name of the dataset
|
28 |
+
:param bg_label: The background label index in the attention maps
|
29 |
+
:param batch_size: The batch size
|
30 |
+
:param num_parts: The number of parts in the attention maps
|
31 |
+
:param plot_ims_separately: Whether to plot the images separately
|
32 |
+
:param plot_landmark_amaps: Whether to plot the landmark attention maps
|
33 |
+
"""
|
34 |
+
self.save_resolution = save_resolution
|
35 |
+
self.alpha = alpha
|
36 |
+
self.sub_path_test = sub_path_test
|
37 |
+
self.dataset_name = dataset_name
|
38 |
+
self.bg_label = bg_label
|
39 |
+
self.snapshot_dir = snapshot_dir
|
40 |
+
|
41 |
+
self.resize_unnorm = inverse_normalize_w_resize(resize_resolution=self.save_resolution)
|
42 |
+
self.batch_size = batch_size
|
43 |
+
self.nrows = factors(self.batch_size)[-1]
|
44 |
+
self.ncols = factors(self.batch_size)[-2]
|
45 |
+
self.num_parts = num_parts
|
46 |
+
self.req_colors = colors[:num_parts]
|
47 |
+
self.plot_ims_separately = plot_ims_separately
|
48 |
+
self.plot_landmark_amaps = plot_landmark_amaps
|
49 |
+
if self.nrows == 1 and self.ncols == 1:
|
50 |
+
self.figs_size = (10, 10)
|
51 |
+
else:
|
52 |
+
self.figs_size = (self.ncols * 2, self.nrows * 2)
|
53 |
+
|
54 |
+
def recalculate_nrows_ncols(self):
|
55 |
+
self.nrows = factors(self.batch_size)[-1]
|
56 |
+
self.ncols = factors(self.batch_size)[-2]
|
57 |
+
if self.nrows == 1 and self.ncols == 1:
|
58 |
+
self.figs_size = (10, 10)
|
59 |
+
else:
|
60 |
+
self.figs_size = (self.ncols * 2, self.nrows * 2)
|
61 |
+
|
62 |
+
@torch.no_grad()
|
63 |
+
def show_maps(self, ims, maps, epoch=0, curr_iter=0, extra_info=""):
|
64 |
+
"""
|
65 |
+
Plot images, attention maps and landmark centroids.
|
66 |
+
Parameters
|
67 |
+
----------
|
68 |
+
ims: Tensor, [batch_size, 3, width_im, height_im]
|
69 |
+
Input images on which to show the attention maps
|
70 |
+
maps: Tensor, [batch_size, number of parts + 1, width_map, height_map]
|
71 |
+
The attention maps to display
|
72 |
+
epoch: int
|
73 |
+
The epoch number
|
74 |
+
curr_iter: int
|
75 |
+
The current iteration number
|
76 |
+
extra_info: str
|
77 |
+
Any extra information to add to the file name
|
78 |
+
"""
|
79 |
+
ims = self.resize_unnorm(ims)
|
80 |
+
if ims.shape[0] != self.batch_size:
|
81 |
+
self.batch_size = ims.shape[0]
|
82 |
+
self.recalculate_nrows_ncols()
|
83 |
+
fig, axs = plt.subplots(nrows=self.nrows, ncols=self.ncols, squeeze=False, figsize=self.figs_size)
|
84 |
+
ims = (ims.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
|
85 |
+
map_argmax = torch.nn.functional.interpolate(maps.clone().detach(), size=self.save_resolution,
|
86 |
+
mode='bilinear',
|
87 |
+
align_corners=True).argmax(dim=1).cpu().numpy()
|
88 |
+
for i, ax in enumerate(axs.ravel()):
|
89 |
+
curr_map = skimage.color.label2rgb(label=map_argmax[i], image=ims[i], colors=self.req_colors,
|
90 |
+
bg_label=self.bg_label, alpha=self.alpha)
|
91 |
+
ax.imshow(curr_map)
|
92 |
+
ax.axis('off')
|
93 |
+
save_dir = Path(os.path.join(self.snapshot_dir, 'results_vis_' + self.sub_path_test))
|
94 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
95 |
+
save_path = os.path.join(save_dir, f'{epoch}_{curr_iter}_{self.dataset_name}{extra_info}.png')
|
96 |
+
fig.tight_layout()
|
97 |
+
if self.snapshot_dir != "":
|
98 |
+
plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
|
99 |
+
else:
|
100 |
+
plt.show()
|
101 |
+
plt.close('all')
|
102 |
+
|
103 |
+
if self.plot_ims_separately:
|
104 |
+
fig, axs = plt.subplots(nrows=self.nrows, ncols=self.ncols, squeeze=False, figsize=self.figs_size)
|
105 |
+
for i, ax in enumerate(axs.ravel()):
|
106 |
+
ax.imshow(ims[i])
|
107 |
+
ax.axis('off')
|
108 |
+
save_path = os.path.join(save_dir, f'image_{epoch}_{curr_iter}_{self.dataset_name}{extra_info}.jpg')
|
109 |
+
fig.tight_layout()
|
110 |
+
if self.snapshot_dir != "":
|
111 |
+
plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
|
112 |
+
else:
|
113 |
+
plt.show()
|
114 |
+
plt.close('all')
|
115 |
+
|
116 |
+
if self.plot_landmark_amaps:
|
117 |
+
if self.batch_size > 1:
|
118 |
+
raise ValueError('Not implemented for batch size > 1')
|
119 |
+
for i in range(self.num_parts):
|
120 |
+
fig, ax = plt.subplots(1, 1, figsize=self.figs_size)
|
121 |
+
divider = make_axes_locatable(ax)
|
122 |
+
cax = divider.append_axes('right', size='5%', pad=0.05)
|
123 |
+
im = ax.imshow(maps[0, i, ...].detach().cpu().numpy(), cmap='cet_gouldian')
|
124 |
+
fig.colorbar(im, cax=cax, orientation='vertical')
|
125 |
+
ax.axis('off')
|
126 |
+
save_path = os.path.join(save_dir,
|
127 |
+
f'landmark_{i}_{epoch}_{curr_iter}_{self.dataset_name}{extra_info}.png')
|
128 |
+
fig.tight_layout()
|
129 |
+
if self.snapshot_dir != "":
|
130 |
+
plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
|
131 |
+
else:
|
132 |
+
plt.show()
|
133 |
+
plt.close()
|
134 |
+
|
135 |
+
plt.close('all')
|