File size: 6,528 Bytes
ff4fdee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import math
from functools import reduce
from operator import mul
from ipdb import set_trace
import torch
import torch.nn.functional as F
import torch.nn as nn
from mmcls.models.backbones import VisionTransformer as _VisionTransformer
from mmcls.models.utils import to_2tuple
from mmcv.cnn.bricks.transformer import PatchEmbed
from torch.nn.modules.batchnorm import _BatchNorm
def build_2d_sincos_position_embedding(patches_resolution,
embed_dims,
temperature=10000.,
cls_token=False):
"""The function is to build position embedding for model to obtain the
position information of the image patches."""
if isinstance(patches_resolution, int):
patches_resolution = (patches_resolution, patches_resolution)
h, w = patches_resolution
grid_w = torch.arange(w, dtype=torch.float32)
grid_h = torch.arange(h, dtype=torch.float32)
grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
assert embed_dims % 4 == 0, \
'Embed dimension must be divisible by 4.'
pos_dim = embed_dims // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = 1. / (temperature**omega)
out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
pos_emb = torch.cat(
[
torch.sin(out_w),
torch.cos(out_w),
torch.sin(out_h),
torch.cos(out_h)
],
dim=1,
)[None, :, :]
if cls_token:
cls_token_pe = torch.zeros([1, 1, embed_dims], dtype=torch.float32)
pos_emb = torch.cat([cls_token_pe, pos_emb], dim=1)
return pos_emb
class VisionTransformer(_VisionTransformer):
"""Vision Transformer.
A pytorch implement of: `An Images is Worth 16x16 Words: Transformers for
Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.
Part of the code is modified from:
`<https://github.com/facebookresearch/moco-v3/blob/main/vits.py>`_.
Args:
stop_grad_conv1 (bool, optional): whether to stop the gradient of
convolution layer in `PatchEmbed`. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
arch_zoo = {
**dict.fromkeys(
['mocov3-s', 'mocov3-small'], {
'embed_dims': 384,
'num_layers': 12,
'num_heads': 12,
'feedforward_channels': 1536,
}),
**dict.fromkeys(
['b', 'base'], {
'embed_dims': 768,
'num_layers': 12,
'num_heads': 12,
'feedforward_channels': 3072
}),
}
def __init__(self,
stop_grad_conv1=False,
frozen_stages=-1,
norm_eval=False,
init_cfg=None,
**kwargs):
super(VisionTransformer, self).__init__(init_cfg=init_cfg,)
self.patch_size = kwargs['patch_size']
self.frozen_stages = frozen_stages
self.norm_eval = norm_eval
self.init_cfg = init_cfg
if isinstance(self.patch_embed, PatchEmbed):
if stop_grad_conv1:
self.patch_embed.projection.weight.requires_grad = False
self.patch_embed.projection.bias.requires_grad = False
self._freeze_stages()
def init_weights(self):
super(VisionTransformer, self).init_weights()
if not (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Use fixed 2D sin-cos position embedding
pos_emb = build_2d_sincos_position_embedding(
patches_resolution=self.patch_resolution,
embed_dims=self.embed_dims,
cls_token=True)
self.pos_embed.data.copy_(pos_emb)
self.pos_embed.requires_grad = False
# xavier_uniform initialization for PatchEmbed
if isinstance(self.patch_embed, PatchEmbed):
val = math.sqrt(
6. / float(3 * reduce(mul, to_2tuple(self.patch_size), 1) +
self.embed_dims))
nn.init.uniform_(self.patch_embed.projection.weight, -val, val)
nn.init.zeros_(self.patch_embed.projection.bias)
# initialization for linear layers
for name, m in self.named_modules():
if isinstance(m, nn.Linear):
if 'qkv' in name:
# treat the weights of Q, K, V separately
val = math.sqrt(
6. /
float(m.weight.shape[0] // 3 + m.weight.shape[1]))
nn.init.uniform_(m.weight, -val, val)
else:
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)
nn.init.normal_(self.cls_token, std=1e-6)
def _freeze_stages(self):
"""Freeze patch_embed layer, some parameters and stages."""
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
self.cls_token.requires_grad = False
self.pos_embed.requires_grad = False
for i in range(1, self.frozen_stages + 1):
m = self.layers[i - 1]
m.eval()
for param in m.parameters():
param.requires_grad = False
if i == (self.num_layers) and self.final_norm:
for param in getattr(self, 'norm1').parameters():
param.requires_grad = False
def train(self, mode=True):
super(VisionTransformer, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval() |