gheinrich commited on
Commit
db40549
·
1 Parent(s): 6b8dbd3

Upload model

Browse files
cls_token.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class ClsToken(nn.Module):
6
+ def __init__(self, ndim: int,
7
+ num_tokens: int = 1,
8
+ enabled: bool = True,
9
+ register_multiple: int = 0,
10
+ ):
11
+ super().__init__()
12
+
13
+ self.ndim = ndim
14
+ self.enabled = enabled
15
+ self.num_registers = 0
16
+ self.num_tokens = num_tokens
17
+ if enabled:
18
+ if register_multiple > 0:
19
+ self.num_registers = register_multiple - (num_tokens % register_multiple)
20
+
21
+ scale = ndim ** -0.5
22
+ self.token = nn.Parameter(torch.randn(num_tokens + self.num_registers, ndim) * scale)
23
+ else:
24
+ self.token = None
25
+
26
+ self.num_patches = self.num_tokens + self.num_registers
27
+
28
+ def disable(self):
29
+ self.token = None
30
+ self.enabled = False
31
+
32
+ def forward(self, x: torch.Tensor):
33
+ if self.token is None:
34
+ return x
35
+
36
+ token = self.token.unsqueeze(0).expand(x.shape[0], -1, -1)
37
+ x = torch.cat([
38
+ token,
39
+ x,
40
+ ], dim=1)
41
+
42
+ return x
43
+
44
+ def no_weight_decay(self):
45
+ return [
46
+ 'token',
47
+ ]
config.json ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RADIOModel"
4
+ ],
5
+ "args": {
6
+ "aa": null,
7
+ "amp": true,
8
+ "amp_dtype": "bfloat16",
9
+ "amp_impl": "native",
10
+ "aug_repeats": 0,
11
+ "aug_splits": 0,
12
+ "auto_loss_balance_mode": "adaloss",
13
+ "batch_size": 32,
14
+ "bn_eps": null,
15
+ "bn_momentum": null,
16
+ "cache_dir": null,
17
+ "channels_last": false,
18
+ "checkpoint_hist": 10,
19
+ "class_map": "",
20
+ "clip_grad": null,
21
+ "clip_mode": "norm",
22
+ "cls_token_per_teacher": true,
23
+ "coco_annotations_file": "/datasets/coco2017-adlsa/annotations/captions_val2017.json",
24
+ "coco_image_dir": "/datasets/coco2017-adlsa/val2017",
25
+ "color_jitter": 0.4,
26
+ "cooldown_epochs": 0,
27
+ "cpe_max_size": 1050,
28
+ "crd_loss": false,
29
+ "crd_loss_weight": 0.8,
30
+ "crop_pct": null,
31
+ "cutmix": 0.0,
32
+ "cutmix_minmax": null,
33
+ "data_dir": "/lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/captioning/datacomp/dc1b/stage2",
34
+ "dataset": "nvgpt4",
35
+ "dataset_download": false,
36
+ "debug_full_knn": false,
37
+ "decay_epochs": 90,
38
+ "decay_milestones": [
39
+ 90,
40
+ 180,
41
+ 270
42
+ ],
43
+ "decay_rate": 0.1,
44
+ "device": "cuda:0",
45
+ "dist_bn": "reduce",
46
+ "distributed": true,
47
+ "drop": 0.0,
48
+ "drop_block": null,
49
+ "drop_connect": null,
50
+ "drop_path": null,
51
+ "epoch_repeats": 0.0,
52
+ "epochs": 300,
53
+ "eval": false,
54
+ "eval_metric": "knn_top1",
55
+ "eval_teacher": false,
56
+ "eval_teacher_only": false,
57
+ "eval_throughput": false,
58
+ "experiment": "checkpoints",
59
+ "fast_norm": false,
60
+ "feature_summarizer": "cls_token",
61
+ "feature_upscale_factor": null,
62
+ "fuser": "",
63
+ "gp": "avg",
64
+ "grad_accum_steps": 1,
65
+ "grad_checkpointing": false,
66
+ "head_init_bias": null,
67
+ "head_init_scale": null,
68
+ "hflip": 0.5,
69
+ "img_size": null,
70
+ "in_chans": null,
71
+ "initial_checkpoint": "",
72
+ "input_size": null,
73
+ "interpolation": "",
74
+ "layer_decay": null,
75
+ "local_rank": 0,
76
+ "log_interval": 50,
77
+ "log_mlflow": false,
78
+ "log_wandb": true,
79
+ "loss": "cosine",
80
+ "loss_auto_balance": false,
81
+ "lr": 0.001,
82
+ "lr_base": 0.1,
83
+ "lr_base_scale": "",
84
+ "lr_base_size": 256,
85
+ "lr_cycle_decay": 0.5,
86
+ "lr_cycle_limit": 1,
87
+ "lr_cycle_mul": 1.0,
88
+ "lr_k_decay": 1.0,
89
+ "lr_noise": null,
90
+ "lr_noise_pct": 0.67,
91
+ "lr_noise_std": 1.0,
92
+ "mean": null,
93
+ "min_lr": 0,
94
+ "mixup": 0.0,
95
+ "mixup_mode": "batch",
96
+ "mixup_off_epoch": 0,
97
+ "mixup_prob": 1.0,
98
+ "mixup_switch_prob": 0.5,
99
+ "mlp_hidden_size": 1520,
100
+ "mlp_num_inner": 3,
101
+ "mlp_version": "v2",
102
+ "model": "vit_huge_patch14_224",
103
+ "model_ema": false,
104
+ "model_ema_decay": 0.9998,
105
+ "model_ema_force_cpu": false,
106
+ "model_kwargs": {},
107
+ "momentum": 0.9,
108
+ "no_aug": false,
109
+ "no_ddp_bb": false,
110
+ "no_prefetcher": false,
111
+ "no_resume_opt": false,
112
+ "num_classes": null,
113
+ "opt": "fusedlamb",
114
+ "opt_betas": null,
115
+ "opt_eps": null,
116
+ "opt_kwargs": {},
117
+ "output": "/lustre/fs6/portfolios/llmservice/users/mranzinger/output/evfm/dfn_oai/11-29-23_vit-h-14-cpe_dfn-oai-dino_maxres",
118
+ "patience_epochs": 10,
119
+ "pin_mem": false,
120
+ "prefetcher": true,
121
+ "pretrained": false,
122
+ "rank": 0,
123
+ "ratio": [
124
+ 0.75,
125
+ 1.3333333333333333
126
+ ],
127
+ "recount": 1,
128
+ "recovery_interval": 0,
129
+ "register_multiple": 8,
130
+ "remode": "pixel",
131
+ "reprob": 0.0,
132
+ "resplit": false,
133
+ "resume": "/lustre/fs6/portfolios/llmservice/users/mranzinger/output/evfm/dfn_oai/11-29-23_vit-h-14-cpe_dfn-oai-dino_maxres/checkpoints/last.pth.tar",
134
+ "save_images": false,
135
+ "scale": [
136
+ 0.5,
137
+ 1.0
138
+ ],
139
+ "sched": "cosine",
140
+ "sched_on_updates": true,
141
+ "seed": 42,
142
+ "smoothing": 0.1,
143
+ "split_bn": false,
144
+ "start_epoch": null,
145
+ "std": null,
146
+ "steps_per_epoch": 2000,
147
+ "sync_bn": false,
148
+ "synchronize_step": false,
149
+ "teachers": [
150
+ {
151
+ "amp": true,
152
+ "amp_dtype": "bfloat16",
153
+ "batch_size": 16,
154
+ "fd_loss_weight": 1.0,
155
+ "fd_normalize": false,
156
+ "feature_distillation": true,
157
+ "input_size": 378,
158
+ "model": "ViT-H-14-378-quickgelu",
159
+ "name": "clip",
160
+ "pretrained": "dfn5b",
161
+ "sample_rate": 16,
162
+ "summary_loss_weight": 1.0,
163
+ "type": "open_clip",
164
+ "vitdet_prob": 0.05,
165
+ "vitdet_window_sizes": [
166
+ 3,
167
+ 9,
168
+ 9,
169
+ 9
170
+ ]
171
+ },
172
+ {
173
+ "amp": false,
174
+ "amp_dtype": "bfloat16",
175
+ "batch_size": 16,
176
+ "fd_loss_weight": 0.8,
177
+ "fd_normalize": false,
178
+ "feature_distillation": true,
179
+ "input_size": 336,
180
+ "model": "ViT-L/14@336px",
181
+ "name": "openai_clip",
182
+ "pretrained": "openai",
183
+ "sample_rate": 16,
184
+ "summary_loss_weight": 0.8,
185
+ "type": "openai_clip",
186
+ "use_summary": false
187
+ },
188
+ {
189
+ "amp": true,
190
+ "amp_dtype": "bfloat16",
191
+ "batch_size": 16,
192
+ "fd_loss_weight": 1.0,
193
+ "fd_normalize": false,
194
+ "feature_distillation": true,
195
+ "input_size": 224,
196
+ "model": "dinov2_vitg14",
197
+ "name": "dino_v2",
198
+ "sample_rate": 16,
199
+ "summary_loss_weight": 1.0,
200
+ "type": "dino_v2"
201
+ }
202
+ ],
203
+ "torchcompile": null,
204
+ "torchscript": false,
205
+ "train_interpolation": "random",
206
+ "train_split": "train",
207
+ "tta": 0,
208
+ "use_coco": false,
209
+ "use_multi_epochs_loader": false,
210
+ "val_data_dir": "/lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/classification/imagenet-1k/webdataset",
211
+ "val_img_size": 378,
212
+ "val_split": "val",
213
+ "validation_batch_size": 128,
214
+ "vflip": 0.0,
215
+ "wandb_entity": "",
216
+ "wandb_group": "dfn_oai",
217
+ "wandb_job_type": "",
218
+ "wandb_name": "",
219
+ "wandb_project": "",
220
+ "warmup_epochs": 2.5,
221
+ "warmup_lr": 1e-05,
222
+ "warmup_prefix": false,
223
+ "weight_decay": 2e-05,
224
+ "worker_seeding": "all",
225
+ "workers": 4,
226
+ "world_size": 64
227
+ },
228
+ "auto_map": {
229
+ "AutoConfig": "hf_model.RADIOConfig",
230
+ "AutoModel": "hf_model.RADIOModel"
231
+ },
232
+ "torch_dtype": "float32",
233
+ "transformers_version": "4.29.0",
234
+ "version": "v1"
235
+ }
enable_cpe_support.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Tuple
2
+ from types import MethodType
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from timm.models import VisionTransformer, checkpoint_seq
8
+
9
+ from .vit_patch_generator import ViTPatchGenerator
10
+
11
+
12
+ def _forward_cpe(self: VisionTransformer, x: torch.Tensor) -> torch.Tensor:
13
+ x = self.patch_generator(x)
14
+ if self.grad_checkpointing and not torch.jit.is_scripting():
15
+ x = checkpoint_seq(self.blocks, x)
16
+ else:
17
+ x = self.blocks(x)
18
+ x = self.norm(x)
19
+ return x
20
+
21
+
22
+ def enable_cpe(model: nn.Module,
23
+ max_img_size: Union[int, Tuple[int, int]] = 1024,
24
+ num_cls_tokens: int = 1,
25
+ pos_dropout: float = 0.1,
26
+ register_multiple: int = 0,
27
+ ):
28
+ if not isinstance(model, VisionTransformer):
29
+ raise ValueError("CPE only support for VisionTransformer models!")
30
+
31
+ patch_size = model.patch_embed.patch_size[0]
32
+ embed_dim = model.embed_dim
33
+ input_dims = model.patch_embed.img_size
34
+ normalize_patches = not isinstance(model.patch_embed.norm, nn.Identity)
35
+ cls_token = model.cls_token is not None
36
+
37
+ max_img_size = int(round(max_img_size / patch_size) * patch_size)
38
+
39
+ patch_generator = ViTPatchGenerator(
40
+ patch_size=patch_size,
41
+ embed_dim=embed_dim,
42
+ input_dims=input_dims,
43
+ normalize_patches=normalize_patches,
44
+ cls_token=cls_token,
45
+ max_input_dims=max_img_size,
46
+ pos_dropout=pos_dropout,
47
+ num_cls_tokens=num_cls_tokens,
48
+ register_multiple=register_multiple,
49
+ )
50
+
51
+ model.patch_generator = patch_generator
52
+ model.patch_embed = None
53
+ model.cls_token = None
54
+ model.pos_embed = None
55
+ model.pos_drop = None
56
+ model.num_cls_tokens = num_cls_tokens
57
+ model.num_registers = patch_generator.num_registers
58
+
59
+ model.forward_features = MethodType(_forward_cpe, model)
hf_model.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from collections import namedtuple
15
+ from typing import Optional
16
+
17
+ from timm.models import VisionTransformer
18
+ import torch
19
+ from transformers import PretrainedConfig, PreTrainedModel
20
+
21
+
22
+ from .model import create_model_from_args
23
+ from .input_conditioner import get_default_conditioner, InputConditioner
24
+
25
+
26
+ resource_map = {
27
+ 'radio_v1': 'https://huggingface.co/nvidia/RADIO/raw/main/radio_v1.pth.tar'
28
+ }
29
+
30
+
31
+ class RADIOConfig(PretrainedConfig):
32
+ """Pretrained Hugging Face configuration for RADIO models."""
33
+
34
+ def __init__(
35
+ self,
36
+ args: Optional[dict] = None,
37
+ version: Optional[str]="v1",
38
+ **kwargs,
39
+ ):
40
+ self.args = args
41
+ self.version = version
42
+ super().__init__(**kwargs)
43
+
44
+
45
+ class RADIOModel(PreTrainedModel):
46
+ """Pretrained Hugging Face model for RADIO."""
47
+
48
+ def __init__(self, config):
49
+ super().__init__(config)
50
+
51
+ RADIOArgs = namedtuple("RADIOArgs", config.args.keys())
52
+ args = RADIOArgs(**config.args)
53
+ self.model = create_model_from_args(args)
54
+
55
+ self.input_conditioner: InputConditioner = get_default_conditioner()
56
+
57
+ #return RADIOModel(mod, conditioner, return_summary=return_summary, return_spatial_features=return_spatial_features)
58
+
59
+ def forward(self, x: torch.Tensor):
60
+ x = self.input_conditioner(x)
61
+
62
+ y = self.model.forward_features(x)
63
+
64
+ if isinstance(y, (list, tuple)):
65
+ summary, all_feat = y
66
+ elif isinstance(self.model, VisionTransformer):
67
+ patch_gen = getattr(self.model, 'patch_generator', None)
68
+ if patch_gen is not None:
69
+ summary = y[:, :patch_gen.num_cls_tokens].flatten(1)
70
+ all_feat = y[:, patch_gen.num_skip:]
71
+ elif self.model.global_pool == 'avg':
72
+ summary = y[:, self.model.num_prefix_tokens:].mean(dim=1)
73
+ all_feat = y
74
+ else:
75
+ summary = y[:, 0]
76
+ all_feat = y[:, 1:]
77
+ else:
78
+ raise ValueError("Unsupported model type")
79
+
80
+ if self.return_summary and self.return_spatial_features:
81
+ return summary, all_feat
82
+ elif self.return_summary:
83
+ return summary
84
+ return all_feat
input_conditioner.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ norm_t = Union[Tuple[float, float, float], torch.Tensor]
8
+
9
+ class InputConditioner(nn.Module):
10
+ def __init__(self,
11
+ input_scale: float,
12
+ norm_mean: norm_t,
13
+ norm_std: norm_t,
14
+ dtype: torch.dtype = torch.float32,
15
+ ):
16
+ super().__init__()
17
+
18
+ self.dtype = dtype
19
+
20
+ # self.input_scale = input_scale
21
+ self.register_buffer("norm_mean", _to_tensor(norm_mean) / input_scale)
22
+ self.register_buffer("norm_std", _to_tensor(norm_std) / input_scale)
23
+
24
+ def forward(self, x: torch.Tensor):
25
+ # x = x * self.input_scale
26
+ y = (x - self.norm_mean) / self.norm_std
27
+ return y.to(self.dtype)
28
+
29
+
30
+ def get_default_conditioner():
31
+ from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
32
+
33
+ return InputConditioner(
34
+ input_scale=1.0,
35
+ norm_mean=OPENAI_CLIP_MEAN,
36
+ norm_std=OPENAI_CLIP_STD,
37
+ )
38
+
39
+
40
+ def _to_tensor(v: norm_t):
41
+ return torch.as_tensor(v, dtype=torch.float32).view(-1, 1, 1)
model.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ from timm.models import create_model
4
+
5
+ from .enable_cpe_support import enable_cpe
6
+
7
+
8
+ def create_model_from_args(args) -> nn.Module:
9
+ in_chans = 3
10
+ if args.in_chans is not None:
11
+ in_chans = args.in_chans
12
+ elif args.input_size is not None:
13
+ in_chans = args.input_size[0]
14
+
15
+ model = create_model(
16
+ args.model,
17
+ pretrained=args.pretrained,
18
+ in_chans=in_chans,
19
+ num_classes=args.num_classes,
20
+ drop_rate=args.drop,
21
+ drop_path_rate=args.drop_path,
22
+ drop_block_rate=args.drop_block,
23
+ global_pool=args.gp,
24
+ bn_momentum=args.bn_momentum,
25
+ bn_eps=args.bn_eps,
26
+ scriptable=args.torchscript,
27
+ checkpoint_path=args.initial_checkpoint,
28
+ **args.model_kwargs,
29
+ )
30
+
31
+ assert not args.cls_token_per_teacher or args.cpe_max_size is not None, "CPE must be enabled for multiple CLS tokens!"
32
+
33
+ if args.cpe_max_size is not None:
34
+ enable_cpe(model,
35
+ args.cpe_max_size,
36
+ num_cls_tokens=len(args.teachers) if args.cls_token_per_teacher else 1,
37
+ register_multiple=args.register_multiple,
38
+ )
39
+
40
+ return model
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:242360b04b7f78204b535ce8a96e28ef3316520d55be43e6873fd45696fb9d61
3
+ size 2662619441
vit_patch_generator.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Union, Tuple, Optional
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+ from einops import rearrange
8
+
9
+ from .cls_token import ClsToken
10
+
11
+ input_dim_t = Union[int, Tuple[int, int]]
12
+
13
+ try:
14
+ # raise ImportError()
15
+ from indirect_grid_sample import indirect_grid_sample
16
+ except ImportError:
17
+ indirect_grid_sample = None
18
+
19
+ class ViTPatchGenerator(nn.Module):
20
+ def __init__(self,
21
+ patch_size: int,
22
+ embed_dim: int,
23
+ input_dims: input_dim_t,
24
+ abs_pos: bool = True,
25
+ normalize_patches: bool = False,
26
+ cls_token: bool = False,
27
+ max_input_dims: Optional[input_dim_t] = None,
28
+ pos_dropout: float = 0.0,
29
+ return_pos_enc: bool = False,
30
+ num_cls_tokens: int = 1,
31
+ register_multiple: int = 0,
32
+ device=None, dtype=None,
33
+ ):
34
+ super().__init__()
35
+
36
+ if isinstance(input_dims, int):
37
+ input_dims = (input_dims, input_dims)
38
+
39
+ if max_input_dims is None:
40
+ max_input_dims = input_dims
41
+ if isinstance(max_input_dims, int):
42
+ max_input_dims = (max_input_dims, max_input_dims)
43
+
44
+ max_input_dims = tuple(
45
+ int(math.ceil(d / patch_size) * patch_size)
46
+ for d in max_input_dims
47
+ )
48
+
49
+ self.cpe_mode = max_input_dims != input_dims
50
+ self.pos_dropout = pos_dropout
51
+ self.return_pos_enc = return_pos_enc
52
+
53
+ factory = dict(device=device, dtype=dtype)
54
+
55
+ self.patch_size = patch_size
56
+ self.abs_pos = abs_pos
57
+ self.embed_dim = embed_dim
58
+
59
+ self.num_rows = max_input_dims[0] // patch_size
60
+ self.num_cols = max_input_dims[1] // patch_size
61
+ self.input_dims = tuple(d // patch_size for d in input_dims)
62
+ self.num_patches = self.num_rows * self.num_cols
63
+ self.max_input_dims = max_input_dims
64
+
65
+ self.im_to_patches = Im2Patches(patch_size)
66
+ self.embedder = ViTPatchLinear(patch_size, embed_dim, **factory)
67
+
68
+ if abs_pos:
69
+ scale = embed_dim ** -0.5
70
+ self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, embed_dim, **factory) * scale)
71
+
72
+ self.cls_token = ClsToken(
73
+ embed_dim,
74
+ num_tokens=num_cls_tokens,
75
+ enabled=cls_token,
76
+ register_multiple=register_multiple,
77
+ )
78
+
79
+ self.patch_normalizer = nn.LayerNorm(embed_dim) if normalize_patches else nn.Identity()
80
+
81
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
82
+ patches = self.embed_patches(x)
83
+ patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:])
84
+ patches = self.cls_token(patches)
85
+ patches = self.patch_normalizer(patches)
86
+ if self.return_pos_enc:
87
+ return patches, pos_enc
88
+ return patches
89
+
90
+ @property
91
+ def apply_cls_token(self):
92
+ return self.cls_token.enabled
93
+
94
+ @property
95
+ def num_cls_tokens(self):
96
+ return self.cls_token.num_tokens
97
+
98
+ @property
99
+ def num_registers(self):
100
+ return self.cls_token.num_registers
101
+
102
+ @property
103
+ def num_skip(self):
104
+ return self.num_cls_tokens + self.num_registers
105
+
106
+ def no_weight_decay(self):
107
+ return [
108
+ 'pos_embed',
109
+ ]
110
+
111
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
112
+ if self.abs_pos:
113
+ self._load_embed(state_dict[f'{prefix}pos_embed'], self.pos_embed)
114
+
115
+ def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):
116
+ if src_embed.shape != targ_embed.shape:
117
+ src_size = int(math.sqrt(src_embed.shape[1]))
118
+
119
+ assert src_size ** 2 == src_embed.shape[1], 'Unable to interpolate non-square embedding'
120
+
121
+ src_embed = rearrange(src_embed, 'b (h w) c -> b c h w', h=src_size, w=src_size)
122
+ src_embed = F.interpolate(src_embed, size=(self.num_rows, self.num_cols), mode='bicubic', align_corners=True, antialias=False)
123
+ src_embed = rearrange(src_embed, 'b c h w -> b (h w) c')
124
+ targ_embed.data.copy_(src_embed)
125
+
126
+ def _load_projection(self, src_proj_weight: torch.Tensor, targ_proj_weight: torch.Tensor):
127
+ if src_proj_weight.shape != targ_proj_weight.shape:
128
+ src_patch_size = int(math.sqrt(src_proj_weight.shape[1] // 3))
129
+
130
+ assert (src_patch_size ** 2) * 3 == src_proj_weight.shape[1], 'Unable to interpolate non-square patch size'
131
+
132
+ src_proj_weight = rearrange(src_proj_weight, 'b (c h w) -> b c h w', c=3, h=src_patch_size, w=src_patch_size)
133
+ src_proj_weight = F.interpolate(src_proj_weight, size=(self.patch_size, self.patch_size), mode='bicubic', align_corners=True, antialias=False)
134
+ src_proj_weight = rearrange(src_proj_weight, 'b c h w -> b (c h w)')
135
+ targ_proj_weight.data.copy_(src_proj_weight)
136
+
137
+ def embed_patches(self, x: torch.Tensor) -> torch.Tensor:
138
+ patches = self.im_to_patches(x)
139
+ patches = self.embedder(patches)
140
+ return patches
141
+
142
+ def apply_pos_enc(self,
143
+ patches: torch.Tensor,
144
+ patch_idxs: Optional[torch.Tensor] = None,
145
+ input_size: Optional[Tuple[int, int]] = None,
146
+ ) -> torch.Tensor:
147
+ if not self.abs_pos:
148
+ return patches
149
+
150
+ pos_enc = self.get_pos_enc(patches.shape[0], patch_idxs, input_size)
151
+
152
+ if self.training and self.pos_dropout > 0:
153
+ keeps = torch.rand(patches.shape[0], 1, 1, dtype=pos_enc.dtype, device=pos_enc.device) > self.pos_dropout
154
+ pos_enc_drop = torch.where(keeps, pos_enc, 0)
155
+ else:
156
+ pos_enc_drop = pos_enc
157
+
158
+ return patches + pos_enc_drop, pos_enc
159
+
160
+ def get_pos_enc(self,
161
+ batch_size: int,
162
+ patch_idxs: Optional[torch.Tensor] = None,
163
+ input_size: Optional[Tuple[int, int]] = None,
164
+ ) -> torch.Tensor:
165
+ if input_size is None:
166
+ input_dims = self.input_dims
167
+ else:
168
+ input_dims = tuple(d // self.patch_size for d in input_size)
169
+
170
+ pos_embed = self._get_pos_embeddings(batch_size, input_dims)
171
+
172
+ if patch_idxs is None:
173
+ return pos_embed
174
+
175
+ exp_patch_idxs = patch_idxs.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1])
176
+
177
+ pos_embed = torch.gather(pos_embed.expand(patch_idxs.shape[0], -1, -1), dim=1, index=exp_patch_idxs)
178
+ return pos_embed
179
+
180
+
181
+ def _get_pos_embeddings(self, batch_size: int, input_dims: Tuple[int, int]):
182
+ if (self.num_rows, self.num_cols) == input_dims:
183
+ return self.pos_embed
184
+
185
+ pos_embed = self.pos_embed.reshape(1, self.num_rows, self.num_cols, -1).permute(0, 3, 1, 2)
186
+
187
+ def window_select(pos_embed):
188
+ if input_dims[0] < pos_embed.shape[-2]:
189
+ pos_embed = pos_embed[..., :input_dims[0], :]
190
+ if input_dims[1] < pos_embed.shape[-1]:
191
+ pos_embed = pos_embed[..., :, :input_dims[1]]
192
+ return pos_embed
193
+
194
+ if self.cpe_mode:
195
+ if self.training:
196
+ min_scale = math.sqrt(0.1)
197
+ scale = torch.rand(batch_size, 1, 1, device=pos_embed.device) * (1 - min_scale) + min_scale
198
+ aspect_min = math.log(3 / 4)
199
+ aspect_max = -aspect_min
200
+ aspect = torch.exp(torch.rand(batch_size, 1, 1, device=pos_embed.device) * (aspect_max - aspect_min) + aspect_min)
201
+
202
+ scale_x = scale * aspect
203
+ scale_y = scale * (1 / aspect)
204
+ scale_xy = torch.stack([scale_x, scale_y], dim=-1).clamp_(0, 1)
205
+
206
+ pos_xy = torch.rand(batch_size, 1, 1, 2, device=pos_embed.device) * (1 - scale_xy)
207
+
208
+ lin_x = torch.linspace(0, 1, steps=input_dims[1], device=pos_embed.device)[None, None].expand(batch_size, input_dims[0], -1)
209
+ lin_y = torch.linspace(0, 1, steps=input_dims[0], device=pos_embed.device)[None, :, None].expand(batch_size, -1, input_dims[1])
210
+
211
+ lin_xy = torch.stack([lin_x, lin_y], dim=-1)
212
+
213
+ grid_xy = lin_xy * scale_xy + pos_xy
214
+
215
+ # Convert to [-1, 1] range
216
+ grid_xy.mul_(2).sub_(1)
217
+
218
+ pos_embed = F.grid_sample(
219
+ pos_embed.expand(batch_size, -1, -1, -1),
220
+ grid=grid_xy,
221
+ mode='bilinear',
222
+ padding_mode='zeros',
223
+ align_corners=True,
224
+ )
225
+ else:
226
+ # i_rows, i_cols = input_dims
227
+ # p_rows, p_cols = pos_embed.shape[2:]
228
+ # if i_rows <= p_rows and i_cols <= p_cols:
229
+ # left = (p_cols - i_cols) // 2
230
+ # top = (p_rows - i_rows) // 2
231
+ # pos_embed = pos_embed[..., top:top+i_rows, left:left+i_cols]
232
+ # else:
233
+ max_dim = max(input_dims)
234
+ pos_embed = F.interpolate(pos_embed, size=(max_dim, max_dim), align_corners=True, mode='bilinear')
235
+
236
+ pos_embed = window_select(pos_embed)
237
+ else:
238
+ pos_embed = window_select(pos_embed)
239
+
240
+ if pos_embed.shape[-2:] != input_dims:
241
+ pos_embed = F.interpolate(pos_embed, size=input_dims, align_corners=True, mode='bilinear')
242
+
243
+ pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
244
+
245
+ return pos_embed
246
+
247
+
248
+ class Im2Patches(nn.Module):
249
+ def __init__(self, patch_size: int):
250
+ super().__init__()
251
+ self.patch_size = patch_size
252
+
253
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
254
+ if self.patch_size == 1:
255
+ patches = x.flatten(2)
256
+ patches = patches.permute(0, 2, 1)
257
+ return patches
258
+
259
+ py = x.shape[-2] // self.patch_size
260
+ px = x.shape[-1] // self.patch_size
261
+ patches = rearrange(x, 'b c (py yy) (px xx) -> b (py px) (c yy xx)',
262
+ py=py, yy=self.patch_size,
263
+ px=px, xx=self.patch_size,
264
+ )
265
+ return patches
266
+
267
+
268
+ class ViTPatchLinear(nn.Linear):
269
+ def __init__(self, patch_size: int, embed_dim: int, **factory):
270
+ super().__init__(
271
+ 3 * (patch_size ** 2),
272
+ embed_dim,
273
+ bias=False,
274
+ **factory
275
+ )
276
+ self.patch_size = patch_size
277
+
278
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
279
+ if self.bias is not None:
280
+ self.bias.data.copy_(state_dict[f'{prefix}bias'])
281
+
282
+ chk_weight = state_dict[f'{prefix}weight']
283
+ if chk_weight.shape != self.weight.shape:
284
+ src_patch_size = int(math.sqrt(chk_weight.shape[1] // 3))
285
+
286
+ assert (src_patch_size ** 2) * 3 == chk_weight.shape[1], 'Unable to interpolate non-square patch size'
287
+
288
+ chk_weight = rearrange(chk_weight, 'b (c h w) -> b c h w', c=3, h=src_patch_size, w=src_patch_size)
289
+ chk_weight = F.interpolate(chk_weight, size=(self.patch_size, self.patch_size), mode='bicubic', align_corners=True, antialias=False)
290
+ chk_weight = rearrange(chk_weight, 'b c h w -> b (c h w)')
291
+ self.weight.data.copy_(chk_weight)