pdiscoformer / load_model.py
ananthu-aniraj's picture
add initial files
20239f9
raw
history blame
10.5 kB
import copy
import os
from pathlib import Path
import torch
from timm.models import create_model
from torchvision.models import get_model
from models import pdiscoformer_vit_bb, pdisconet_vit_bb, pdisconet_resnet_torchvision_bb
from models.individual_landmark_resnet import IndividualLandmarkResNet
from models.individual_landmark_convnext import IndividualLandmarkConvNext
from models.individual_landmark_vit import IndividualLandmarkViT
from utils import load_state_dict_pdisco
def load_model_arch(args, num_cls):
"""
Function to load the model
:param args: Arguments from the command line
:param num_cls: Number of classes in the dataset
:return:
"""
if 'resnet' in args.model_arch:
num_layers_split = [int(s) for s in args.model_arch if s.isdigit()]
num_layers = int(''.join(map(str, num_layers_split)))
if num_layers >= 100:
timm_model_arch = args.model_arch + ".a1h_in1k"
else:
timm_model_arch = args.model_arch + ".a1_in1k"
if "resnet" in args.model_arch and args.use_torchvision_resnet_model:
weights = "DEFAULT" if args.pretrained_start_weights else None
base_model = get_model(args.model_arch, weights=weights)
elif "resnet" in args.model_arch and not args.use_torchvision_resnet_model:
if args.eval_only:
base_model = create_model(
timm_model_arch,
pretrained=args.pretrained_start_weights,
num_classes=num_cls,
output_stride=args.output_stride,
)
else:
base_model = create_model(
timm_model_arch,
pretrained=args.pretrained_start_weights,
drop_path_rate=args.drop_path,
num_classes=num_cls,
output_stride=args.output_stride,
)
elif "convnext" in args.model_arch:
if args.eval_only:
base_model = create_model(
args.model_arch,
pretrained=args.pretrained_start_weights,
num_classes=num_cls,
output_stride=args.output_stride,
)
else:
base_model = create_model(
args.model_arch,
pretrained=args.pretrained_start_weights,
drop_path_rate=args.drop_path,
num_classes=num_cls,
output_stride=args.output_stride,
)
elif "vit" in args.model_arch:
if args.eval_only:
base_model = create_model(
args.model_arch,
pretrained=args.pretrained_start_weights,
img_size=args.image_size,
)
else:
base_model = create_model(
args.model_arch,
pretrained=args.pretrained_start_weights,
drop_path_rate=args.drop_path,
img_size=args.image_size,
)
vit_patch_size = base_model.patch_embed.proj.kernel_size[0]
if args.image_size % vit_patch_size != 0:
raise ValueError(f"Image size {args.image_size} must be divisible by patch size {vit_patch_size}")
else:
raise ValueError('Model not supported.')
return base_model
def init_pdisco_model(base_model, args, num_cls):
"""
Function to initialize the model
:param base_model: Base model
:param args: Arguments from the command line
:param num_cls: Number of classes in the dataset
:return:
"""
# Initialize the network
if 'convnext' in args.model_arch:
sl_channels = base_model.stages[-1].downsample[-1].in_channels
fl_channels = base_model.head.in_features
model = IndividualLandmarkConvNext(base_model, args.num_parts, num_classes=num_cls,
sl_channels=sl_channels, fl_channels=fl_channels,
part_dropout=args.part_dropout, modulation_type=args.modulation_type,
gumbel_softmax=args.gumbel_softmax,
gumbel_softmax_temperature=args.gumbel_softmax_temperature,
gumbel_softmax_hard=args.gumbel_softmax_hard,
modulation_orth=args.modulation_orth, classifier_type=args.classifier_type,
noise_variance=args.noise_variance)
elif 'resnet' in args.model_arch:
sl_channels = base_model.layer4[0].conv1.in_channels
fl_channels = base_model.fc.in_features
model = IndividualLandmarkResNet(base_model, args.num_parts, num_classes=num_cls,
sl_channels=sl_channels, fl_channels=fl_channels,
use_torchvision_model=args.use_torchvision_resnet_model,
part_dropout=args.part_dropout, modulation_type=args.modulation_type,
gumbel_softmax=args.gumbel_softmax,
gumbel_softmax_temperature=args.gumbel_softmax_temperature,
gumbel_softmax_hard=args.gumbel_softmax_hard,
modulation_orth=args.modulation_orth, classifier_type=args.classifier_type,
noise_variance=args.noise_variance)
elif 'vit' in args.model_arch:
model = IndividualLandmarkViT(base_model, num_landmarks=args.num_parts, num_classes=num_cls,
part_dropout=args.part_dropout,
modulation_type=args.modulation_type, gumbel_softmax=args.gumbel_softmax,
gumbel_softmax_temperature=args.gumbel_softmax_temperature,
gumbel_softmax_hard=args.gumbel_softmax_hard,
modulation_orth=args.modulation_orth, classifier_type=args.classifier_type,
noise_variance=args.noise_variance)
else:
raise ValueError('Model not supported.')
return model
def load_model_pdisco(args, num_cls):
"""
Function to load the model
:param args: Arguments from the command line
:param num_cls: Number of classes in the dataset
:return:
"""
base_model = load_model_arch(args, num_cls)
model = init_pdisco_model(base_model, args, num_cls)
return model
def pdiscoformer_vit(pretrained=True, model_dataset="cub", k=8, model_url="", img_size=224, num_cls=200):
"""
Function to load the PDiscoFormer model with ViT backbone
:param pretrained: Boolean flag to load the pretrained weights
:param model_dataset: Dataset for which the model is trained
:param k: Number of unsupervised landmarks the model is trained on
:param model_url: URL to load the model weights from
:param img_size: Image size
:param num_cls: Number of classes in the dataset
:return: PDiscoFormer model with ViT backbone
"""
model = pdiscoformer_vit_bb("vit_base_patch14_reg4_dinov2.lvd142m", num_cls=num_cls, k=k, img_size=img_size)
if pretrained:
hub_dir = torch.hub.get_dir()
model_dir = os.path.join(hub_dir, "pdiscoformer_checkpoints", f"pdiscoformer_{model_dataset}")
Path(model_dir).mkdir(parents=True, exist_ok=True)
url_path = model_url + str(k) + "_parts_snapshot_best.pt"
snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu')
if 'model_state' in snapshot_data:
_, state_dict = load_state_dict_pdisco(snapshot_data)
else:
state_dict = copy.deepcopy(snapshot_data)
model.load_state_dict(state_dict, strict=True)
return model
def pdisconet_vit(pretrained=True, model_dataset="nabirds", k=8, model_url="", img_size=224, num_cls=555):
"""
Function to load the PDiscoNet model with ViT backbone
:param pretrained: Boolean flag to load the pretrained weights
:param model_dataset: Dataset for which the model is trained
:param k: Number of unsupervised landmarks the model is trained on
:param model_url: URL to load the model weights from
:param img_size: Image size
:param num_cls: Number of classes in the dataset
:return: PDiscoNet model with ViT backbone
"""
model = pdisconet_vit_bb("vit_base_patch14_reg4_dinov2.lvd142m", num_cls=num_cls, k=k, img_size=img_size)
if pretrained:
hub_dir = torch.hub.get_dir()
model_dir = os.path.join(hub_dir, "pdiscoformer_checkpoints", f"pdisconet_{model_dataset}")
Path(model_dir).mkdir(parents=True, exist_ok=True)
url_path = model_url + str(k) + "_parts_snapshot_best.pt"
snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu')
if 'model_state' in snapshot_data:
_, state_dict = load_state_dict_pdisco(snapshot_data)
else:
state_dict = copy.deepcopy(snapshot_data)
model.load_state_dict(state_dict, strict=True)
return model
def pdisconet_resnet101(pretrained=True, model_dataset="nabirds", k=8, model_url="", num_cls=555):
"""
Function to load the PDiscoNet model with ResNet-101 backbone
:param pretrained: Boolean flag to load the pretrained weights
:param model_dataset: Dataset for which the model is trained
:param k: Number of unsupervised landmarks the model is trained on
:param model_url: URL to load the model weights from
:param num_cls: Number of classes in the dataset
:return: PDiscoNet model with ResNet-101 backbone
"""
model = pdisconet_resnet_torchvision_bb("resnet101", num_cls=num_cls, k=k)
if pretrained:
hub_dir = torch.hub.get_dir()
model_dir = os.path.join(hub_dir, "pdiscoformer_checkpoints", f"pdisconet_{model_dataset}")
Path(model_dir).mkdir(parents=True, exist_ok=True)
url_path = model_url + str(k) + "_parts_snapshot_best.pt"
snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu')
if 'model_state' in snapshot_data:
_, state_dict = load_state_dict_pdisco(snapshot_data)
else:
state_dict = copy.deepcopy(snapshot_data)
model.load_state_dict(state_dict, strict=True)
return model