RADIO / enable_cpe_support.py
gheinrich's picture
Upload model
d3b8c8f verified
# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from typing import Union, Tuple
from types import MethodType
import torch
from torch import nn
from timm.models import VisionTransformer, checkpoint_seq
from .vit_patch_generator import ViTPatchGenerator
def _forward_cpe(self: VisionTransformer, x: torch.Tensor) -> torch.Tensor:
x = self.patch_generator(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
x = self.norm(x)
return x
def enable_cpe(model: nn.Module,
max_img_size: Union[int, Tuple[int, int]] = 1024,
num_cls_tokens: int = 1,
pos_dropout: float = 0.1,
register_multiple: int = 0,
):
if not isinstance(model, VisionTransformer):
raise ValueError("CPE only support for VisionTransformer models!")
patch_size = model.patch_embed.patch_size[0]
embed_dim = model.embed_dim
input_dims = model.patch_embed.img_size
normalize_patches = not isinstance(model.patch_embed.norm, nn.Identity)
cls_token = model.cls_token is not None
max_img_size = int(round(max_img_size / patch_size) * patch_size)
patch_generator = ViTPatchGenerator(
patch_size=patch_size,
embed_dim=embed_dim,
input_dims=input_dims,
normalize_patches=normalize_patches,
cls_token=cls_token,
max_input_dims=max_img_size,
pos_dropout=pos_dropout,
num_cls_tokens=num_cls_tokens,
register_multiple=register_multiple,
)
model.patch_generator = patch_generator
model.patch_embed = None
model.cls_token = None
model.pos_embed = None
model.pos_drop = None
model.num_cls_tokens = num_cls_tokens
model.num_registers = patch_generator.num_registers
model.forward_features = MethodType(_forward_cpe, model)