# -*- coding: utf-8 -*- # Author: Gaojian Wang@ZJUICSR # -------------------------------------------------------- # This source code is licensed under the Attribution-NonCommercial 4.0 International License. # You can find the license in the LICENSE file in the root directory of this source tree. # -------------------------------------------------------- from functools import partial import torch import torch.nn as nn import timm.models.vision_transformer class VisionTransformer(timm.models.vision_transformer.VisionTransformer): """ Vision Transformer with support for global average pooling """ def __init__(self, global_pool=False, **kwargs): super(VisionTransformer, self).__init__(**kwargs) self.global_pool = global_pool if self.global_pool: norm_layer = kwargs['norm_layer'] embed_dim = kwargs['embed_dim'] self.fc_norm = norm_layer(embed_dim) del self.norm # remove the original norm def forward_features(self, x): B = x.shape[0] x = self.patch_embed(x) cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) x = x + self.pos_embed x = self.pos_drop(x) for blk in self.blocks: x = blk(x) if self.global_pool: x = x[:, 1:, :].mean(dim=1) # global pool without cls token outcome = self.fc_norm(x) else: x = self.norm(x) outcome = x[:, 0] return outcome def vit_small_patch16(**kwargs): model = VisionTransformer( patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, # ViT-small config in MOCO_V3 # patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, qkv_bias=True, # ViT-small config in timm norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model def vit_base_patch16(**kwargs): model = VisionTransformer( patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model def vit_large_patch16(**kwargs): model = VisionTransformer( patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model def vit_huge_patch14(**kwargs): model = VisionTransformer( patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model # class VisionTransformerWithProjector(VisionTransformer): # def __init__(self, vit_model, model_encoder, feat_cl_dim=128): # super(VisionTransformerWithProjector, self).__init__() # self.encoder = vit_model # embed_dim = {'vit_base_patch16': 768, 'vit_large_patch16': 1024, 'vit_huge_patch14': 1280} # self.projection_head = nn.Sequential( # nn.Linear(embed_dim[model_encoder], embed_dim[model_encoder]), # nn.ReLU(inplace=True), # nn.Linear(embed_dim[model_encoder], feat_cl_dim) # ) # # def forward(self, x): # x = self.encoder(x) # latent_cl = self.projection_head(x) # [N, feat_cl_dim] # features = nn.functional.normalize(latent_cl, dim=-1) # [N, feat_cl_dim] # return features