Spaces:
Running
on
Zero
Running
on
Zero
import math | |
from collections import OrderedDict | |
from typing import Dict, List, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from packaging.version import parse as V | |
from torch.nn import init | |
from torch.nn.parameter import Parameter | |
from models.mossformer_gan_se.fsmn import UniDeepFsmn | |
from models.mossformer_gan_se.conv_module import ConvModule | |
from models.mossformer_gan_se.mossformer import MossFormer | |
from models.mossformer_gan_se.se_layer import SELayer | |
from models.mossformer_gan_se.get_layer_from_string import get_layer | |
from models.mossformer_gan_se.discriminator import Discriminator | |
# Check if the installed version of PyTorch is 1.9.0 or higher | |
is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") | |
class MossFormerGAN_SE_16K(nn.Module): | |
""" | |
MossFormerGAN_SE_16K: A GAN-based speech enhancement model for 16kHz input audio. | |
This model integrates a synchronous attention network (SyncANet) for | |
feature extraction. Depending on the mode (train or inference), it may | |
also include a discriminator for adversarial training. | |
Args: | |
args (Namespace): Arguments containing configuration parameters, | |
including 'fft_len' and 'mode'. | |
""" | |
def __init__(self, args): | |
"""Initializes the MossFormerGAN_SE_16K model.""" | |
super(MossFormerGAN_SE_16K, self).__init__() | |
# Initialize SyncANet with specified number of channels and features | |
self.model = SyncANet(num_channel=64, num_features=args.fft_len // 2 + 1) | |
# Initialize discriminator if in training mode | |
if args.mode == 'train': | |
self.discriminator = Discriminator(ndf=16) | |
else: | |
self.discriminator = None | |
def forward(self, x): | |
""" | |
Defines the forward pass of the MossFormerGAN_SE_16K model. | |
Args: | |
x (torch.Tensor): Input tensor of shape [batch_size, num_channels, height, width]. | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor]: Output tensors representing the real and imaginary parts. | |
""" | |
output_real, output_imag = self.model(x) # Get real and imaginary outputs from the model | |
return output_real, output_imag # Return the outputs | |
class FSMN_Wrap(nn.Module): | |
""" | |
FSMN_Wrap: A wrapper around the UniDeepFsmn module to facilitate | |
integration into the larger model architecture. | |
Args: | |
nIn (int): Number of input features. | |
nHidden (int): Number of hidden features in the FSMN (default is 128). | |
lorder (int): Order of the FSMN (default is 20). | |
nOut (int): Number of output features (default is 128). | |
""" | |
def __init__(self, nIn, nHidden=128, lorder=20, nOut=128): | |
"""Initializes the FSMN_Wrap module with specified parameters.""" | |
super(FSMN_Wrap, self).__init__() | |
# Initialize the UniDeepFsmn module | |
self.fsmn = UniDeepFsmn(nIn, nHidden, lorder, nHidden) | |
def forward(self, x): | |
""" | |
Defines the forward pass of the FSMN_Wrap module. | |
Args: | |
x (torch.Tensor): Input tensor of shape [batch_size, channels, height, time, 2]. | |
Returns: | |
torch.Tensor: Output tensor reshaped to [batch_size, channels, height, time]. | |
""" | |
# Shape of input x: [b, c, h, T, 2] | |
b, c, T, h = x.size() | |
# Permute x to reshape it for FSMN processing: [b, T, h, c] | |
x = x.permute(0, 2, 3, 1) # Change dimensions to [b, T, h, c] | |
x = torch.reshape(x, (b * T, h, c)) # Reshape to [b*T, h, c] | |
# Pass through the FSMN | |
output = self.fsmn(x) # output: [b*T, h, c] | |
# Reshape output back to original dimensions | |
output = torch.reshape(output, (b, T, h, c)) # output: [b, T, h, c] | |
return output.permute(0, 3, 1, 2) # Final output shape: [b, c, h, T] | |
class DilatedDenseNet(nn.Module): | |
""" | |
DilatedDenseNet: A dilated dense network for feature extraction. | |
This network consists of a series of dilated convolutions organized in a dense block structure, | |
allowing for efficient feature reuse and capturing multi-scale information. | |
Args: | |
depth (int): The number of layers in the dense block (default is 4). | |
in_channels (int): The number of input channels for the first layer (default is 64). | |
""" | |
def __init__(self, depth=4, in_channels=64): | |
"""Initializes the DilatedDenseNet with specified depth and input channels.""" | |
super(DilatedDenseNet, self).__init__() | |
self.depth = depth | |
self.in_channels = in_channels | |
self.pad = nn.ConstantPad2d((1, 1, 1, 0), value=0.) # Padding for the first layer | |
self.twidth = 2 # Temporal width for convolutions | |
self.kernel_size = (self.twidth, 3) # Kernel size for convolutions | |
# Initialize dilated convolutions, padding, normalization, and FSMN for each layer | |
for i in range(self.depth): | |
dil = 2 ** i # Dilation factor for the current layer | |
pad_length = self.twidth + (dil - 1) * (self.twidth - 1) - 1 # Calculate padding length | |
setattr(self, 'pad{}'.format(i + 1), nn.ConstantPad2d((1, 1, pad_length, 0), value=0.)) | |
setattr(self, 'conv{}'.format(i + 1), | |
nn.Conv2d(self.in_channels * (i + 1), self.in_channels, kernel_size=self.kernel_size, | |
dilation=(dil, 1))) # Convolution layer | |
setattr(self, 'norm{}'.format(i + 1), nn.InstanceNorm2d(in_channels, affine=True)) # Normalization | |
setattr(self, 'prelu{}'.format(i + 1), nn.PReLU(self.in_channels)) # Activation function | |
setattr(self, 'fsmn{}'.format(i + 1), FSMN_Wrap(nIn=self.in_channels, nHidden=self.in_channels, lorder=5, nOut=self.in_channels)) | |
def forward(self, x): | |
""" | |
Defines the forward pass for the DilatedDenseNet. | |
Args: | |
x (torch.Tensor): Input tensor of shape [batch_size, channels, height, width]. | |
Returns: | |
torch.Tensor: Output tensor after processing through the dense network. | |
""" | |
skip = x # Initialize skip connection with input | |
for i in range(self.depth): | |
# Apply padding, convolution, normalization, activation, and FSMN in sequence | |
out = getattr(self, 'pad{}'.format(i + 1))(skip) | |
out = getattr(self, 'conv{}'.format(i + 1))(out) | |
out = getattr(self, 'norm{}'.format(i + 1))(out) | |
out = getattr(self, 'prelu{}'.format(i + 1))(out) | |
out = getattr(self, 'fsmn{}'.format(i + 1))(out) | |
skip = torch.cat([out, skip], dim=1) # Concatenate outputs for dense connectivity | |
return out # Return the final output | |
class DenseEncoder(nn.Module): | |
""" | |
DenseEncoder: A dense encoding module for feature extraction from input data. | |
This module consists of a series of convolutional layers followed by a | |
dilated dense network for robust feature learning. | |
Args: | |
in_channel (int): Number of input channels for the encoder. | |
channels (int): Number of output channels for each convolutional layer (default is 64). | |
""" | |
def __init__(self, in_channel, channels=64): | |
"""Initializes the DenseEncoder with specified input channels and feature size.""" | |
super(DenseEncoder, self).__init__() | |
self.conv_1 = nn.Sequential( | |
nn.Conv2d(in_channel, channels, (1, 1), (1, 1)), # Initial convolution layer | |
nn.InstanceNorm2d(channels, affine=True), # Normalization layer | |
nn.PReLU(channels) # Activation function | |
) | |
self.dilated_dense = DilatedDenseNet(depth=4, in_channels=channels) # Dilated Dense Network | |
self.conv_2 = nn.Sequential( | |
nn.Conv2d(channels, channels, (1, 3), (1, 2), padding=(0, 1)), # Second convolution layer | |
nn.InstanceNorm2d(channels, affine=True), # Normalization layer | |
nn.PReLU(channels) # Activation function | |
) | |
def forward(self, x): | |
""" | |
Defines the forward pass for the DenseEncoder. | |
Args: | |
x (torch.Tensor): Input tensor of shape [batch_size, in_channel, height, width]. | |
Returns: | |
torch.Tensor: Output tensor after processing through the encoder. | |
""" | |
x = self.conv_1(x) # Process through the first convolutional layer | |
x = self.dilated_dense(x) # Process through the dilated dense network | |
x = self.conv_2(x) # Process through the second convolutional layer | |
return x # Return the final output | |
class SPConvTranspose2d(nn.Module): | |
""" | |
SPConvTranspose2d: A spatially separable convolution transpose layer. | |
This module implements a transposed convolution operation with spatial separability, | |
allowing for efficient upsampling and feature extraction. | |
Args: | |
in_channels (int): Number of input channels. | |
out_channels (int): Number of output channels. | |
kernel_size (tuple): Size of the convolution kernel. | |
r (int): Upsampling rate (default is 1). | |
""" | |
def __init__(self, in_channels, out_channels, kernel_size, r=1): | |
"""Initializes the SPConvTranspose2d with specified parameters.""" | |
super(SPConvTranspose2d, self).__init__() | |
self.pad1 = nn.ConstantPad2d((1, 1, 0, 0), value=0.) # Padding for input | |
self.out_channels = out_channels # Store number of output channels | |
self.conv = nn.Conv2d(in_channels, out_channels * r, kernel_size=kernel_size, stride=(1, 1)) # Convolution layer | |
self.r = r # Store the upsampling rate | |
def forward(self, x): | |
""" | |
Defines the forward pass for the SPConvTranspose2d module. | |
Args: | |
x (torch.Tensor): Input tensor of shape [batch_size, in_channels, height, width]. | |
Returns: | |
torch.Tensor: Output tensor after transposed convolution operation. | |
""" | |
x = self.pad1(x) # Apply padding to input | |
out = self.conv(x) # Perform convolution operation | |
batch_size, nchannels, H, W = out.shape # Get output shape | |
out = out.view((batch_size, self.r, nchannels // self.r, H, W)) # Reshape output for separation | |
out = out.permute(0, 2, 3, 4, 1) # Rearrange dimensions | |
out = out.contiguous().view((batch_size, nchannels // self.r, H, -1)) # Final output shape | |
return out # Return the final output | |
class MaskDecoder(nn.Module): | |
""" | |
MaskDecoder: A decoder module for estimating masks used in audio processing. | |
This module utilizes a dilated dense network to capture features and | |
applies sub-pixel convolution to upscale the output. It produces | |
a mask that can be applied to the magnitude of audio signals. | |
Args: | |
num_features (int): The number of features in the output mask. | |
num_channel (int): The number of channels in intermediate layers (default is 64). | |
out_channel (int): The number of output channels for the final output mask (default is 1). | |
""" | |
def __init__(self, num_features, num_channel=64, out_channel=1): | |
"""Initializes the MaskDecoder with specified parameters.""" | |
super(MaskDecoder, self).__init__() | |
self.dense_block = DilatedDenseNet(depth=4, in_channels=num_channel) # Dense feature extraction | |
self.sub_pixel = SPConvTranspose2d(num_channel, num_channel, (1, 3), 2) # Sub-pixel convolution for upsampling | |
self.conv_1 = nn.Conv2d(num_channel, out_channel, (1, 2)) # Convolution layer to produce mask | |
self.norm = nn.InstanceNorm2d(out_channel, affine=True) # Normalization layer | |
self.prelu = nn.PReLU(out_channel) # Activation function | |
self.final_conv = nn.Conv2d(out_channel, out_channel, (1, 1)) # Final convolution layer | |
self.prelu_out = nn.PReLU(num_features, init=-0.25) # Final activation for output mask | |
def forward(self, x): | |
""" | |
Defines the forward pass for the MaskDecoder. | |
Args: | |
x (torch.Tensor): Input tensor of shape [batch_size, channels, height, width]. | |
Returns: | |
torch.Tensor: Output mask tensor after processing through the decoder. | |
""" | |
x = self.dense_block(x) # Feature extraction using dilated dense block | |
x = self.sub_pixel(x) # Upsample the features | |
x = self.conv_1(x) # Convolution to estimate the mask | |
x = self.prelu(self.norm(x)) # Apply normalization and activation | |
x = self.final_conv(x).permute(0, 3, 2, 1).squeeze(-1) # Final convolution and rearrangement | |
return self.prelu_out(x).permute(0, 2, 1).unsqueeze(1) # Final output shape | |
class ComplexDecoder(nn.Module): | |
""" | |
ComplexDecoder: A decoder module for estimating complex-valued outputs. | |
This module processes features through a dilated dense network and a | |
sub-pixel convolution layer to generate two output channels representing | |
the real and imaginary parts of the complex output. | |
Args: | |
num_channel (int): The number of channels in intermediate layers (default is 64). | |
""" | |
def __init__(self, num_channel=64): | |
"""Initializes the ComplexDecoder with specified parameters.""" | |
super(ComplexDecoder, self).__init__() | |
self.dense_block = DilatedDenseNet(depth=4, in_channels=num_channel) # Dense feature extraction | |
self.sub_pixel = SPConvTranspose2d(num_channel, num_channel, (1, 3), 2) # Sub-pixel convolution for upsampling | |
self.prelu = nn.PReLU(num_channel) # Activation function | |
self.norm = nn.InstanceNorm2d(num_channel, affine=True) # Normalization layer | |
self.conv = nn.Conv2d(num_channel, 2, (1, 2)) # Convolution layer to produce complex outputs | |
def forward(self, x): | |
""" | |
Defines the forward pass for the ComplexDecoder. | |
Args: | |
x (torch.Tensor): Input tensor of shape [batch_size, channels, height, width]. | |
Returns: | |
torch.Tensor: Output tensor containing real and imaginary parts. | |
""" | |
x = self.dense_block(x) # Feature extraction using dilated dense block | |
x = self.sub_pixel(x) # Upsample the features | |
x = self.prelu(self.norm(x)) # Apply normalization and activation | |
x = self.conv(x) # Generate complex output | |
return x # Return the output tensor | |
class SyncANet(nn.Module): | |
""" | |
SyncANet: A synchronous audio processing network for separating audio signals. | |
This network integrates dense encoding, synchronous attention blocks, | |
and separate decoders for estimating masks and complex-valued outputs. | |
Args: | |
num_channel (int): The number of channels in the network (default is 64). | |
num_features (int): The number of features for the mask decoder (default is 201). | |
""" | |
def __init__(self, num_channel=64, num_features=201): | |
"""Initializes the SyncANet with specified parameters.""" | |
super(SyncANet, self).__init__() | |
self.dense_encoder = DenseEncoder(in_channel=3, channels=num_channel) # Dense encoder for input | |
self.n_layers = 6 # Number of synchronous attention layers | |
self.blocks = nn.ModuleList([]) # List to hold attention blocks | |
# Initialize attention blocks | |
for _ in range(self.n_layers): | |
self.blocks.append( | |
SyncANetBlock( | |
emb_dim=num_channel, | |
emb_ks=2, | |
emb_hs=1, | |
n_freqs=int(num_features//2)+1, | |
hidden_channels=num_channel*2, | |
n_head=4, | |
approx_qk_dim=512, | |
activation='prelu', | |
eps=1.0e-5, | |
) | |
) | |
self.mask_decoder = MaskDecoder(num_features, num_channel=num_channel, out_channel=1) # Mask decoder | |
self.complex_decoder = ComplexDecoder(num_channel=num_channel) # Complex decoder | |
def forward(self, x): | |
""" | |
Defines the forward pass for the SyncANet. | |
Args: | |
x (torch.Tensor): Input tensor of shape [batch_size, 2, height, width] representing complex signals. | |
Returns: | |
list: List containing the real and imaginary parts of the output tensor. | |
""" | |
out_list = [] # List to store outputs | |
mag = torch.sqrt(x[:, 0, :, :]**2 + x[:, 1, :, :]**2).unsqueeze(1) # Calculate magnitude | |
noisy_phase = torch.angle(torch.complex(x[:, 0, :, :], x[:, 1, :, :])).unsqueeze(1) # Calculate phase | |
x_in = torch.cat([mag, x], dim=1) # Concatenate magnitude and input for processing | |
x = self.dense_encoder(x_in) # Feature extraction using dense encoder | |
for ii in range(self.n_layers): | |
x = self.blocks[ii](x) # Pass through attention blocks | |
mask = self.mask_decoder(x) # Estimate mask from features | |
out_mag = mask * mag # Apply mask to magnitude | |
complex_out = self.complex_decoder(x) # Generate complex output | |
mag_real = out_mag * torch.cos(noisy_phase) # Real part of the output | |
mag_imag = out_mag * torch.sin(noisy_phase) # Imaginary part of the output | |
final_real = mag_real + complex_out[:, 0, :, :].unsqueeze(1) # Final real output | |
final_imag = mag_imag + complex_out[:, 1, :, :].unsqueeze(1) # Final imaginary output | |
out_list.append(final_real) # Append real output to list | |
out_list.append(final_imag) # Append imaginary output to list | |
return out_list # Return list of outputs | |
class FFConvM(nn.Module): | |
""" | |
FFConvM: A feedforward convolutional module combining linear layers, normalization, | |
non-linear activation, and convolution operations. | |
This module processes input tensors through a sequence of transformations, including | |
normalization, a linear layer with a SiLU activation, a convolutional operation, and | |
dropout for regularization. | |
Args: | |
dim_in (int): The number of input features (dimensionality of input). | |
dim_out (int): The number of output features (dimensionality of output). | |
norm_klass (nn.Module): The normalization class to be applied (default is nn.LayerNorm). | |
dropout (float): The dropout probability for regularization (default is 0.1). | |
""" | |
def __init__( | |
self, | |
dim_in, | |
dim_out, | |
norm_klass=nn.LayerNorm, | |
dropout=0.1 | |
): | |
"""Initializes the FFConvM with specified parameters.""" | |
super().__init__() | |
# Define the sequential model | |
self.mdl = nn.Sequential( | |
norm_klass(dim_in), # Apply normalization to input | |
nn.Linear(dim_in, dim_out), # Linear transformation to dim_out | |
nn.SiLU(), # Non-linear activation using SiLU (Sigmoid Linear Unit) | |
ConvModule(dim_out), # Convolution operation on the output | |
nn.Dropout(dropout) # Dropout layer for regularization | |
) | |
def forward(self, x): | |
""" | |
Defines the forward pass for the FFConvM. | |
Args: | |
x (torch.Tensor): Input tensor of shape [batch_size, dim_in]. | |
Returns: | |
torch.Tensor: Output tensor of shape [batch_size, dim_out] after processing. | |
""" | |
output = self.mdl(x) # Pass input through the sequential model | |
return output # Return the processed output | |
class SyncANetBlock(nn.Module): | |
""" | |
SyncANetBlock implements a modified version of the MossFormer (GatedFormer) module, | |
inspired by the TF-GridNet architecture (https://arxiv.org/abs/2211.12433). | |
It combines gated triple-attention schemes and Finite Short Memory Network (FSMN) modules | |
to enhance computational efficiency and overall performance in audio processing tasks. | |
Attributes: | |
emb_dim (int): Dimensionality of the embedding. | |
emb_ks (int): Kernel size for embeddings. | |
emb_hs (int): Stride size for embeddings. | |
n_freqs (int): Number of frequency bands. | |
hidden_channels (int): Number of hidden channels. | |
n_head (int): Number of attention heads. | |
approx_qk_dim (int): Approximate dimension for query-key matrices. | |
activation (str): Activation function to use. | |
eps (float): Small value to avoid division by zero in normalization layers. | |
""" | |
def __getitem__(self, key): | |
""" | |
Allows accessing module attributes using indexing. | |
Args: | |
key: Attribute name to retrieve. | |
Returns: | |
The requested attribute. | |
""" | |
return getattr(self, key) | |
def __init__( | |
self, | |
emb_dim, | |
emb_ks, | |
emb_hs, | |
n_freqs, | |
hidden_channels, | |
n_head=4, | |
approx_qk_dim=512, | |
activation="prelu", | |
eps=1e-5, | |
): | |
""" | |
Initializes the SyncANetBlock with the specified parameters. | |
Args: | |
emb_dim (int): Dimensionality of the embedding. | |
emb_ks (int): Kernel size for embeddings. | |
emb_hs (int): Stride size for embeddings. | |
n_freqs (int): Number of frequency bands. | |
hidden_channels (int): Number of hidden channels. | |
n_head (int): Number of attention heads. Default is 4. | |
approx_qk_dim (int): Approximate dimension for query-key matrices. Default is 512. | |
activation (str): Activation function to use. Default is "prelu". | |
eps (float): Small value to avoid division by zero in normalization layers. Default is 1e-5. | |
""" | |
super().__init__() | |
in_channels = emb_dim * emb_ks # Calculate the number of input channels | |
## Intra modules: Modules for internal processing within the block | |
self.Fconv = nn.Conv2d(emb_dim, in_channels, kernel_size=(1, emb_ks), stride=(1, 1), groups=emb_dim) | |
self.intra_norm = LayerNormalization4D(emb_dim, eps=eps) # Layer normalization | |
self.intra_to_u = FFConvM( | |
dim_in=in_channels, | |
dim_out=hidden_channels, | |
norm_klass=nn.LayerNorm, | |
dropout=0.1, | |
) | |
self.intra_to_v = FFConvM( | |
dim_in=in_channels, | |
dim_out=hidden_channels, | |
norm_klass=nn.LayerNorm, | |
dropout=0.1, | |
) | |
self.intra_rnn = self._build_repeats(in_channels, hidden_channels, 20, hidden_channels, repeats=1) # FSMN layers | |
self.intra_mossformer = MossFormer(dim=emb_dim, group_size=n_freqs) # MossFormer module | |
# Linear transformation for intra module output | |
self.intra_linear = nn.ConvTranspose1d( | |
hidden_channels, emb_dim, emb_ks, stride=emb_hs | |
) | |
self.intra_se = SELayer(channel=emb_dim, reduction=1) # Squeeze-and-excitation layer | |
## Inter modules: Modules for external processing between blocks | |
self.inter_norm = LayerNormalization4D(emb_dim, eps=eps) # Layer normalization | |
self.inter_to_u = FFConvM( | |
dim_in=in_channels, | |
dim_out=hidden_channels, | |
norm_klass=nn.LayerNorm, | |
dropout=0.1, | |
) | |
self.inter_to_v = FFConvM( | |
dim_in=in_channels, | |
dim_out=hidden_channels, | |
norm_klass=nn.LayerNorm, | |
dropout=0.1, | |
) | |
self.inter_rnn = self._build_repeats(in_channels, hidden_channels, 20, hidden_channels, repeats=1) # FSMN layers | |
self.inter_mossformer = MossFormer(dim=emb_dim, group_size=256) # MossFormer module | |
# Linear transformation for inter module output | |
self.inter_linear = nn.ConvTranspose1d( | |
hidden_channels, emb_dim, emb_ks, stride=emb_hs | |
) | |
self.inter_se = SELayer(channel=emb_dim, reduction=1) # Squeeze-and-excitation layer | |
# Approximate query-key dimension calculation | |
E = math.ceil(approx_qk_dim * 1.0 / n_freqs) | |
assert emb_dim % n_head == 0 # Ensure emb_dim is divisible by n_head | |
# Define attention convolution layers for each head | |
for ii in range(n_head): | |
self.add_module( | |
f"attn_conv_Q_{ii}", | |
nn.Sequential( | |
nn.Conv2d(emb_dim, E, 1), | |
get_layer(activation)(), | |
LayerNormalization4DCF((E, n_freqs), eps=eps), | |
), | |
) | |
self.add_module( | |
f"attn_conv_K_{ii}", | |
nn.Sequential( | |
nn.Conv2d(emb_dim, E, 1), | |
get_layer(activation)(), | |
LayerNormalization4DCF((E, n_freqs), eps=eps), | |
), | |
) | |
self.add_module( | |
f"attn_conv_V_{ii}", | |
nn.Sequential( | |
nn.Conv2d(emb_dim, emb_dim // n_head, 1), | |
get_layer(activation)(), | |
LayerNormalization4DCF((emb_dim // n_head, n_freqs), eps=eps), | |
), | |
) | |
# Final attention concatenation projection | |
self.add_module( | |
"attn_concat_proj", | |
nn.Sequential( | |
nn.Conv2d(emb_dim, emb_dim, 1), | |
get_layer(activation)(), | |
LayerNormalization4DCF((emb_dim, n_freqs), eps=eps), | |
), | |
) | |
# Store parameters for further processing | |
self.emb_dim = emb_dim | |
self.emb_ks = emb_ks | |
self.emb_hs = emb_hs | |
self.n_head = n_head | |
def _build_repeats(self, in_channels, out_channels, lorder, hidden_size, repeats=1): | |
""" | |
Constructs a sequence of UniDeepFSMN modules. | |
Args: | |
in_channels (int): Number of input channels. | |
out_channels (int): Number of output channels. | |
lorder (int): Order of the filter. | |
hidden_size (int): Hidden size for the FSMN. | |
repeats (int): Number of times to repeat the module. Default is 1. | |
Returns: | |
nn.Sequential: A sequence of UniDeepFSMN modules. | |
""" | |
repeats = [ | |
UniDeepFsmn(in_channels, out_channels, lorder, hidden_size) | |
for _ in range(repeats) | |
] | |
return nn.Sequential(*repeats) | |
def forward(self, x): | |
"""Performs a forward pass through the SyncANetBlock. | |
Args: | |
x (torch.Tensor): Input tensor of shape [B, C, T, Q] where | |
B is batch size, C is number of channels, | |
T is temporal dimension, and Q is frequency dimension. | |
Returns: | |
torch.Tensor: Output tensor of the same shape [B, C, T, Q]. | |
""" | |
B, C, old_T, old_Q = x.shape | |
# Calculate new dimensions for padding | |
T = math.ceil((old_T - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks | |
Q = math.ceil((old_Q - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks | |
# Pad the input tensor to match the new dimensions | |
x = F.pad(x, (0, Q - old_Q, 0, T - old_T)) | |
# Intra-process | |
input_ = x | |
intra_rnn = self.intra_norm(input_) # Normalize input for intra-process | |
intra_rnn = self.Fconv(intra_rnn) # Apply depthwise convolution | |
intra_rnn = ( | |
intra_rnn.transpose(1, 2).contiguous().view(B * T, C * self.emb_ks, -1) | |
) # Reshape for subsequent operations | |
intra_rnn = intra_rnn.transpose(1, 2) # Reshape for processing | |
intra_rnn_u = self.intra_to_u(intra_rnn) # Linear transformation | |
intra_rnn_v = self.intra_to_v(intra_rnn) # Linear transformation | |
intra_rnn_u = self.intra_rnn(intra_rnn_u) # Apply FSMN | |
intra_rnn = intra_rnn_v * intra_rnn_u # Element-wise multiplication | |
intra_rnn = intra_rnn.transpose(1, 2) # Reshape back | |
intra_rnn = self.intra_linear(intra_rnn) # Linear projection | |
intra_rnn = intra_rnn.transpose(1, 2) # Reshape for mossformer | |
intra_rnn = intra_rnn.view([B, T, Q, C]) # Reshape for mossformer | |
intra_rnn = self.intra_mossformer(intra_rnn) # Apply MossFormer | |
intra_rnn = intra_rnn.transpose(1, 2) # Reshape back | |
intra_rnn = intra_rnn.view([B, T, C, Q]) # Reshape back | |
intra_rnn = intra_rnn.transpose(1, 2).contiguous() # Final reshape | |
intra_rnn = self.intra_se(intra_rnn) # Squeeze-and-excitation layer | |
intra_rnn = intra_rnn + input_ # Residual connection | |
# Inter-process | |
input_ = intra_rnn | |
inter_rnn = self.inter_norm(input_) # Normalize input for inter-process | |
inter_rnn = ( | |
inter_rnn.permute(0, 3, 1, 2).contiguous().view(B * Q, C, T) | |
) # Reshape for processing | |
inter_rnn = F.unfold( | |
inter_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1) | |
) # Extract sliding windows | |
inter_rnn = inter_rnn.transpose(1, 2) # Reshape for further processing | |
inter_rnn_u = self.inter_to_u(inter_rnn) # Linear transformation | |
inter_rnn_v = self.inter_to_v(inter_rnn) # Linear transformation | |
inter_rnn_u = self.inter_rnn(inter_rnn_u) # Apply FSMN | |
inter_rnn = inter_rnn_v * inter_rnn_u # Element-wise multiplication | |
inter_rnn = inter_rnn.transpose(1, 2) # Reshape back | |
inter_rnn = self.inter_linear(inter_rnn) # Linear projection | |
inter_rnn = inter_rnn.transpose(1, 2) # Reshape for mossformer | |
inter_rnn = inter_rnn.view([B, Q, T, C]) # Reshape for mossformer | |
inter_rnn = self.inter_mossformer(inter_rnn) # Apply MossFormer | |
inter_rnn = inter_rnn.transpose(1, 2) # Reshape back | |
inter_rnn = inter_rnn.view([B, Q, C, T]) # Final reshape | |
inter_rnn = inter_rnn.permute(0, 2, 3, 1).contiguous() # Permute for SE layer | |
inter_rnn = self.inter_se(inter_rnn) # Squeeze-and-excitation layer | |
inter_rnn = inter_rnn + input_ # Residual connection | |
# Attention mechanism | |
inter_rnn = inter_rnn[..., :old_T, :old_Q] # Trim to original shape | |
batch = inter_rnn | |
all_Q, all_K, all_V = [], [], [] | |
# Compute query, key, and value for each attention head | |
for ii in range(self.n_head): | |
all_Q.append(self["attn_conv_Q_%d" % ii](batch)) # Query | |
all_K.append(self["attn_conv_K_%d" % ii](batch)) # Key | |
all_V.append(self["attn_conv_V_%d" % ii](batch)) # Value | |
Q = torch.cat(all_Q, dim=0) # Concatenate all queries | |
K = torch.cat(all_K, dim=0) # Concatenate all keys | |
V = torch.cat(all_V, dim=0) # Concatenate all values | |
# Reshape for attention calculation | |
Q = Q.transpose(1, 2) | |
Q = Q.flatten(start_dim=2) # Flatten for attention calculation | |
K = K.transpose(1, 2) | |
K = K.flatten(start_dim=2) # Flatten for attention calculation | |
V = V.transpose(1, 2) # Reshape for attention calculation | |
old_shape = V.shape | |
V = V.flatten(start_dim=2) # Flatten for attention calculation | |
emb_dim = Q.shape[-1] | |
# Compute scaled dot-product attention | |
attn_mat = torch.matmul(Q, K.transpose(1, 2)) / (emb_dim**0.5) # Attention matrix | |
attn_mat = F.softmax(attn_mat, dim=2) # Softmax over attention scores | |
V = torch.matmul(attn_mat, V) # Weighted sum of values | |
V = V.reshape(old_shape) # Reshape back | |
V = V.transpose(1, 2) # Final reshaping | |
emb_dim = V.shape[1] | |
batch = V.view([self.n_head, B, emb_dim, old_T, -1]) # Reshape for multi-head | |
batch = batch.transpose(0, 1) # Permute for batch processing | |
batch = batch.contiguous().view( | |
[B, self.n_head * emb_dim, old_T, -1] | |
) # Final reshape for concatenation | |
batch = self["attn_concat_proj"](batch) # Final linear projection | |
# Combine inter-process result with attention output | |
out = batch + inter_rnn | |
return out # Return the output tensor | |
class LayerNormalization4D(nn.Module): | |
""" | |
LayerNormalization4D applies layer normalization to 4D tensors | |
(e.g., [B, C, T, F]), where B is the batch size, C is the number of channels, | |
T is the temporal dimension, and F is the frequency dimension. | |
Attributes: | |
gamma (torch.Parameter): Learnable scaling parameter. | |
beta (torch.Parameter): Learnable shifting parameter. | |
eps (float): Small value for numerical stability during variance calculation. | |
""" | |
def __init__(self, input_dimension, eps=1e-5): | |
""" | |
Initializes the LayerNormalization4D layer. | |
Args: | |
input_dimension (int): The number of channels in the input tensor. | |
eps (float, optional): Small constant added for numerical stability. | |
""" | |
super().__init__() | |
param_size = [1, input_dimension, 1, 1] | |
self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32)) # Scale parameter | |
self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32)) # Shift parameter | |
init.ones_(self.gamma) # Initialize gamma to 1 | |
init.zeros_(self.beta) # Initialize beta to 0 | |
self.eps = eps # Set the epsilon value | |
def forward(self, x): | |
""" | |
Forward pass for the layer normalization. | |
Args: | |
x (torch.Tensor): Input tensor of shape [B, C, T, F]. | |
Returns: | |
torch.Tensor: Normalized output tensor of the same shape. | |
""" | |
if x.ndim == 4: | |
_, C, _, _ = x.shape # Extract the number of channels | |
stat_dim = (1,) # Dimension to compute statistics over | |
else: | |
raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim)) | |
# Compute mean and standard deviation along the specified dimension | |
mu_ = x.mean(dim=stat_dim, keepdim=True) # [B, 1, T, F] | |
std_ = torch.sqrt( | |
x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps | |
) # [B, 1, T, F] | |
# Normalize the input tensor and apply learnable parameters | |
x_hat = ((x - mu_) / std_) * self.gamma + self.beta # [B, C, T, F] | |
return x_hat | |
class LayerNormalization4DCF(nn.Module): | |
""" | |
LayerNormalization4DCF applies layer normalization to 4D tensors | |
(e.g., [B, C, T, F]) specifically designed for DCF (Dynamic Channel Frequency) inputs. | |
Attributes: | |
gamma (torch.Parameter): Learnable scaling parameter. | |
beta (torch.Parameter): Learnable shifting parameter. | |
eps (float): Small value for numerical stability during variance calculation. | |
""" | |
def __init__(self, input_dimension, eps=1e-5): | |
""" | |
Initializes the LayerNormalization4DCF layer. | |
Args: | |
input_dimension (tuple): A tuple containing the dimensions of the input tensor | |
(number of channels, frequency dimension). | |
eps (float, optional): Small constant added for numerical stability. | |
""" | |
super().__init__() | |
assert len(input_dimension) == 2, "Input dimension must be a tuple of length 2." | |
param_size = [1, input_dimension[0], 1, input_dimension[1]] # Shape based on input dimensions | |
self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32)) # Scale parameter | |
self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32)) # Shift parameter | |
init.ones_(self.gamma) # Initialize gamma to 1 | |
init.zeros_(self.beta) # Initialize beta to 0 | |
self.eps = eps # Set the epsilon value | |
def forward(self, x): | |
""" | |
Forward pass for the layer normalization. | |
Args: | |
x (torch.Tensor): Input tensor of shape [B, C, T, F]. | |
Returns: | |
torch.Tensor: Normalized output tensor of the same shape. | |
""" | |
if x.ndim == 4: | |
stat_dim = (1, 3) # Dimensions to compute statistics over | |
else: | |
raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim)) | |
# Compute mean and standard deviation along the specified dimensions | |
mu_ = x.mean(dim=stat_dim, keepdim=True) # [B, 1, T, 1] | |
std_ = torch.sqrt( | |
x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps | |
) # [B, 1, T, F] | |
# Normalize the input tensor and apply learnable parameters | |
x_hat = ((x - mu_) / std_) * self.gamma + self.beta # [B, C, T, F] | |
return x_hat | |