Vaishanth Ramaraj
initial commit
8166792
raw
history blame
1.05 kB
import timm
import torch.nn as nn
from pathlib import Path
from .utils import activations, forward_default, get_activation
from ..external.next_vit.classification.nextvit import *
def forward_next_vit(pretrained, x):
return forward_default(pretrained, x, "forward")
def _make_next_vit_backbone(
model,
hooks=[2, 6, 36, 39],
):
pretrained = nn.Module()
pretrained.model = model
pretrained.model.features[hooks[0]].register_forward_hook(get_activation("1"))
pretrained.model.features[hooks[1]].register_forward_hook(get_activation("2"))
pretrained.model.features[hooks[2]].register_forward_hook(get_activation("3"))
pretrained.model.features[hooks[3]].register_forward_hook(get_activation("4"))
pretrained.activations = activations
return pretrained
def _make_pretrained_next_vit_large_6m(hooks=None):
model = timm.create_model("nextvit_large")
hooks = [2, 6, 36, 39] if hooks == None else hooks
return _make_next_vit_backbone(
model,
hooks=hooks,
)