pdiscoformer / models /vit_baseline.py
ananthu-aniraj's picture
add initial files
20239f9
raw
history blame
8.7 kB
# Compostion of the VisionTransformer class from timm with extra features: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
import torch
import torch.nn as nn
from typing import Tuple, Union, Sequence, Any
from timm.layers import trunc_normal_
from timm.models.vision_transformer import Block, Attention
from layers.transformer_layers import BlockWQKVReturn, AttentionWQKVReturn
from utils.misc_utils import compute_attention
class BaselineViT(torch.nn.Module):
"""
Modifications:
- Use PDiscoBlock instead of Block
- Use PDiscoAttention instead of Attention
- Return the mean of k over heads from attention
- Option to use only class tokens or only patch tokens or both (concat) for classification
"""
def __init__(self, init_model: torch.nn.Module, num_classes: int,
class_tokens_only: bool = False,
patch_tokens_only: bool = False, return_transformer_qkv: bool = False) -> None:
super().__init__()
self.num_classes = num_classes
self.class_tokens_only = class_tokens_only
self.patch_tokens_only = patch_tokens_only
self.num_prefix_tokens = init_model.num_prefix_tokens
self.num_reg_tokens = init_model.num_reg_tokens
self.has_class_token = init_model.has_class_token
self.no_embed_class = init_model.no_embed_class
self.cls_token = init_model.cls_token
self.reg_token = init_model.reg_token
self.patch_embed = init_model.patch_embed
self.pos_embed = init_model.pos_embed
self.pos_drop = init_model.pos_drop
self.part_embed = nn.Identity()
self.patch_prune = nn.Identity()
self.norm_pre = init_model.norm_pre
self.blocks = init_model.blocks
self.norm = init_model.norm
self.fc_norm = init_model.fc_norm
if class_tokens_only or patch_tokens_only:
self.head = nn.Linear(init_model.embed_dim, num_classes)
else:
self.head = nn.Linear(init_model.embed_dim * 2, num_classes)
self.h_fmap = int(self.patch_embed.img_size[0] // self.patch_embed.patch_size[0])
self.w_fmap = int(self.patch_embed.img_size[1] // self.patch_embed.patch_size[1])
self.return_transformer_qkv = return_transformer_qkv
self.convert_blocks_and_attention()
self._init_weights_head()
def convert_blocks_and_attention(self):
for module in self.modules():
if isinstance(module, Block):
module.__class__ = BlockWQKVReturn
elif isinstance(module, Attention):
module.__class__ = AttentionWQKVReturn
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
pos_embed = self.pos_embed
to_cat = []
if self.cls_token is not None:
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
if self.reg_token is not None:
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
if self.no_embed_class:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x = x + pos_embed
if to_cat:
x = torch.cat(to_cat + [x], dim=1)
else:
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
if to_cat:
x = torch.cat(to_cat + [x], dim=1)
x = x + pos_embed
return self.pos_drop(x)
def _init_weights_head(self):
trunc_normal_(self.head.weight, std=.02)
if self.head.bias is not None:
nn.init.constant_(self.head.bias, 0.)
def forward(self, x: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
x = self.patch_embed(x)
# Position Embedding
x = self._pos_embed(x)
x = self.part_embed(x)
x = self.patch_prune(x)
# Forward pass through transformer
x = self.norm_pre(x)
if self.return_transformer_qkv:
# Return keys of last attention layer
for i, blk in enumerate(self.blocks):
x, qkv = blk(x, return_qkv=True)
else:
x = self.blocks(x)
x = self.norm(x)
# Classification head
x = self.fc_norm(x)
if self.class_tokens_only: # only use class token
x = x[:, 0, :]
elif self.patch_tokens_only: # only use patch tokens
x = x[:, self.num_prefix_tokens:, :].mean(dim=1)
else:
x = torch.cat([x[:, 0, :], x[:, self.num_prefix_tokens:, :].mean(dim=1)], dim=1)
x = self.head(x)
if self.return_transformer_qkv:
return x, qkv
else:
return x
def get_specific_intermediate_layer(
self,
x: torch.Tensor,
n: int = 1,
return_qkv: bool = False,
return_att_weights: bool = False,
):
num_blocks = len(self.blocks)
attn_weights = []
if n >= num_blocks:
raise ValueError(f"n must be less than {num_blocks}")
# forward pass
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.norm_pre(x)
if n == -1:
if return_qkv:
raise ValueError("take_indice cannot be -1 if return_transformer_qkv is True")
else:
return x
for i, blk in enumerate(self.blocks):
if self.return_transformer_qkv:
x, qkv = blk(x, return_qkv=True)
if return_att_weights:
attn_weight, _ = compute_attention(qkv)
attn_weights.append(attn_weight.detach())
else:
x = blk(x)
if i == n:
output = x.clone()
if self.return_transformer_qkv and return_qkv:
qkv_output = qkv.clone()
break
if self.return_transformer_qkv and return_qkv and return_att_weights:
return output, qkv_output, attn_weights
elif self.return_transformer_qkv and return_qkv:
return output, qkv_output
elif self.return_transformer_qkv and return_att_weights:
return output, attn_weights
else:
return output
def _intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1,
):
outputs, num_blocks = [], len(self.blocks)
if self.return_transformer_qkv:
qkv_outputs = []
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
# forward pass
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.norm_pre(x)
for i, blk in enumerate(self.blocks):
if self.return_transformer_qkv:
x, qkv = blk(x, return_qkv=True)
else:
x = blk(x)
if i in take_indices:
outputs.append(x)
if self.return_transformer_qkv:
qkv_outputs.append(qkv)
if self.return_transformer_qkv:
return outputs, qkv_outputs
else:
return outputs
def get_intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1,
reshape: bool = False,
return_prefix_tokens: bool = False,
norm: bool = False,
) -> tuple[tuple, Any]:
""" Intermediate layer accessor (NOTE: This is a WIP experiment).
Inspired by DINO / DINOv2 interface
"""
# take last n blocks if n is an int, if in is a sequence, select by matching indices
if self.return_transformer_qkv:
outputs, qkv = self._intermediate_layers(x, n)
else:
outputs = self._intermediate_layers(x, n)
if norm:
outputs = [self.norm(out) for out in outputs]
prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
if reshape:
grid_size = self.patch_embed.grid_size
outputs = [
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
for out in outputs
]
if return_prefix_tokens:
return_out = tuple(zip(outputs, prefix_tokens))
else:
return_out = tuple(outputs)
if self.return_transformer_qkv:
return return_out, qkv
else:
return return_out