x3d / resnet.py
zhong-al
Get rid of helper dir
4f41c8c
raw
history blame
32.3 kB
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""Video models."""
import torch
import torch.nn as nn
from pytorchvideo.layers.swish import Swish
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
"""
Stochastic Depth per sample.
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (
x.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
mask.floor_() # binarize
output = x.div(keep_prob) * mask
return output
class Nonlocal(nn.Module):
"""
Builds Non-local Neural Networks as a generic family of building
blocks for capturing long-range dependencies. Non-local Network
computes the response at a position as a weighted sum of the
features at all positions. This building block can be plugged into
many computer vision architectures.
More details in the paper: https://arxiv.org/pdf/1711.07971.pdf
"""
def __init__(
self,
dim,
dim_inner,
pool_size=None,
instantiation="softmax",
zero_init_final_conv=False,
zero_init_final_norm=True,
norm_eps=1e-5,
norm_momentum=0.1,
norm_module=nn.BatchNorm3d,
):
"""
Args:
dim (int): number of dimension for the input.
dim_inner (int): number of dimension inside of the Non-local block.
pool_size (list): the kernel size of spatial temporal pooling,
temporal pool kernel size, spatial pool kernel size, spatial
pool kernel size in order. By default pool_size is None,
then there would be no pooling used.
instantiation (string): supports two different instantiation method:
"dot_product": normalizing correlation matrix with L2.
"softmax": normalizing correlation matrix with Softmax.
zero_init_final_conv (bool): If true, zero initializing the final
convolution of the Non-local block.
zero_init_final_norm (bool):
If true, zero initializing the final batch norm of the Non-local
block.
norm_module (nn.Module): nn.Module for the normalization layer. The
default is nn.BatchNorm3d.
"""
super(Nonlocal, self).__init__()
self.dim = dim
self.dim_inner = dim_inner
self.pool_size = pool_size
self.instantiation = instantiation
self.use_pool = (
False if pool_size is None else any((size > 1 for size in pool_size))
)
self.norm_eps = norm_eps
self.norm_momentum = norm_momentum
self._construct_nonlocal(
zero_init_final_conv, zero_init_final_norm, norm_module
)
def _construct_nonlocal(
self, zero_init_final_conv, zero_init_final_norm, norm_module
):
# Three convolution heads: theta, phi, and g.
self.conv_theta = nn.Conv3d(
self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0
)
self.conv_phi = nn.Conv3d(
self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0
)
self.conv_g = nn.Conv3d(
self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0
)
# Final convolution output.
self.conv_out = nn.Conv3d(
self.dim_inner, self.dim, kernel_size=1, stride=1, padding=0
)
# Zero initializing the final convolution output.
self.conv_out.zero_init = zero_init_final_conv
# TODO: change the name to `norm`
self.bn = norm_module(
num_features=self.dim,
eps=self.norm_eps,
momentum=self.norm_momentum,
)
# Zero initializing the final bn.
self.bn.transform_final_bn = zero_init_final_norm
# Optional to add the spatial-temporal pooling.
if self.use_pool:
self.pool = nn.MaxPool3d(
kernel_size=self.pool_size,
stride=self.pool_size,
padding=[0, 0, 0],
)
def forward(self, x):
x_identity = x
N, C, T, H, W = x.size()
theta = self.conv_theta(x)
# Perform temporal-spatial pooling to reduce the computation.
if self.use_pool:
x = self.pool(x)
phi = self.conv_phi(x)
g = self.conv_g(x)
theta = theta.view(N, self.dim_inner, -1)
phi = phi.view(N, self.dim_inner, -1)
g = g.view(N, self.dim_inner, -1)
# (N, C, TxHxW) * (N, C, TxHxW) => (N, TxHxW, TxHxW).
theta_phi = torch.einsum("nct,ncp->ntp", (theta, phi))
# For original Non-local paper, there are two main ways to normalize
# the affinity tensor:
# 1) Softmax normalization (norm on exp).
# 2) dot_product normalization.
if self.instantiation == "softmax":
# Normalizing the affinity tensor theta_phi before softmax.
theta_phi = theta_phi * (self.dim_inner**-0.5)
theta_phi = nn.functional.softmax(theta_phi, dim=2)
elif self.instantiation == "dot_product":
spatial_temporal_dim = theta_phi.shape[2]
theta_phi = theta_phi / spatial_temporal_dim
else:
raise NotImplementedError("Unknown norm type {}".format(self.instantiation))
# (N, TxHxW, TxHxW) * (N, C, TxHxW) => (N, C, TxHxW).
theta_phi_g = torch.einsum("ntg,ncg->nct", (theta_phi, g))
# (N, C, TxHxW) => (N, C, T, H, W).
theta_phi_g = theta_phi_g.view(N, self.dim_inner, T, H, W)
p = self.conv_out(theta_phi_g)
p = self.bn(p)
return x_identity + p
class SE(nn.Module):
"""Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid."""
def _round_width(self, width, multiplier, min_width=8, divisor=8):
"""
Round width of filters based on width multiplier
Args:
width (int): the channel dimensions of the input.
multiplier (float): the multiplication factor.
min_width (int): the minimum width after multiplication.
divisor (int): the new width should be dividable by divisor.
"""
if not multiplier:
return width
width *= multiplier
min_width = min_width or divisor
width_out = max(min_width, int(width + divisor / 2) // divisor * divisor)
if width_out < 0.9 * width:
width_out += divisor
return int(width_out)
def __init__(self, dim_in, ratio, relu_act=True):
"""
Args:
dim_in (int): the channel dimensions of the input.
ratio (float): the channel reduction ratio for squeeze.
relu_act (bool): whether to use ReLU activation instead
of Swish (default).
divisor (int): the new width should be dividable by divisor.
"""
super(SE, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
dim_fc = self._round_width(dim_in, ratio)
self.fc1 = nn.Conv3d(dim_in, dim_fc, 1, bias=True)
self.fc1_act = nn.ReLU() if relu_act else Swish()
self.fc2 = nn.Conv3d(dim_fc, dim_in, 1, bias=True)
self.fc2_sig = nn.Sigmoid()
def forward(self, x):
x_in = x
for module in self.children():
x = module(x)
return x_in * x
def get_trans_func(name):
"""
Retrieves the transformation module by name.
"""
trans_funcs = {
"bottleneck_transform": BottleneckTransform,
"basic_transform": BasicTransform,
"x3d_transform": X3DTransform,
}
assert (
name in trans_funcs.keys()
), "Transformation function '{}' not supported".format(name)
return trans_funcs[name]
class BasicTransform(nn.Module):
"""
Basic transformation: Tx3x3, 1x3x3, where T is the size of temporal kernel.
"""
def __init__(
self,
dim_in,
dim_out,
temp_kernel_size,
stride,
dim_inner=None,
num_groups=1,
stride_1x1=None,
inplace_relu=True,
eps=1e-5,
bn_mmt=0.1,
dilation=1,
norm_module=nn.BatchNorm3d,
block_idx=0,
):
"""
Args:
dim_in (int): the channel dimensions of the input.
dim_out (int): the channel dimension of the output.
temp_kernel_size (int): the temporal kernel sizes of the first
convolution in the basic block.
stride (int): the stride of the bottleneck.
dim_inner (None): the inner dimension would not be used in
BasicTransform.
num_groups (int): number of groups for the convolution. Number of
group is always 1 for BasicTransform.
stride_1x1 (None): stride_1x1 will not be used in BasicTransform.
inplace_relu (bool): if True, calculate the relu on the original
input without allocating new memory.
eps (float): epsilon for batch norm.
bn_mmt (float): momentum for batch norm. Noted that BN momentum in
PyTorch = 1 - BN momentum in Caffe2.
norm_module (nn.Module): nn.Module for the normalization layer. The
default is nn.BatchNorm3d.
"""
super(BasicTransform, self).__init__()
self.temp_kernel_size = temp_kernel_size
self._inplace_relu = inplace_relu
self._eps = eps
self._bn_mmt = bn_mmt
self._construct(dim_in, dim_out, stride, dilation, norm_module)
def _construct(self, dim_in, dim_out, stride, dilation, norm_module):
# Tx3x3, BN, ReLU.
self.a = nn.Conv3d(
dim_in,
dim_out,
kernel_size=[self.temp_kernel_size, 3, 3],
stride=[1, stride, stride],
padding=[int(self.temp_kernel_size // 2), 1, 1],
bias=False,
)
self.a_bn = norm_module(
num_features=dim_out, eps=self._eps, momentum=self._bn_mmt
)
self.a_relu = nn.ReLU(inplace=self._inplace_relu)
# 1x3x3, BN.
self.b = nn.Conv3d(
dim_out,
dim_out,
kernel_size=[1, 3, 3],
stride=[1, 1, 1],
padding=[0, dilation, dilation],
dilation=[1, dilation, dilation],
bias=False,
)
self.b.final_conv = True
self.b_bn = norm_module(
num_features=dim_out, eps=self._eps, momentum=self._bn_mmt
)
self.b_bn.transform_final_bn = True
def forward(self, x):
x = self.a(x)
x = self.a_bn(x)
x = self.a_relu(x)
x = self.b(x)
x = self.b_bn(x)
return x
class X3DTransform(nn.Module):
"""
X3D transformation: 1x1x1, Tx3x3 (channelwise, num_groups=dim_in), 1x1x1,
augmented with (optional) SE (squeeze-excitation) on the 3x3x3 output.
T is the temporal kernel size (defaulting to 3)
"""
def __init__(
self,
dim_in,
dim_out,
temp_kernel_size,
stride,
dim_inner,
num_groups,
stride_1x1=False,
inplace_relu=True,
eps=1e-5,
bn_mmt=0.1,
dilation=1,
norm_module=nn.BatchNorm3d,
se_ratio=0.0625,
swish_inner=True,
block_idx=0,
):
"""
Args:
dim_in (int): the channel dimensions of the input.
dim_out (int): the channel dimension of the output.
temp_kernel_size (int): the temporal kernel sizes of the middle
convolution in the bottleneck.
stride (int): the stride of the bottleneck.
dim_inner (int): the inner dimension of the block.
num_groups (int): number of groups for the convolution. num_groups=1
is for standard ResNet like networks, and num_groups>1 is for
ResNeXt like networks.
stride_1x1 (bool): if True, apply stride to 1x1 conv, otherwise
apply stride to the 3x3 conv.
inplace_relu (bool): if True, calculate the relu on the original
input without allocating new memory.
eps (float): epsilon for batch norm.
bn_mmt (float): momentum for batch norm. Noted that BN momentum in
PyTorch = 1 - BN momentum in Caffe2.
dilation (int): size of dilation.
norm_module (nn.Module): nn.Module for the normalization layer. The
default is nn.BatchNorm3d.
se_ratio (float): if > 0, apply SE to the Tx3x3 conv, with the SE
channel dimensionality being se_ratio times the Tx3x3 conv dim.
swish_inner (bool): if True, apply swish to the Tx3x3 conv, otherwise
apply ReLU to the Tx3x3 conv.
"""
super(X3DTransform, self).__init__()
self.temp_kernel_size = temp_kernel_size
self._inplace_relu = inplace_relu
self._eps = eps
self._bn_mmt = bn_mmt
self._se_ratio = se_ratio
self._swish_inner = swish_inner
self._stride_1x1 = stride_1x1
self._block_idx = block_idx
self._construct(
dim_in,
dim_out,
stride,
dim_inner,
num_groups,
dilation,
norm_module,
)
def _construct(
self,
dim_in,
dim_out,
stride,
dim_inner,
num_groups,
dilation,
norm_module,
):
(str1x1, str3x3) = (stride, 1) if self._stride_1x1 else (1, stride)
# 1x1x1, BN, ReLU.
self.a = nn.Conv3d(
dim_in,
dim_inner,
kernel_size=[1, 1, 1],
stride=[1, str1x1, str1x1],
padding=[0, 0, 0],
bias=False,
)
self.a_bn = norm_module(
num_features=dim_inner, eps=self._eps, momentum=self._bn_mmt
)
self.a_relu = nn.ReLU(inplace=self._inplace_relu)
# Tx3x3, BN, ReLU.
self.b = nn.Conv3d(
dim_inner,
dim_inner,
[self.temp_kernel_size, 3, 3],
stride=[1, str3x3, str3x3],
padding=[int(self.temp_kernel_size // 2), dilation, dilation],
groups=num_groups,
bias=False,
dilation=[1, dilation, dilation],
)
self.b_bn = norm_module(
num_features=dim_inner, eps=self._eps, momentum=self._bn_mmt
)
# Apply SE attention or not
use_se = True if (self._block_idx + 1) % 2 else False
if self._se_ratio > 0.0 and use_se:
self.se = SE(dim_inner, self._se_ratio)
if self._swish_inner:
self.b_relu = Swish()
else:
self.b_relu = nn.ReLU(inplace=self._inplace_relu)
# 1x1x1, BN.
self.c = nn.Conv3d(
dim_inner,
dim_out,
kernel_size=[1, 1, 1],
stride=[1, 1, 1],
padding=[0, 0, 0],
bias=False,
)
self.c_bn = norm_module(
num_features=dim_out, eps=self._eps, momentum=self._bn_mmt
)
self.c_bn.transform_final_bn = True
def forward(self, x):
for block in self.children():
x = block(x)
return x
class BottleneckTransform(nn.Module):
"""
Bottleneck transformation: Tx1x1, 1x3x3, 1x1x1, where T is the size of
temporal kernel.
"""
def __init__(
self,
dim_in,
dim_out,
temp_kernel_size,
stride,
dim_inner,
num_groups,
stride_1x1=False,
inplace_relu=True,
eps=1e-5,
bn_mmt=0.1,
dilation=1,
norm_module=nn.BatchNorm3d,
block_idx=0,
):
"""
Args:
dim_in (int): the channel dimensions of the input.
dim_out (int): the channel dimension of the output.
temp_kernel_size (int): the temporal kernel sizes of the first
convolution in the bottleneck.
stride (int): the stride of the bottleneck.
dim_inner (int): the inner dimension of the block.
num_groups (int): number of groups for the convolution. num_groups=1
is for standard ResNet like networks, and num_groups>1 is for
ResNeXt like networks.
stride_1x1 (bool): if True, apply stride to 1x1 conv, otherwise
apply stride to the 3x3 conv.
inplace_relu (bool): if True, calculate the relu on the original
input without allocating new memory.
eps (float): epsilon for batch norm.
bn_mmt (float): momentum for batch norm. Noted that BN momentum in
PyTorch = 1 - BN momentum in Caffe2.
dilation (int): size of dilation.
norm_module (nn.Module): nn.Module for the normalization layer. The
default is nn.BatchNorm3d.
"""
super(BottleneckTransform, self).__init__()
self.temp_kernel_size = temp_kernel_size
self._inplace_relu = inplace_relu
self._eps = eps
self._bn_mmt = bn_mmt
self._stride_1x1 = stride_1x1
self._construct(
dim_in,
dim_out,
stride,
dim_inner,
num_groups,
dilation,
norm_module,
)
def _construct(
self,
dim_in,
dim_out,
stride,
dim_inner,
num_groups,
dilation,
norm_module,
):
(str1x1, str3x3) = (stride, 1) if self._stride_1x1 else (1, stride)
# Tx1x1, BN, ReLU.
self.a = nn.Conv3d(
dim_in,
dim_inner,
kernel_size=[self.temp_kernel_size, 1, 1],
stride=[1, str1x1, str1x1],
padding=[int(self.temp_kernel_size // 2), 0, 0],
bias=False,
)
self.a_bn = norm_module(
num_features=dim_inner, eps=self._eps, momentum=self._bn_mmt
)
self.a_relu = nn.ReLU(inplace=self._inplace_relu)
# 1x3x3, BN, ReLU.
self.b = nn.Conv3d(
dim_inner,
dim_inner,
[1, 3, 3],
stride=[1, str3x3, str3x3],
padding=[0, dilation, dilation],
groups=num_groups,
bias=False,
dilation=[1, dilation, dilation],
)
self.b_bn = norm_module(
num_features=dim_inner, eps=self._eps, momentum=self._bn_mmt
)
self.b_relu = nn.ReLU(inplace=self._inplace_relu)
# 1x1x1, BN.
self.c = nn.Conv3d(
dim_inner,
dim_out,
kernel_size=[1, 1, 1],
stride=[1, 1, 1],
padding=[0, 0, 0],
bias=False,
)
self.c.final_conv = True
self.c_bn = norm_module(
num_features=dim_out, eps=self._eps, momentum=self._bn_mmt
)
self.c_bn.transform_final_bn = True
def forward(self, x):
# Explicitly forward every layer.
# Branch2a.
x = self.a(x)
x = self.a_bn(x)
x = self.a_relu(x)
# Branch2b.
x = self.b(x)
x = self.b_bn(x)
x = self.b_relu(x)
# Branch2c
x = self.c(x)
x = self.c_bn(x)
return x
class ResBlock(nn.Module):
"""
Residual block.
"""
def __init__(
self,
dim_in,
dim_out,
temp_kernel_size,
stride,
trans_func,
dim_inner,
num_groups=1,
stride_1x1=False,
inplace_relu=True,
eps=1e-5,
bn_mmt=0.1,
dilation=1,
norm_module=nn.BatchNorm3d,
block_idx=0,
drop_connect_rate=0.0,
):
"""
ResBlock class constructs redisual blocks. More details can be found in:
Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun.
"Deep residual learning for image recognition."
https://arxiv.org/abs/1512.03385
Args:
dim_in (int): the channel dimensions of the input.
dim_out (int): the channel dimension of the output.
temp_kernel_size (int): the temporal kernel sizes of the middle
convolution in the bottleneck.
stride (int): the stride of the bottleneck.
trans_func (string): transform function to be used to construct the
bottleneck.
dim_inner (int): the inner dimension of the block.
num_groups (int): number of groups for the convolution. num_groups=1
is for standard ResNet like networks, and num_groups>1 is for
ResNeXt like networks.
stride_1x1 (bool): if True, apply stride to 1x1 conv, otherwise
apply stride to the 3x3 conv.
inplace_relu (bool): calculate the relu on the original input
without allocating new memory.
eps (float): epsilon for batch norm.
bn_mmt (float): momentum for batch norm. Noted that BN momentum in
PyTorch = 1 - BN momentum in Caffe2.
dilation (int): size of dilation.
norm_module (nn.Module): nn.Module for the normalization layer. The
default is nn.BatchNorm3d.
drop_connect_rate (float): basic rate at which blocks are dropped,
linearly increases from input to output blocks.
"""
super(ResBlock, self).__init__()
self._inplace_relu = inplace_relu
self._eps = eps
self._bn_mmt = bn_mmt
self._drop_connect_rate = drop_connect_rate
self._construct(
dim_in,
dim_out,
temp_kernel_size,
stride,
trans_func,
dim_inner,
num_groups,
stride_1x1,
inplace_relu,
dilation,
norm_module,
block_idx,
)
def _construct(
self,
dim_in,
dim_out,
temp_kernel_size,
stride,
trans_func,
dim_inner,
num_groups,
stride_1x1,
inplace_relu,
dilation,
norm_module,
block_idx,
):
# Use skip connection with projection if dim or res change.
if (dim_in != dim_out) or (stride != 1):
self.branch1 = nn.Conv3d(
dim_in,
dim_out,
kernel_size=1,
stride=[1, stride, stride],
padding=0,
bias=False,
dilation=1,
)
self.branch1_bn = norm_module(
num_features=dim_out, eps=self._eps, momentum=self._bn_mmt
)
self.branch2 = trans_func(
dim_in,
dim_out,
temp_kernel_size,
stride,
dim_inner,
num_groups,
stride_1x1=stride_1x1,
inplace_relu=inplace_relu,
dilation=dilation,
norm_module=norm_module,
block_idx=block_idx,
)
self.relu = nn.ReLU(self._inplace_relu)
def forward(self, x):
f_x = self.branch2(x)
if self.training and self._drop_connect_rate > 0.0:
f_x = drop_path(f_x, self._drop_connect_rate)
if hasattr(self, "branch1"):
x = self.branch1_bn(self.branch1(x)) + f_x
else:
x = x + f_x
x = self.relu(x)
return x
class ResStage(nn.Module):
"""
Stage of 3D ResNet. It expects to have one or more tensors as input for
single pathway (C2D, I3D, Slow), and multi-pathway (SlowFast) cases.
More details can be found here:
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He.
"SlowFast networks for video recognition."
https://arxiv.org/pdf/1812.03982.pdf
"""
def __init__(
self,
dim_in,
dim_out,
stride,
temp_kernel_sizes,
num_blocks,
dim_inner,
num_groups,
num_block_temp_kernel,
nonlocal_inds,
nonlocal_group,
nonlocal_pool,
dilation,
instantiation="softmax",
trans_func_name="bottleneck_transform",
stride_1x1=False,
inplace_relu=True,
norm_module=nn.BatchNorm3d,
drop_connect_rate=0.0,
):
"""
The `__init__` method of any subclass should also contain these arguments.
ResStage builds p streams, where p can be greater or equal to one.
Args:
dim_in (list): list of p the channel dimensions of the input.
Different channel dimensions control the input dimension of
different pathways.
dim_out (list): list of p the channel dimensions of the output.
Different channel dimensions control the input dimension of
different pathways.
temp_kernel_sizes (list): list of the p temporal kernel sizes of the
convolution in the bottleneck. Different temp_kernel_sizes
control different pathway.
stride (list): list of the p strides of the bottleneck. Different
stride control different pathway.
num_blocks (list): list of p numbers of blocks for each of the
pathway.
dim_inner (list): list of the p inner channel dimensions of the
input. Different channel dimensions control the input dimension
of different pathways.
num_groups (list): list of number of p groups for the convolution.
num_groups=1 is for standard ResNet like networks, and
num_groups>1 is for ResNeXt like networks.
num_block_temp_kernel (list): extent the temp_kernel_sizes to
num_block_temp_kernel blocks, then fill temporal kernel size
of 1 for the rest of the layers.
nonlocal_inds (list): If the tuple is empty, no nonlocal layer will
be added. If the tuple is not empty, add nonlocal layers after
the index-th block.
dilation (list): size of dilation for each pathway.
nonlocal_group (list): list of number of p nonlocal groups. Each
number controls how to fold temporal dimension to batch
dimension before applying nonlocal transformation.
https://github.com/facebookresearch/video-nonlocal-net.
instantiation (string): different instantiation for nonlocal layer.
Supports two different instantiation method:
"dot_product": normalizing correlation matrix with L2.
"softmax": normalizing correlation matrix with Softmax.
trans_func_name (string): name of the the transformation function apply
on the network.
norm_module (nn.Module): nn.Module for the normalization layer. The
default is nn.BatchNorm3d.
drop_connect_rate (float): basic rate at which blocks are dropped,
linearly increases from input to output blocks.
"""
super(ResStage, self).__init__()
assert all(
(
num_block_temp_kernel[i] <= num_blocks[i]
for i in range(len(temp_kernel_sizes))
)
)
self.num_blocks = num_blocks
self.nonlocal_group = nonlocal_group
self._drop_connect_rate = drop_connect_rate
self.temp_kernel_sizes = [
(temp_kernel_sizes[i] * num_blocks[i])[: num_block_temp_kernel[i]]
+ [1] * (num_blocks[i] - num_block_temp_kernel[i])
for i in range(len(temp_kernel_sizes))
]
assert (
len(
{
len(dim_in),
len(dim_out),
len(temp_kernel_sizes),
len(stride),
len(num_blocks),
len(dim_inner),
len(num_groups),
len(num_block_temp_kernel),
len(nonlocal_inds),
len(nonlocal_group),
}
)
== 1
)
self.num_pathways = len(self.num_blocks)
self._construct(
dim_in,
dim_out,
stride,
dim_inner,
num_groups,
trans_func_name,
stride_1x1,
inplace_relu,
nonlocal_inds,
nonlocal_pool,
instantiation,
dilation,
norm_module,
)
def _construct(
self,
dim_in,
dim_out,
stride,
dim_inner,
num_groups,
trans_func_name,
stride_1x1,
inplace_relu,
nonlocal_inds,
nonlocal_pool,
instantiation,
dilation,
norm_module,
):
for pathway in range(self.num_pathways):
for i in range(self.num_blocks[pathway]):
# Retrieve the transformation function.
trans_func = get_trans_func(trans_func_name)
# Construct the block.
res_block = ResBlock(
dim_in[pathway] if i == 0 else dim_out[pathway],
dim_out[pathway],
self.temp_kernel_sizes[pathway][i],
stride[pathway] if i == 0 else 1,
trans_func,
dim_inner[pathway],
num_groups[pathway],
stride_1x1=stride_1x1,
inplace_relu=inplace_relu,
dilation=dilation[pathway],
norm_module=norm_module,
block_idx=i,
drop_connect_rate=self._drop_connect_rate,
)
self.add_module("pathway{}_res{}".format(
pathway, i), res_block)
if i in nonlocal_inds[pathway]:
nln = Nonlocal(
dim_out[pathway],
dim_out[pathway] // 2,
nonlocal_pool[pathway],
instantiation=instantiation,
norm_module=norm_module,
)
self.add_module(
"pathway{}_nonlocal{}".format(pathway, i), nln)
def forward(self, inputs):
output = []
for pathway in range(self.num_pathways):
x = inputs[pathway]
for i in range(self.num_blocks[pathway]):
m = getattr(self, "pathway{}_res{}".format(pathway, i))
x = m(x)
if hasattr(self, "pathway{}_nonlocal{}".format(pathway, i)):
nln = getattr(
self, "pathway{}_nonlocal{}".format(pathway, i))
b, c, t, h, w = x.shape
if self.nonlocal_group[pathway] > 1:
# Fold temporal dimension into batch dimension.
x = x.permute(0, 2, 1, 3, 4)
x = x.reshape(
b * self.nonlocal_group[pathway],
t // self.nonlocal_group[pathway],
c,
h,
w,
)
x = x.permute(0, 2, 1, 3, 4)
x = nln(x)
if self.nonlocal_group[pathway] > 1:
# Fold back to temporal dimension.
x = x.permute(0, 2, 1, 3, 4)
x = x.reshape(b, t, c, h, w)
x = x.permute(0, 2, 1, 3, 4)
output.append(x)
return output