x3d / x3d.py
zhong-al
Get rid of helper dir
4f41c8c
raw
history blame
11.6 kB
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import math
import torch
from torch import nn
from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks_default
from fvcore.nn.weight_init import c2_msra_fill, c2_xavier_fill
from .norm import get_norm
from .stem import VideoModelStem
from .resnet import ResStage
from .head import X3DHead
# round width
def round_width(width, multiplier, min_width=1, divisor=1):
if not multiplier:
return width
width *= multiplier
min_width = min_width or divisor
width_out = max(min_width, int(width + divisor / 2) // divisor * divisor)
if width_out < 0.9 * width:
width_out += divisor
return int(width_out)
# init weights
def init_weights(
model, fc_init_std=0.01, zero_init_final_bn=True, zero_init_final_conv=False
):
"""
Performs ResNet style weight initialization.
Args:
fc_init_std (float): the expected standard deviation for fc layer.
zero_init_final_bn (bool): if True, zero initialize the final bn for
every bottleneck.
"""
for m in model.modules():
if isinstance(m, nn.Conv3d):
# Note that there is no bias due to BN
if hasattr(m, "final_conv") and zero_init_final_conv:
m.weight.data.zero_()
else:
"""
Follow the initialization method proposed in:
{He, Kaiming, et al.
"Delving deep into rectifiers: Surpassing human-level
performance on imagenet classification."
arXiv preprint arXiv:1502.01852 (2015)}
"""
c2_msra_fill(m)
elif isinstance(m, (nn.BatchNorm3d, nn.BatchNorm2d, nn.BatchNorm1d)):
if (
hasattr(m, "transform_final_bn")
and m.transform_final_bn
and zero_init_final_bn
):
batchnorm_weight = 0.0
else:
batchnorm_weight = 1.0
if m.weight is not None:
m.weight.data.fill_(batchnorm_weight)
if m.bias is not None:
m.bias.data.zero_()
if isinstance(m, nn.Linear):
if hasattr(m, "xavier_init") and m.xavier_init:
c2_xavier_fill(m)
else:
m.weight.data.normal_(mean=0.0, std=fc_init_std)
if m.bias is not None:
m.bias.data.zero_()
# pool1
_POOL1 = {
"2d": [[1, 1, 1]],
"c2d": [[2, 1, 1]],
"slow_c2d": [[1, 1, 1]],
"i3d": [[2, 1, 1]],
"slow_i3d": [[1, 1, 1]],
"slow": [[1, 1, 1]],
"slowfast": [[1, 1, 1], [1, 1, 1]],
"x3d": [[1, 1, 1]],
}
# temporal kernel basis
_TEMPORAL_KERNEL_BASIS = {
"2d": [
[[1]], # conv1 temporal kernel.
[[1]], # res2 temporal kernel.
[[1]], # res3 temporal kernel.
[[1]], # res4 temporal kernel.
[[1]], # res5 temporal kernel.
],
"c2d": [
[[1]], # conv1 temporal kernel.
[[1]], # res2 temporal kernel.
[[1]], # res3 temporal kernel.
[[1]], # res4 temporal kernel.
[[1]], # res5 temporal kernel.
],
"slow_c2d": [
[[1]], # conv1 temporal kernel.
[[1]], # res2 temporal kernel.
[[1]], # res3 temporal kernel.
[[1]], # res4 temporal kernel.
[[1]], # res5 temporal kernel.
],
"i3d": [
[[5]], # conv1 temporal kernel.
[[3]], # res2 temporal kernel.
[[3, 1]], # res3 temporal kernel.
[[3, 1]], # res4 temporal kernel.
[[1, 3]], # res5 temporal kernel.
],
"slow_i3d": [
[[5]], # conv1 temporal kernel.
[[3]], # res2 temporal kernel.
[[3, 1]], # res3 temporal kernel.
[[3, 1]], # res4 temporal kernel.
[[1, 3]], # res5 temporal kernel.
],
"slow": [
[[1]], # conv1 temporal kernel.
[[1]], # res2 temporal kernel.
[[1]], # res3 temporal kernel.
[[3]], # res4 temporal kernel.
[[3]], # res5 temporal kernel.
],
"slowfast": [
[[1], [5]], # conv1 temporal kernel for slow and fast pathway.
[[1], [3]], # res2 temporal kernel for slow and fast pathway.
[[1], [3]], # res3 temporal kernel for slow and fast pathway.
[[3], [3]], # res4 temporal kernel for slow and fast pathway.
[[3], [3]], # res5 temporal kernel for slow and fast pathway.
],
"x3d": [
[[5]], # conv1 temporal kernels.
[[3]], # res2 temporal kernels.
[[3]], # res3 temporal kernels.
[[3]], # res4 temporal kernels.
[[3]], # res5 temporal kernels.
],
}
# model stage depth
_MODEL_STAGE_DEPTH = {18: (2, 2, 2, 2), 50: (3, 4, 6, 3), 101: (3, 4, 23, 3)}
# X3D model
class X3D(nn.Module):
"""
X3D model builder. It builds a X3D network backbone, which is a ResNet.
Christoph Feichtenhofer.
"X3D: Expanding Architectures for Efficient Video Recognition."
https://arxiv.org/abs/2004.04730
"""
def __init__(self, cfg):
"""
The `__init__` method of any subclass should also contain these
arguments.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
super(X3D, self).__init__()
self.norm_module = get_norm(cfg)
self.enable_detection = cfg.DETECTION.ENABLE
self.num_pathways = 1
exp_stage = 2.0
self.dim_c1 = cfg.X3D.DIM_C1
self.dim_res2 = (
round_width(self.dim_c1, exp_stage, divisor=8)
if cfg.X3D.SCALE_RES2
else self.dim_c1
)
self.dim_res3 = round_width(self.dim_res2, exp_stage, divisor=8)
self.dim_res4 = round_width(self.dim_res3, exp_stage, divisor=8)
self.dim_res5 = round_width(self.dim_res4, exp_stage, divisor=8)
self.block_basis = [
# blocks, c, stride
[1, self.dim_res2, 2],
[2, self.dim_res3, 2],
[5, self.dim_res4, 2],
[3, self.dim_res5, 2],
]
self._construct_network(cfg)
init_weights(
self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN
)
def _round_repeats(self, repeats, multiplier):
"""Round number of layers based on depth multiplier."""
if not multiplier:
return repeats
return int(math.ceil(multiplier * repeats))
def _construct_network(self, cfg):
"""
Builds a single pathway X3D model.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
assert cfg.MODEL.ARCH in _POOL1.keys()
assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys()
(d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH]
num_groups = cfg.RESNET.NUM_GROUPS
width_per_group = cfg.RESNET.WIDTH_PER_GROUP
dim_inner = num_groups * width_per_group
w_mul = cfg.X3D.WIDTH_FACTOR
d_mul = cfg.X3D.DEPTH_FACTOR
dim_res1 = round_width(self.dim_c1, w_mul)
temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH]
self.s1 = VideoModelStem(
dim_in=cfg.DATA.INPUT_CHANNEL_NUM,
dim_out=[dim_res1],
kernel=[temp_kernel[0][0] + [3, 3]],
stride=[[1, 2, 2]],
padding=[[temp_kernel[0][0][0] // 2, 1, 1]],
norm_module=self.norm_module,
stem_func_name="x3d_stem",
)
# blob_in = s1
dim_in = dim_res1
for stage, block in enumerate(self.block_basis):
dim_out = round_width(block[1], w_mul)
dim_inner = int(cfg.X3D.BOTTLENECK_FACTOR * dim_out)
n_rep = self._round_repeats(block[0], d_mul)
# start w res2 to follow convention
prefix = "s{}".format(stage + 2)
s = ResStage(
dim_in=[dim_in],
dim_out=[dim_out],
dim_inner=[dim_inner],
temp_kernel_sizes=temp_kernel[1],
stride=[block[2]],
num_blocks=[n_rep],
num_groups=[dim_inner] if cfg.X3D.CHANNELWISE_3x3x3 else [
num_groups],
num_block_temp_kernel=[n_rep],
nonlocal_inds=cfg.NONLOCAL.LOCATION[0],
nonlocal_group=cfg.NONLOCAL.GROUP[0],
nonlocal_pool=cfg.NONLOCAL.POOL[0],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
norm_module=self.norm_module,
dilation=cfg.RESNET.SPATIAL_DILATIONS[stage],
drop_connect_rate=cfg.MODEL.DROPCONNECT_RATE
* (stage + 2)
/ (len(self.block_basis) + 1),
)
dim_in = dim_out
self.add_module(prefix, s)
if self.enable_detection:
NotImplementedError
else:
spat_sz = int(math.ceil(cfg.DATA.TRAIN_CROP_SIZE / 32.0))
self.head = X3DHead(
dim_in=dim_out,
dim_inner=dim_inner,
dim_out=cfg.X3D.DIM_C5,
num_classes=cfg.MODEL.NUM_CLASSES,
pool_size=[cfg.DATA.NUM_FRAMES, spat_sz, spat_sz],
dropout_rate=cfg.MODEL.DROPOUT_RATE,
act_func=cfg.MODEL.HEAD_ACT,
bn_lin5_on=cfg.X3D.BN_LIN5,
)
def forward(self, x, bboxes=None):
for module in self.children():
x = module(x)
return x
def build_model(cfg, gpu_id=None):
if torch.cuda.is_available():
assert (
cfg.NUM_GPUS <= torch.cuda.device_count()
), "Cannot use more GPU devices than available"
else:
assert (
cfg.NUM_GPUS == 0
), "Cuda is not available. Please set `NUM_GPUS: 0 for running on CPUs."
# Construct the model
model = X3D(cfg)
if cfg.BN.NORM_TYPE == "sync_batchnorm_apex":
try:
import apex
except ImportError:
raise ImportError("APEX is required for this model, pelase install")
process_group = apex.parallel.create_syncbn_process_group(
group_size=cfg.BN.NUM_SYNC_DEVICES
)
model = apex.parallel.convert_syncbn_model(model, process_group=process_group)
if cfg.NUM_GPUS:
if gpu_id is None:
# Determine the GPU used by the current process
cur_device = torch.cuda.current_device()
else:
cur_device = gpu_id
# Transfer the model to the current GPU device
model = model.cuda(device=cur_device)
# Use multi-process data parallel model in the multi-gpu setting
if cfg.NUM_GPUS > 1:
# Make model replica operate on the current device
model = torch.nn.parallel.DistributedDataParallel(
module=model,
device_ids=[cur_device],
output_device=cur_device,
find_unused_parameters=(
True
if cfg.MODEL.DETACH_FINAL_FC
or cfg.MODEL.MODEL_NAME == "ContrastiveModel"
else False
),
)
if cfg.MODEL.FP16_ALLREDUCE:
model.register_comm_hook(
state=None, hook=comm_hooks_default.fp16_compress_hook
)
return model