File size: 14,458 Bytes
37cca23 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from functools import partial
import json
import logging
import os
import sys
from typing import List, Optional
import torch
from torch.nn.functional import one_hot, softmax
import dinov2.distributed as distributed
from dinov2.data import SamplerType, make_data_loader, make_dataset
from dinov2.data.transforms import make_classification_eval_transform
from dinov2.eval.metrics import AccuracyAveraging, build_topk_accuracy_metric
from dinov2.eval.setup import get_args_parser as get_setup_args_parser
from dinov2.eval.setup import setup_and_build_model
from dinov2.eval.utils import ModelWithNormalize, evaluate, extract_features
logger = logging.getLogger("dinov2")
def get_args_parser(
description: Optional[str] = None,
parents: Optional[List[argparse.ArgumentParser]] = None,
add_help: bool = True,
):
parents = parents or []
setup_args_parser = get_setup_args_parser(parents=parents, add_help=False)
parents = [setup_args_parser]
parser = argparse.ArgumentParser(
description=description,
parents=parents,
add_help=add_help,
)
parser.add_argument(
"--train-dataset",
dest="train_dataset_str",
type=str,
help="Training dataset",
)
parser.add_argument(
"--val-dataset",
dest="val_dataset_str",
type=str,
help="Validation dataset",
)
parser.add_argument(
"--nb_knn",
nargs="+",
type=int,
help="Number of NN to use. 20 is usually working the best.",
)
parser.add_argument(
"--temperature",
type=float,
help="Temperature used in the voting coefficient",
)
parser.add_argument(
"--gather-on-cpu",
action="store_true",
help="Whether to gather the train features on cpu, slower"
"but useful to avoid OOM for large datasets (e.g. ImageNet22k).",
)
parser.add_argument(
"--batch-size",
type=int,
help="Batch size.",
)
parser.add_argument(
"--n-per-class-list",
nargs="+",
type=int,
help="Number to take per class",
)
parser.add_argument(
"--n-tries",
type=int,
help="Number of tries",
)
parser.set_defaults(
train_dataset_str="ImageNet:split=TRAIN",
val_dataset_str="ImageNet:split=VAL",
nb_knn=[10, 20, 100, 200],
temperature=0.07,
batch_size=256,
n_per_class_list=[-1],
n_tries=1,
)
return parser
class KnnModule(torch.nn.Module):
"""
Gets knn of test features from all processes on a chunk of the train features
Each rank gets a chunk of the train features as well as a chunk of the test features.
In `compute_neighbors`, for each rank one after the other, its chunk of test features
is sent to all devices, partial knns are computed with each chunk of train features
then collated back on the original device.
"""
def __init__(self, train_features, train_labels, nb_knn, T, device, num_classes=1000):
super().__init__()
self.global_rank = distributed.get_global_rank()
self.global_size = distributed.get_global_size()
self.device = device
self.train_features_rank_T = train_features.chunk(self.global_size)[self.global_rank].T.to(self.device)
self.candidates = train_labels.chunk(self.global_size)[self.global_rank].view(1, -1).to(self.device)
self.nb_knn = nb_knn
self.max_k = max(self.nb_knn)
self.T = T
self.num_classes = num_classes
def _get_knn_sims_and_labels(self, similarity, train_labels):
topk_sims, indices = similarity.topk(self.max_k, largest=True, sorted=True)
neighbors_labels = torch.gather(train_labels, 1, indices)
return topk_sims, neighbors_labels
def _similarity_for_rank(self, features_rank, source_rank):
# Send the features from `source_rank` to all ranks
broadcast_shape = torch.tensor(features_rank.shape).to(self.device)
torch.distributed.broadcast(broadcast_shape, source_rank)
broadcasted = features_rank
if self.global_rank != source_rank:
broadcasted = torch.zeros(*broadcast_shape, dtype=features_rank.dtype, device=self.device)
torch.distributed.broadcast(broadcasted, source_rank)
# Compute the neighbors for `source_rank` among `train_features_rank_T`
similarity_rank = torch.mm(broadcasted, self.train_features_rank_T)
candidate_labels = self.candidates.expand(len(similarity_rank), -1)
return self._get_knn_sims_and_labels(similarity_rank, candidate_labels)
def _gather_all_knn_for_rank(self, topk_sims, neighbors_labels, target_rank):
# Gather all neighbors for `target_rank`
topk_sims_rank = retrieved_rank = None
if self.global_rank == target_rank:
topk_sims_rank = [torch.zeros_like(topk_sims) for _ in range(self.global_size)]
retrieved_rank = [torch.zeros_like(neighbors_labels) for _ in range(self.global_size)]
torch.distributed.gather(topk_sims, topk_sims_rank, dst=target_rank)
torch.distributed.gather(neighbors_labels, retrieved_rank, dst=target_rank)
if self.global_rank == target_rank:
# Perform a second top-k on the k * global_size retrieved neighbors
topk_sims_rank = torch.cat(topk_sims_rank, dim=1)
retrieved_rank = torch.cat(retrieved_rank, dim=1)
results = self._get_knn_sims_and_labels(topk_sims_rank, retrieved_rank)
return results
return None
def compute_neighbors(self, features_rank):
for rank in range(self.global_size):
topk_sims, neighbors_labels = self._similarity_for_rank(features_rank, rank)
results = self._gather_all_knn_for_rank(topk_sims, neighbors_labels, rank)
if results is not None:
topk_sims_rank, neighbors_labels_rank = results
return topk_sims_rank, neighbors_labels_rank
def forward(self, features_rank):
"""
Compute the results on all values of `self.nb_knn` neighbors from the full `self.max_k`
"""
assert all(k <= self.max_k for k in self.nb_knn)
topk_sims, neighbors_labels = self.compute_neighbors(features_rank)
batch_size = neighbors_labels.shape[0]
topk_sims_transform = softmax(topk_sims / self.T, 1)
matmul = torch.mul(
one_hot(neighbors_labels, num_classes=self.num_classes),
topk_sims_transform.view(batch_size, -1, 1),
)
probas_for_k = {k: torch.sum(matmul[:, :k, :], 1) for k in self.nb_knn}
return probas_for_k
class DictKeysModule(torch.nn.Module):
def __init__(self, keys):
super().__init__()
self.keys = keys
def forward(self, features_dict, targets):
for k in self.keys:
features_dict = features_dict[k]
return {"preds": features_dict, "target": targets}
def create_module_dict(*, module, n_per_class_list, n_tries, nb_knn, train_features, train_labels):
modules = {}
mapping = create_class_indices_mapping(train_labels)
for npc in n_per_class_list:
if npc < 0: # Only one try needed when using the full data
full_module = module(
train_features=train_features,
train_labels=train_labels,
nb_knn=nb_knn,
)
modules["full"] = ModuleDictWithForward({"1": full_module})
continue
all_tries = {}
for t in range(n_tries):
final_indices = filter_train(mapping, npc, seed=t)
k_list = list(set(nb_knn + [npc]))
k_list = sorted([el for el in k_list if el <= npc])
all_tries[str(t)] = module(
train_features=train_features[final_indices],
train_labels=train_labels[final_indices],
nb_knn=k_list,
)
modules[f"{npc} per class"] = ModuleDictWithForward(all_tries)
return ModuleDictWithForward(modules)
def filter_train(mapping, n_per_class, seed):
torch.manual_seed(seed)
final_indices = []
for k in mapping.keys():
index = torch.randperm(len(mapping[k]))[:n_per_class]
final_indices.append(mapping[k][index])
return torch.cat(final_indices).squeeze()
def create_class_indices_mapping(labels):
unique_labels, inverse = torch.unique(labels, return_inverse=True)
mapping = {unique_labels[i]: (inverse == i).nonzero() for i in range(len(unique_labels))}
return mapping
class ModuleDictWithForward(torch.nn.ModuleDict):
def forward(self, *args, **kwargs):
return {k: module(*args, **kwargs) for k, module in self._modules.items()}
def eval_knn(
model,
train_dataset,
val_dataset,
accuracy_averaging,
nb_knn,
temperature,
batch_size,
num_workers,
gather_on_cpu,
n_per_class_list=[-1],
n_tries=1,
):
model = ModelWithNormalize(model)
logger.info("Extracting features for train set...")
train_features, train_labels = extract_features(
model, train_dataset, batch_size, num_workers, gather_on_cpu=gather_on_cpu
)
logger.info(f"Train features created, shape {train_features.shape}.")
val_dataloader = make_data_loader(
dataset=val_dataset,
batch_size=batch_size,
num_workers=num_workers,
sampler_type=SamplerType.DISTRIBUTED,
drop_last=False,
shuffle=False,
persistent_workers=True,
)
num_classes = train_labels.max() + 1
metric_collection = build_topk_accuracy_metric(accuracy_averaging, num_classes=num_classes)
device = torch.cuda.current_device()
partial_module = partial(KnnModule, T=temperature, device=device, num_classes=num_classes)
knn_module_dict = create_module_dict(
module=partial_module,
n_per_class_list=n_per_class_list,
n_tries=n_tries,
nb_knn=nb_knn,
train_features=train_features,
train_labels=train_labels,
)
postprocessors, metrics = {}, {}
for n_per_class, knn_module in knn_module_dict.items():
for t, knn_try in knn_module.items():
postprocessors = {
**postprocessors,
**{(n_per_class, t, k): DictKeysModule([n_per_class, t, k]) for k in knn_try.nb_knn},
}
metrics = {**metrics, **{(n_per_class, t, k): metric_collection.clone() for k in knn_try.nb_knn}}
model_with_knn = torch.nn.Sequential(model, knn_module_dict)
# ============ evaluation ... ============
logger.info("Start the k-NN classification.")
_, results_dict = evaluate(model_with_knn, val_dataloader, postprocessors, metrics, device)
# Averaging the results over the n tries for each value of n_per_class
for n_per_class, knn_module in knn_module_dict.items():
first_try = list(knn_module.keys())[0]
k_list = knn_module[first_try].nb_knn
for k in k_list:
keys = results_dict[(n_per_class, first_try, k)].keys() # keys are e.g. `top-1` and `top-5`
results_dict[(n_per_class, k)] = {
key: torch.mean(torch.stack([results_dict[(n_per_class, t, k)][key] for t in knn_module.keys()]))
for key in keys
}
for t in knn_module.keys():
del results_dict[(n_per_class, t, k)]
return results_dict
def eval_knn_with_model(
model,
output_dir,
train_dataset_str="ImageNet:split=TRAIN",
val_dataset_str="ImageNet:split=VAL",
nb_knn=(10, 20, 100, 200),
temperature=0.07,
autocast_dtype=torch.float,
accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY,
transform=None,
gather_on_cpu=False,
batch_size=256,
num_workers=5,
n_per_class_list=[-1],
n_tries=1,
):
transform = transform or make_classification_eval_transform()
train_dataset = make_dataset(
dataset_str=train_dataset_str,
transform=transform,
)
val_dataset = make_dataset(
dataset_str=val_dataset_str,
transform=transform,
)
with torch.cuda.amp.autocast(dtype=autocast_dtype):
results_dict_knn = eval_knn(
model=model,
train_dataset=train_dataset,
val_dataset=val_dataset,
accuracy_averaging=accuracy_averaging,
nb_knn=nb_knn,
temperature=temperature,
batch_size=batch_size,
num_workers=num_workers,
gather_on_cpu=gather_on_cpu,
n_per_class_list=n_per_class_list,
n_tries=n_tries,
)
results_dict = {}
if distributed.is_main_process():
for knn_ in results_dict_knn.keys():
top1 = results_dict_knn[knn_]["top-1"].item() * 100.0
top5 = results_dict_knn[knn_]["top-5"].item() * 100.0
results_dict[f"{knn_} Top 1"] = top1
results_dict[f"{knn_} Top 5"] = top5
logger.info(f"{knn_} classifier result: Top1: {top1:.2f} Top5: {top5:.2f}")
metrics_file_path = os.path.join(output_dir, "results_eval_knn.json")
with open(metrics_file_path, "a") as f:
for k, v in results_dict.items():
f.write(json.dumps({k: v}) + "\n")
if distributed.is_enabled():
torch.distributed.barrier()
return results_dict
def main(args):
model, autocast_dtype = setup_and_build_model(args)
eval_knn_with_model(
model=model,
output_dir=args.output_dir,
train_dataset_str=args.train_dataset_str,
val_dataset_str=args.val_dataset_str,
nb_knn=args.nb_knn,
temperature=args.temperature,
autocast_dtype=autocast_dtype,
accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY,
transform=None,
gather_on_cpu=args.gather_on_cpu,
batch_size=args.batch_size,
num_workers=5,
n_per_class_list=args.n_per_class_list,
n_tries=args.n_tries,
)
return 0
if __name__ == "__main__":
description = "DINOv2 k-NN evaluation"
args_parser = get_args_parser(description=description)
args = args_parser.parse_args()
sys.exit(main(args))
|