Spaces:
Running
on
Zero
Running
on
Zero
from torch import nn | |
import torch | |
## Referencing the paper: https://arxiv.org/pdf/1709.01507 | |
class SELayer(nn.Module): | |
""" | |
Squeeze-and-Excitation Layer (SELayer) for enhancing channel-wise feature responses. | |
The SELayer implements the Squeeze-and-Excitation block as proposed in the paper, | |
which adaptively recalibrates channel-wise feature responses by modeling the interdependencies | |
between channels. | |
Args: | |
channel (int): The number of input channels. | |
reduction (int): The reduction ratio for the number of channels in the bottleneck. | |
Default is 16. | |
""" | |
def __init__(self, channel, reduction=16): | |
super(SELayer, self).__init__() | |
# Adaptive average pooling to generate a global descriptor | |
self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
# Fully connected layers for the real part | |
self.fc_r = nn.Sequential( | |
nn.Linear(channel, channel // reduction), # Reduce channels | |
nn.ReLU(inplace=True), # Activation function | |
nn.Linear(channel // reduction, channel), # Restore channels | |
nn.Sigmoid() # Sigmoid activation to scale outputs | |
) | |
# Fully connected layers for the imaginary part | |
self.fc_i = nn.Sequential( | |
nn.Linear(channel, channel // reduction), # Reduce channels | |
nn.ReLU(inplace=True), # Activation function | |
nn.Linear(channel // reduction, channel), # Restore channels | |
nn.Sigmoid() # Sigmoid activation to scale outputs | |
) | |
def forward(self, x): | |
""" | |
Forward pass for the SELayer. | |
The forward method applies the squeeze-and-excitation operation on the input tensor `x`. | |
It computes the channel-wise attention weights for both the real and imaginary parts | |
of the input. | |
Args: | |
x (torch.Tensor): Input tensor of shape (B, C, D, H, W), where: | |
B - batch size, | |
C - number of channels, | |
D - depth, | |
H - height, | |
W - width. | |
Returns: | |
torch.Tensor: Output tensor after applying channel-wise attention, | |
same shape as input `x`. | |
""" | |
# Extract the batch size and number of channels | |
b, c, _, _, _ = x.size() | |
# Compute the squeeze operation for the real part | |
x_r = self.avg_pool(x[:, :, :, :, 0]).view(b, c) # Global average pooling | |
# Compute the squeeze operation for the imaginary part | |
x_i = self.avg_pool(x[:, :, :, :, 1]).view(b, c) # Global average pooling | |
# Calculate channel-wise attention for the real part | |
y_r = self.fc_r(x_r).view(b, c, 1, 1, 1) - self.fc_i(x_i).view(b, c, 1, 1, 1) | |
# Calculate channel-wise attention for the imaginary part | |
y_i = self.fc_r(x_i).view(b, c, 1, 1, 1) + self.fc_i(x_r).view(b, c, 1, 1, 1) | |
# Concatenate real and imaginary attention weights along the channel dimension | |
y = torch.cat([y_r, y_i], 4) | |
# Scale the input features by the attention weights | |
return x * y | |