File size: 1,045 Bytes
8166792 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import timm
import torch.nn as nn
from pathlib import Path
from .utils import activations, forward_default, get_activation
from ..external.next_vit.classification.nextvit import *
def forward_next_vit(pretrained, x):
return forward_default(pretrained, x, "forward")
def _make_next_vit_backbone(
model,
hooks=[2, 6, 36, 39],
):
pretrained = nn.Module()
pretrained.model = model
pretrained.model.features[hooks[0]].register_forward_hook(get_activation("1"))
pretrained.model.features[hooks[1]].register_forward_hook(get_activation("2"))
pretrained.model.features[hooks[2]].register_forward_hook(get_activation("3"))
pretrained.model.features[hooks[3]].register_forward_hook(get_activation("4"))
pretrained.activations = activations
return pretrained
def _make_pretrained_next_vit_large_6m(hooks=None):
model = timm.create_model("nextvit_large")
hooks = [2, 6, 36, 39] if hooks == None else hooks
return _make_next_vit_backbone(
model,
hooks=hooks,
)
|