ianpan's picture
update models, output, examples
455e8ef
raw
history blame
876 Bytes
"""
Contains commonly used neural net modules.
"""
import math
import torch
import torch.nn as nn
class FeatureReduction(nn.Module):
"""
Reduce feature dimensionality
Intended use is after the last layer of the neural net backbone, before pooling
Grouped convolution is used to reduce # of extra parameters
"""
def __init__(self, feature_dim: int, reduce_feature_dim: int):
super().__init__()
groups = math.gcd(feature_dim, reduce_feature_dim)
self.reduce = nn.Conv2d(
feature_dim,
reduce_feature_dim,
groups=groups,
kernel_size=1,
stride=1,
bias=False,
)
self.bn = nn.BatchNorm2d(reduce_feature_dim)
self.act = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.act(self.bn(self.reduce(x)))