Spaces:
Sleeping
Sleeping
import copy | |
import os | |
from pathlib import Path | |
import torch | |
from timm.models import create_model | |
from torchvision.models import get_model | |
from models import pdiscoformer_vit_bb, pdisconet_vit_bb, pdisconet_resnet_torchvision_bb | |
from models.individual_landmark_resnet import IndividualLandmarkResNet | |
from models.individual_landmark_convnext import IndividualLandmarkConvNext | |
from models.individual_landmark_vit import IndividualLandmarkViT | |
from utils import load_state_dict_pdisco | |
def load_model_arch(args, num_cls): | |
""" | |
Function to load the model | |
:param args: Arguments from the command line | |
:param num_cls: Number of classes in the dataset | |
:return: | |
""" | |
if 'resnet' in args.model_arch: | |
num_layers_split = [int(s) for s in args.model_arch if s.isdigit()] | |
num_layers = int(''.join(map(str, num_layers_split))) | |
if num_layers >= 100: | |
timm_model_arch = args.model_arch + ".a1h_in1k" | |
else: | |
timm_model_arch = args.model_arch + ".a1_in1k" | |
if "resnet" in args.model_arch and args.use_torchvision_resnet_model: | |
weights = "DEFAULT" if args.pretrained_start_weights else None | |
base_model = get_model(args.model_arch, weights=weights) | |
elif "resnet" in args.model_arch and not args.use_torchvision_resnet_model: | |
if args.eval_only: | |
base_model = create_model( | |
timm_model_arch, | |
pretrained=args.pretrained_start_weights, | |
num_classes=num_cls, | |
output_stride=args.output_stride, | |
) | |
else: | |
base_model = create_model( | |
timm_model_arch, | |
pretrained=args.pretrained_start_weights, | |
drop_path_rate=args.drop_path, | |
num_classes=num_cls, | |
output_stride=args.output_stride, | |
) | |
elif "convnext" in args.model_arch: | |
if args.eval_only: | |
base_model = create_model( | |
args.model_arch, | |
pretrained=args.pretrained_start_weights, | |
num_classes=num_cls, | |
output_stride=args.output_stride, | |
) | |
else: | |
base_model = create_model( | |
args.model_arch, | |
pretrained=args.pretrained_start_weights, | |
drop_path_rate=args.drop_path, | |
num_classes=num_cls, | |
output_stride=args.output_stride, | |
) | |
elif "vit" in args.model_arch: | |
if args.eval_only: | |
base_model = create_model( | |
args.model_arch, | |
pretrained=args.pretrained_start_weights, | |
img_size=args.image_size, | |
) | |
else: | |
base_model = create_model( | |
args.model_arch, | |
pretrained=args.pretrained_start_weights, | |
drop_path_rate=args.drop_path, | |
img_size=args.image_size, | |
) | |
vit_patch_size = base_model.patch_embed.proj.kernel_size[0] | |
if args.image_size % vit_patch_size != 0: | |
raise ValueError(f"Image size {args.image_size} must be divisible by patch size {vit_patch_size}") | |
else: | |
raise ValueError('Model not supported.') | |
return base_model | |
def init_pdisco_model(base_model, args, num_cls): | |
""" | |
Function to initialize the model | |
:param base_model: Base model | |
:param args: Arguments from the command line | |
:param num_cls: Number of classes in the dataset | |
:return: | |
""" | |
# Initialize the network | |
if 'convnext' in args.model_arch: | |
sl_channels = base_model.stages[-1].downsample[-1].in_channels | |
fl_channels = base_model.head.in_features | |
model = IndividualLandmarkConvNext(base_model, args.num_parts, num_classes=num_cls, | |
sl_channels=sl_channels, fl_channels=fl_channels, | |
part_dropout=args.part_dropout, modulation_type=args.modulation_type, | |
gumbel_softmax=args.gumbel_softmax, | |
gumbel_softmax_temperature=args.gumbel_softmax_temperature, | |
gumbel_softmax_hard=args.gumbel_softmax_hard, | |
modulation_orth=args.modulation_orth, classifier_type=args.classifier_type, | |
noise_variance=args.noise_variance) | |
elif 'resnet' in args.model_arch: | |
sl_channels = base_model.layer4[0].conv1.in_channels | |
fl_channels = base_model.fc.in_features | |
model = IndividualLandmarkResNet(base_model, args.num_parts, num_classes=num_cls, | |
sl_channels=sl_channels, fl_channels=fl_channels, | |
use_torchvision_model=args.use_torchvision_resnet_model, | |
part_dropout=args.part_dropout, modulation_type=args.modulation_type, | |
gumbel_softmax=args.gumbel_softmax, | |
gumbel_softmax_temperature=args.gumbel_softmax_temperature, | |
gumbel_softmax_hard=args.gumbel_softmax_hard, | |
modulation_orth=args.modulation_orth, classifier_type=args.classifier_type, | |
noise_variance=args.noise_variance) | |
elif 'vit' in args.model_arch: | |
model = IndividualLandmarkViT(base_model, num_landmarks=args.num_parts, num_classes=num_cls, | |
part_dropout=args.part_dropout, | |
modulation_type=args.modulation_type, gumbel_softmax=args.gumbel_softmax, | |
gumbel_softmax_temperature=args.gumbel_softmax_temperature, | |
gumbel_softmax_hard=args.gumbel_softmax_hard, | |
modulation_orth=args.modulation_orth, classifier_type=args.classifier_type, | |
noise_variance=args.noise_variance) | |
else: | |
raise ValueError('Model not supported.') | |
return model | |
def load_model_pdisco(args, num_cls): | |
""" | |
Function to load the model | |
:param args: Arguments from the command line | |
:param num_cls: Number of classes in the dataset | |
:return: | |
""" | |
base_model = load_model_arch(args, num_cls) | |
model = init_pdisco_model(base_model, args, num_cls) | |
return model | |
def pdiscoformer_vit(pretrained=True, model_dataset="cub", k=8, model_url="", img_size=224, num_cls=200): | |
""" | |
Function to load the PDiscoFormer model with ViT backbone | |
:param pretrained: Boolean flag to load the pretrained weights | |
:param model_dataset: Dataset for which the model is trained | |
:param k: Number of unsupervised landmarks the model is trained on | |
:param model_url: URL to load the model weights from | |
:param img_size: Image size | |
:param num_cls: Number of classes in the dataset | |
:return: PDiscoFormer model with ViT backbone | |
""" | |
model = pdiscoformer_vit_bb("vit_base_patch14_reg4_dinov2.lvd142m", num_cls=num_cls, k=k, img_size=img_size) | |
if pretrained: | |
hub_dir = torch.hub.get_dir() | |
model_dir = os.path.join(hub_dir, "pdiscoformer_checkpoints", f"pdiscoformer_{model_dataset}") | |
Path(model_dir).mkdir(parents=True, exist_ok=True) | |
url_path = model_url + str(k) + "_parts_snapshot_best.pt" | |
snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu') | |
if 'model_state' in snapshot_data: | |
_, state_dict = load_state_dict_pdisco(snapshot_data) | |
else: | |
state_dict = copy.deepcopy(snapshot_data) | |
model.load_state_dict(state_dict, strict=True) | |
return model | |
def pdisconet_vit(pretrained=True, model_dataset="nabirds", k=8, model_url="", img_size=224, num_cls=555): | |
""" | |
Function to load the PDiscoNet model with ViT backbone | |
:param pretrained: Boolean flag to load the pretrained weights | |
:param model_dataset: Dataset for which the model is trained | |
:param k: Number of unsupervised landmarks the model is trained on | |
:param model_url: URL to load the model weights from | |
:param img_size: Image size | |
:param num_cls: Number of classes in the dataset | |
:return: PDiscoNet model with ViT backbone | |
""" | |
model = pdisconet_vit_bb("vit_base_patch14_reg4_dinov2.lvd142m", num_cls=num_cls, k=k, img_size=img_size) | |
if pretrained: | |
hub_dir = torch.hub.get_dir() | |
model_dir = os.path.join(hub_dir, "pdiscoformer_checkpoints", f"pdisconet_{model_dataset}") | |
Path(model_dir).mkdir(parents=True, exist_ok=True) | |
url_path = model_url + str(k) + "_parts_snapshot_best.pt" | |
snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu') | |
if 'model_state' in snapshot_data: | |
_, state_dict = load_state_dict_pdisco(snapshot_data) | |
else: | |
state_dict = copy.deepcopy(snapshot_data) | |
model.load_state_dict(state_dict, strict=True) | |
return model | |
def pdisconet_resnet101(pretrained=True, model_dataset="nabirds", k=8, model_url="", num_cls=555): | |
""" | |
Function to load the PDiscoNet model with ResNet-101 backbone | |
:param pretrained: Boolean flag to load the pretrained weights | |
:param model_dataset: Dataset for which the model is trained | |
:param k: Number of unsupervised landmarks the model is trained on | |
:param model_url: URL to load the model weights from | |
:param num_cls: Number of classes in the dataset | |
:return: PDiscoNet model with ResNet-101 backbone | |
""" | |
model = pdisconet_resnet_torchvision_bb("resnet101", num_cls=num_cls, k=k) | |
if pretrained: | |
hub_dir = torch.hub.get_dir() | |
model_dir = os.path.join(hub_dir, "pdiscoformer_checkpoints", f"pdisconet_{model_dataset}") | |
Path(model_dir).mkdir(parents=True, exist_ok=True) | |
url_path = model_url + str(k) + "_parts_snapshot_best.pt" | |
snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu') | |
if 'model_state' in snapshot_data: | |
_, state_dict = load_state_dict_pdisco(snapshot_data) | |
else: | |
state_dict = copy.deepcopy(snapshot_data) | |
model.load_state_dict(state_dict, strict=True) | |
return model | |