#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """BatchNorm (BN) utility functions and custom batch-size BN implementations""" from functools import partial import torch import torch.nn as nn from pytorchvideo.layers.batch_norm import NaiveSyncBatchNorm3d def get_norm(cfg): """ Args: cfg (CfgNode): model building configs, details are in the comments of the config file. Returns: nn.Module: the normalization layer. """ if cfg.BN.NORM_TYPE in {"batchnorm", "sync_batchnorm_apex"}: return nn.BatchNorm3d elif cfg.BN.NORM_TYPE == "sub_batchnorm": return partial(SubBatchNorm3d, num_splits=cfg.BN.NUM_SPLITS) elif cfg.BN.NORM_TYPE == "sync_batchnorm": return partial( NaiveSyncBatchNorm3d, num_sync_devices=cfg.BN.NUM_SYNC_DEVICES, global_sync=cfg.BN.GLOBAL_SYNC, ) else: raise NotImplementedError( "Norm type {} is not supported".format(cfg.BN.NORM_TYPE) ) class SubBatchNorm3d(nn.Module): """ The standard BN layer computes stats across all examples in a GPU. In some cases it is desirable to compute stats across only a subset of examples (e.g., in multigrid training https://arxiv.org/abs/1912.00998). SubBatchNorm3d splits the batch dimension into N splits, and run BN on each of them separately (so that the stats are computed on each subset of examples (1/N of batch) independently. During evaluation, it aggregates the stats from all splits into one BN. """ def __init__(self, num_splits, **args): """ Args: num_splits (int): number of splits. args (list): other arguments. """ super(SubBatchNorm3d, self).__init__() self.num_splits = num_splits num_features = args["num_features"] # Keep only one set of weight and bias. if args.get("affine", True): self.affine = True args["affine"] = False self.weight = torch.nn.Parameter(torch.ones(num_features)) self.bias = torch.nn.Parameter(torch.zeros(num_features)) else: self.affine = False self.bn = nn.BatchNorm3d(**args) args["num_features"] = num_features * num_splits self.split_bn = nn.BatchNorm3d(**args) def _get_aggregated_mean_std(self, means, stds, n): """ Calculate the aggregated mean and stds. Args: means (tensor): mean values. stds (tensor): standard deviations. n (int): number of sets of means and stds. """ mean = means.view(n, -1).sum(0) / n std = ( stds.view(n, -1).sum(0) / n + ((means.view(n, -1) - mean) ** 2).view(n, -1).sum(0) / n ) return mean.detach(), std.detach() def aggregate_stats(self): """ Synchronize running_mean, and running_var. Call this before eval. """ if self.split_bn.track_running_stats: ( self.bn.running_mean.data, self.bn.running_var.data, ) = self._get_aggregated_mean_std( self.split_bn.running_mean, self.split_bn.running_var, self.num_splits, ) def forward(self, x): if self.training: n, c, t, h, w = x.shape x = x.view(n // self.num_splits, c * self.num_splits, t, h, w) x = self.split_bn(x) x = x.view(n, c, t, h, w) else: x = self.bn(x) if self.affine: x = x * self.weight.view((-1, 1, 1, 1)) x = x + self.bias.view((-1, 1, 1, 1)) return x