Spaces:
Running
on
Zero
Running
on
Zero
from torch import nn | |
import torch | |
class SELayer(nn.Module): | |
""" | |
Squeeze-and-Excitation (SE) Layer. | |
This layer implements the Squeeze-and-Excitation mechanism, which adaptively | |
recalibrates channel-wise feature responses by explicitly modeling | |
interdependencies between channels. It enhances the representational power | |
of a neural network by emphasizing informative features while suppressing | |
less useful ones. | |
Args: | |
channel (int): The number of input channels. | |
reduction (int, optional): Reduction ratio for the dimensionality | |
of the intermediate representations. Default is 16. | |
""" | |
def __init__(self, channel, reduction=16): | |
super(SELayer, self).__init__() | |
# Adaptive average pooling to reduce spatial dimensions to 1x1 | |
self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
self.avg_pool_layer = nn.Sequential( | |
nn.Linear(channel, channel // reduction), # First linear layer | |
nn.ReLU(inplace=True), # Activation layer | |
nn.Linear(channel // reduction, channel), # Second linear layer | |
nn.Sigmoid() # Sigmoid activation for scaling | |
) | |
# Adaptive max pooling to reduce spatial dimensions to 1x1 | |
self.max_pool = nn.AdaptiveMaxPool2d(1) | |
self.max_pool_layer = nn.Sequential( | |
nn.Linear(channel, channel // reduction), # First linear layer | |
nn.ReLU(inplace=True), # Activation layer | |
nn.Linear(channel // reduction, channel), # Second linear layer | |
nn.Sigmoid() # Sigmoid activation for scaling | |
) | |
def forward(self, x): | |
""" | |
Forward pass for the SE Layer. | |
Args: | |
x (Tensor): Input tensor of shape (B, C, H, W) where: | |
B = batch size, | |
C = number of channels, | |
H = height, | |
W = width. | |
Returns: | |
Tensor: Output tensor of the same shape as input `x` after | |
applying the squeeze-and-excitation mechanism. | |
""" | |
b, c, _, _ = x.size() # Unpack input dimensions | |
x_avg = self.avg_pool(x).view(b, c) # Squeeze: apply average pooling | |
x_avg = self.avg_pool_layer(x_avg).view(b, c, 1, 1) # Excitation: pass through layers | |
x_max = self.max_pool(x).view(b, c) # Squeeze: apply max pooling | |
x_max = self.max_pool_layer(x_max).view(b, c, 1, 1) # Excitation: pass through layers | |
# Scale the input features by the computed channel weights | |
y = (x_avg + x_max) * x | |
return y # Return the recalibrated output | |