|
from abc import ABC, abstractmethod |
|
import torch |
|
import torch.nn as nn |
|
import torchvision.transforms as transforms |
|
|
|
|
|
class Backbone(nn.Module, ABC): |
|
def __init__(self): |
|
super(Backbone, self).__init__() |
|
|
|
@abstractmethod |
|
def forward(self, x): |
|
pass |
|
|
|
@abstractmethod |
|
def get_dimension(self): |
|
pass |
|
|
|
@abstractmethod |
|
def get_out_size(self, in_size): |
|
pass |
|
|
|
def get_transform(self): |
|
pass |
|
|
|
|
|
|
|
class DinoV2Backbone(Backbone): |
|
def __init__(self, model_name): |
|
super(DinoV2Backbone, self).__init__() |
|
self.model = torch.hub.load('facebookresearch/dinov2', model_name) |
|
|
|
def forward(self, x): |
|
b, c, h, w = x.shape |
|
out_h, out_w = self.get_out_size((h, w)) |
|
x = self.model.forward_features(x)['x_norm_patchtokens'] |
|
x = x.view(x.size(0), out_h, out_w, -1).permute(0, 3, 1, 2) |
|
return x |
|
|
|
def get_dimension(self): |
|
return self.model.embed_dim |
|
|
|
def get_out_size(self, in_size): |
|
h, w = in_size |
|
return (h // self.model.patch_size, w // self.model.patch_size) |
|
|
|
def get_transform(self, in_size): |
|
return transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.485,0.456,0.406], |
|
std=[0.229,0.224,0.225] |
|
), |
|
transforms.Resize(in_size), |
|
]) |