""" modified from https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/lobes/models/dual_path.py #Author: Shengkui Zhao """ import math import torch import torch.nn as nn import torch.nn.functional as F import copy from models.mossformer2_ss.mossformer2_block import ScaledSinuEmbedding, MossformerBlock_GFSMN, MossformerBlock EPS = 1e-8 class GlobalLayerNorm(nn.Module): """Calculate Global Layer Normalization. Arguments --------- dim : (int or list or torch.Size) Input shape from an expected input of size. eps : float A value added to the denominator for numerical stability. elementwise_affine : bool A boolean value that when set to True, this module has learnable per-element affine parameters initialized to ones (for weights) and zeros (for biases). Example ------- >>> x = torch.randn(5, 10, 20) >>> GLN = GlobalLayerNorm(10, 3) >>> x_norm = GLN(x) """ def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True): super(GlobalLayerNorm, self).__init__() self.dim = dim self.eps = eps self.elementwise_affine = elementwise_affine if self.elementwise_affine: if shape == 3: self.weight = nn.Parameter(torch.ones(self.dim, 1)) self.bias = nn.Parameter(torch.zeros(self.dim, 1)) if shape == 4: self.weight = nn.Parameter(torch.ones(self.dim, 1, 1)) self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1)) else: self.register_parameter("weight", None) self.register_parameter("bias", None) def forward(self, x): """Returns the normalized tensor. Arguments --------- x : torch.Tensor Tensor of size [N, C, K, S] or [N, C, L]. """ # x = N x C x K x S or N x C x L # N x 1 x 1 # cln: mean,var N x 1 x K x S # gln: mean,var N x 1 x 1 if x.dim() == 3: mean = torch.mean(x, (1, 2), keepdim=True) var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True) if self.elementwise_affine: x = ( self.weight * (x - mean) / torch.sqrt(var + self.eps) + self.bias ) else: x = (x - mean) / torch.sqrt(var + self.eps) if x.dim() == 4: mean = torch.mean(x, (1, 2, 3), keepdim=True) var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True) if self.elementwise_affine: x = ( self.weight * (x - mean) / torch.sqrt(var + self.eps) + self.bias ) else: x = (x - mean) / torch.sqrt(var + self.eps) return x class CumulativeLayerNorm(nn.LayerNorm): """Calculate Cumulative Layer Normalization. Arguments --------- dim : int Dimension that you want to normalize. elementwise_affine : True Learnable per-element affine parameters. Example ------- >>> x = torch.randn(5, 10, 20) >>> CLN = CumulativeLayerNorm(10) >>> x_norm = CLN(x) """ def __init__(self, dim, elementwise_affine=True): super(CumulativeLayerNorm, self).__init__( dim, elementwise_affine=elementwise_affine, eps=1e-8 ) def forward(self, x): """Returns the normalized tensor. Arguments --------- x : torch.Tensor Tensor size [N, C, K, S] or [N, C, L] """ # x: N x C x K x S or N x C x L # N x K x S x C if x.dim() == 4: x = x.permute(0, 2, 3, 1).contiguous() # N x K x S x C == only channel norm x = super().forward(x) # N x C x K x S x = x.permute(0, 3, 1, 2).contiguous() if x.dim() == 3: x = torch.transpose(x, 1, 2) # N x L x C == only channel norm x = super().forward(x) # N x C x L x = torch.transpose(x, 1, 2) return x def select_norm(norm, dim, shape): """Just a wrapper to select the normalization type. """ if norm == "gln": return GlobalLayerNorm(dim, shape, elementwise_affine=True) if norm == "cln": return CumulativeLayerNorm(dim, elementwise_affine=True) if norm == "ln": return nn.GroupNorm(1, dim, eps=1e-8) else: return nn.BatchNorm1d(dim) class Encoder(nn.Module): """Convolutional Encoder Layer. Arguments --------- kernel_size : int Length of filters. in_channels : int Number of input channels. out_channels : int Number of output channels. Example ------- >>> x = torch.randn(2, 1000) >>> encoder = Encoder(kernel_size=4, out_channels=64) >>> h = encoder(x) >>> h.shape torch.Size([2, 64, 499]) """ def __init__(self, kernel_size=2, out_channels=64, in_channels=1): super(Encoder, self).__init__() self.conv1d = nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=kernel_size // 2, groups=1, bias=False, ) self.in_channels = in_channels def forward(self, x): """Return the encoded output. Arguments --------- x : torch.Tensor Input tensor with dimensionality [B, L]. Return ------ x : torch.Tensor Encoded tensor with dimensionality [B, N, T_out]. where B = Batchsize L = Number of timepoints N = Number of filters T_out = Number of timepoints at the output of the encoder """ # B x L -> B x 1 x L if self.in_channels == 1: x = torch.unsqueeze(x, dim=1) # B x 1 x L -> B x N x T_out x = self.conv1d(x) x = F.relu(x) return x class Decoder(nn.ConvTranspose1d): """A decoder layer that consists of ConvTranspose1d. Arguments --------- kernel_size : int Length of filters. in_channels : int Number of input channels. out_channels : int Number of output channels. Example --------- >>> x = torch.randn(2, 100, 1000) >>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1) >>> h = decoder(x) >>> h.shape torch.Size([2, 1003]) """ def __init__(self, *args, **kwargs): super(Decoder, self).__init__(*args, **kwargs) def forward(self, x): """Return the decoded output. Arguments --------- x : torch.Tensor Input tensor with dimensionality [B, N, L]. where, B = Batchsize, N = number of filters L = time points """ if x.dim() not in [2, 3]: raise RuntimeError( "{} accept 3/4D tensor as input".format(self.__name__) ) x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1)) if torch.squeeze(x).dim() == 1: x = torch.squeeze(x, dim=1) else: x = torch.squeeze(x) return x class IdentityBlock: """This block is used when we want to have identity transformation within the Dual_path block. Example ------- >>> x = torch.randn(10, 100) >>> IB = IdentityBlock() >>> xhat = IB(x) """ def _init__(self, **kwargs): pass def __call__(self, x): return x class MossFormerM(nn.Module): """This class implements the MossFormer2 block. Arguments --------- num_blocks : int Number of mossformer blocks to include. d_model : int The dimension of the input embedding. attn_dropout : float Dropout for the self-attention (Optional). group_size: int the chunk size query_key_dim: int the attention vector dimension expansion_factor: int the expansion factor for the linear projection in conv module causal: bool true for causal / false for non causal Example ------- >>> import torch >>> x = torch.rand((8, 60, 512)) >>> net = MossFormerM(num_blocks=8, d_model=512) >>> output, _ = net(x) >>> output.shape torch.Size([8, 60, 512]) """ def __init__( self, num_blocks, d_model=None, causal=False, group_size = 256, query_key_dim = 128, expansion_factor = 4., attn_dropout = 0.1 ): super().__init__() self.mossformerM = MossformerBlock_GFSMN( dim=d_model, depth=num_blocks, group_size=group_size, query_key_dim=query_key_dim, expansion_factor=expansion_factor, causal=causal, attn_dropout=attn_dropout ) self.norm = nn.LayerNorm(d_model, eps=1e-6) def forward( self, src, ): """ Arguments ---------- src : torch.Tensor Tensor shape [B, L, N], where, B = Batchsize, L = time points N = number of filters The sequence to the encoder layer (required). src_mask : tensor The mask for the src sequence (optional). src_key_padding_mask : tensor The mask for the src keys per batch (optional). """ output = self.mossformerM(src) output = self.norm(output) return output class MossFormerM2(nn.Module): """This class implements the MossFormer block. Arguments --------- num_blocks : int Number of mossformer blocks to include. d_model : int The dimension of the input embedding. attn_dropout : float Dropout for the self-attention (Optional). group_size: int the chunk size query_key_dim: int the attention vector dimension expansion_factor: int the expansion factor for the linear projection in conv module causal: bool true for causal / false for non causal Example ------- >>> import torch >>> x = torch.rand((8, 60, 512)) >>> net = MossFormerM2(num_blocks=8, d_model=512) >>> output, _ = net(x) >>> output.shape torch.Size([8, 60, 512]) """ def __init__( self, num_blocks, d_model=None, causal=False, group_size = 256, query_key_dim = 128, expansion_factor = 4., attn_dropout = 0.1 ): super().__init__() self.mossformerM = MossformerBlock( dim=d_model, depth=num_blocks, group_size=group_size, query_key_dim=query_key_dim, expansion_factor=expansion_factor, causal=causal, attn_dropout=attn_dropout ) self.norm = nn.LayerNorm(d_model, eps=1e-6) def forward( self, src, ): """ Arguments ---------- src : torch.Tensor Tensor shape [B, L, N], where, B = Batchsize, L = time points N = number of filters The sequence to the encoder layer (required). src_mask : tensor The mask for the src sequence (optional). src_key_padding_mask : tensor The mask for the src keys per batch (optional). """ output = self.mossformerM(src) output = self.norm(output) return output class Computation_Block(nn.Module): """ Computation block for single-path processing. This block performs single-path processing using an intra-model (e.g., MossFormerM) to process input data both within chunks and the full sequence allowing for flexibility in normalization and skip connections. Arguments --------- num_blocks : int Number of blocks to use in the intra model. out_channels : int Dimensionality of the inter/intra model. norm : str, optional Normalization type. Default is 'ln' for Layer Normalization. skip_around_intra : bool, optional If True, adds a skip connection around the intra layer. Default is True. Example --------- >>> comp_block = Computation_Block(num_blocks=64, out_channels=64) >>> x = torch.randn(10, 64, 100) # Sample input tensor >>> x = comp_block(x) # Process through the computation block >>> x.shape # Output shape torch.Size([10, 64, 100]) """ def __init__( self, num_blocks: int, out_channels: int, norm: str = "ln", skip_around_intra: bool = True, ): """ Initializes the Computation_Block. Args: num_blocks (int): Number of blocks for the intra model. out_channels (int): Dimensionality of the output features. norm (str, optional): Normalization type. Defaults to 'ln'. skip_around_intra (bool, optional): If True, use skip connection around the intra layer. Defaults to True. """ super(Computation_Block, self).__init__() # Initialize the intra-model (MossFormerM with recurrence) self.intra_mdl = MossFormerM(num_blocks=num_blocks, d_model=out_channels) self.skip_around_intra = skip_around_intra # Flag for skip connection # Set normalization type self.norm = norm if norm is not None: self.intra_norm = select_norm(norm, out_channels, 3) # Initialize normalization layer def forward(self, x: torch.Tensor) -> torch.Tensor: """Returns the output tensor. Args: x (torch.Tensor): Input tensor of dimension [B, N, S], where: B = Batch size, N = Number of filters, S = Sequence length. Returns: out (torch.Tensor): Output tensor of dimension [B, N, S]. """ B, N, S = x.shape # Get the shape of the input tensor # Permute to change the tensor shape from [B, N, S] to [B, S, N] for processing intra = x.permute(0, 2, 1).contiguous() # Process through the intra model intra = self.intra_mdl(intra) # Permute back to [B, N, S] intra = intra.permute(0, 2, 1).contiguous() # Apply normalization if specified if self.norm is not None: intra = self.intra_norm(intra) # Add skip connection around the intra layer if enabled if self.skip_around_intra: intra = intra + x out = intra # Set the output tensor return out # Return the processed output tensor class MossFormer_MaskNet(nn.Module): """ The MossFormer MaskNet for predicting masks for encoder output features. This implementation uses an upgraded MaskNet structure based on the MossFormer2 model. Arguments --------- in_channels : int Number of channels at the output of the encoder. out_channels : int Number of channels that will be input to the intra and inter blocks. num_blocks : int Number of layers in the Dual Computation Block. norm : str Type of normalization to apply. num_spks : int Number of sources (speakers). skip_around_intra : bool If True, adds skip connections around the intra layers. use_global_pos_enc : bool If True, utilizes global positional encodings. max_length : int Maximum sequence length for input data. Example --------- >>> mossformer_masknet = MossFormer_MaskNet(64, 64, num_spks=2) >>> x = torch.randn(10, 64, 2000) # Sample input tensor >>> x = mossformer_masknet(x) # Process through the MaskNet >>> x.shape # Output shape torch.Size([2, 10, 64, 2000]) """ def __init__( self, in_channels: int, out_channels: int, num_blocks: int = 24, norm: str = "ln", num_spks: int = 2, skip_around_intra: bool = True, use_global_pos_enc: bool = True, max_length: int = 20000, ): """ Initializes the MossFormer_MaskNet. Args: in_channels (int): Number of input channels from the encoder. out_channels (int): Number of output channels to be used in the computation blocks. num_blocks (int): Number of layers for the Dual Computation Block. Default is 24. norm (str): Type of normalization to apply. Default is 'ln'. num_spks (int): Number of speakers. Default is 2. skip_around_intra (bool): If True, adds skip connections around intra layers. Default is True. use_global_pos_enc (bool): If True, enables global positional encoding. Default is True. max_length (int): Maximum sequence length. Default is 20000. """ super(MossFormer_MaskNet, self).__init__() self.num_spks = num_spks # Store number of speakers self.num_blocks = num_blocks # Store number of computation blocks # Initialize normalization layer based on the provided type self.norm = select_norm(norm, in_channels, 3) # 1D Convolutional layer to project input channels to output channels self.conv1d_encoder = nn.Conv1d(in_channels, out_channels, 1, bias=False) self.use_global_pos_enc = use_global_pos_enc # Flag for global positional encoding if self.use_global_pos_enc: # Initialize positional encoding layer self.pos_enc = ScaledSinuEmbedding(out_channels) # Initialize the computation block for processing features self.mdl = Computation_Block( num_blocks, out_channels, norm, skip_around_intra=skip_around_intra, ) # Output layer to project features to the desired number of speaker outputs self.conv1d_out = nn.Conv1d( out_channels, out_channels * num_spks, kernel_size=1 ) self.conv1_decoder = nn.Conv1d(out_channels, in_channels, 1, bias=False) # Decoder layer self.prelu = nn.PReLU() # PReLU activation self.activation = nn.ReLU() # ReLU activation for final output # Gated output layer to refine predictions self.output = nn.Sequential( nn.Conv1d(out_channels, out_channels, 1), nn.Tanh() ) self.output_gate = nn.Sequential( nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid() ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Returns the output tensor. Args: x (torch.Tensor): Input tensor of dimension [B, N, S], where: B = Batch size, N = Number of channels (filters), S = Sequence length. Returns: out (torch.Tensor): Output tensor of dimension [spks, B, N, S], where: spks = Number of speakers, B = Batch size, N = Number of filters, S = Number of time frames. """ # [B, N, L] - Normalize the input tensor x = self.norm(x) # [B, N, L] - Apply 1D convolution to encode features x = self.conv1d_encoder(x) # If using global positional encoding, add the positional embeddings if self.use_global_pos_enc: base = x # Store the original encoded features x = x.transpose(1, -1) # Change shape to [B, L, N] emb = self.pos_enc(x) # Get positional embeddings emb = emb.transpose(0, -1) # Change shape back to [B, N, L] x = base + emb # Add positional embeddings to encoded features # [B, N, S] - Process through the computation block x = self.mdl(x) x = self.prelu(x) # Apply PReLU activation # [B, N*spks, S] - Project features to multiple speaker outputs x = self.conv1d_out(x) B, _, S = x.shape # Get the shape after convolution # [B*spks, N, S] - Reshape for speaker outputs x = x.view(B * self.num_spks, -1, S) # [B*spks, N, S] - Apply gated output layer x = self.output(x) * self.output_gate(x) # [B*spks, N, S] - Decode back to original channel size x = self.conv1_decoder(x) # [B, spks, N, S] - Reshape output tensor to include speaker dimension _, N, L = x.shape x = x.view(B, self.num_spks, N, L) x = self.activation(x) # Apply ReLU activation # [spks, B, N, S] - Transpose to match output format x = x.transpose(0, 1) return x # Return the output tensor class MossFormer(nn.Module): """ The End-to-End (E2E) Encoder-MaskNet-Decoder MossFormer model for speech separation. This implementation is based on the upgraded MaskNet architecture from the MossFormer2 model. Arguments --------- in_channels : int Number of channels at the output of the encoder. out_channels : int Number of channels that will be input to the MossFormer2 blocks. num_blocks : int Number of layers in the Dual Computation Block. kernel_size : int Kernel size for the convolutional layers in the encoder and decoder. norm : str Type of normalization to apply (e.g., 'ln' for layer normalization). num_spks : int Number of sources (speakers) to separate. skip_around_intra : bool If True, adds skip connections around intra layers in the computation blocks. use_global_pos_enc : bool If True, uses global positional encodings in the model. max_length : int Maximum sequence length for input data. Example --------- >>> mossformer = MossFormer(num_spks=2) >>> x = torch.randn(1, 10000) # Sample input tensor >>> outputs = mossformer(x) # Process the input through the model >>> outputs[0].shape # Output shape for first speaker torch.Size([1, 10000]) """ def __init__( self, in_channels=512, out_channels=512, num_blocks=24, kernel_size=16, norm="ln", num_spks=2, skip_around_intra=True, use_global_pos_enc=True, max_length=20000, ): """ Initializes the MossFormer model. Args: in_channels (int): Number of input channels from the encoder. Default is 512. out_channels (int): Number of output channels for the MaskNet blocks. Default is 512. num_blocks (int): Number of layers in the Dual Computation Block. Default is 24. kernel_size (int): Kernel size for convolutional layers. Default is 16. norm (str): Type of normalization to apply. Default is 'ln'. num_spks (int): Number of speakers to separate. Default is 2. skip_around_intra (bool): If True, adds skip connections around intra layers. Default is True. use_global_pos_enc (bool): If True, uses global positional encoding. Default is True. max_length (int): Maximum sequence length. Default is 20000. """ super(MossFormer, self).__init__() self.num_spks = num_spks # Store number of speakers # Initialize the encoder with 1 input channel and the specified output channels self.enc = Encoder(kernel_size=kernel_size, out_channels=in_channels, in_channels=1) # Initialize the MaskNet with the specified parameters self.mask_net = MossFormer_MaskNet( in_channels=in_channels, out_channels=out_channels, num_blocks=num_blocks, norm=norm, num_spks=num_spks, skip_around_intra=skip_around_intra, use_global_pos_enc=use_global_pos_enc, max_length=max_length, ) # Initialize the decoder to project output back to 1 channel self.dec = Decoder( in_channels=out_channels, out_channels=1, kernel_size=kernel_size, stride=kernel_size // 2, bias=False ) def forward(self, input: torch.Tensor) -> list: """Processes the input through the encoder, mask net, and decoder. Args: input (torch.Tensor): Input tensor of shape [B, T], where B is the batch size and T is the input length. Returns: out (list): List of output tensors for each speaker, each of shape [B, T]. """ # Pass the input through the encoder to extract features x = self.enc(input) # Generate the mask for each speaker using the mask net mask = self.mask_net(x) # Duplicate the features for each speaker x = torch.stack([x] * self.num_spks) # Apply the mask to separate the sources sep_x = x * mask # Decoding process to reconstruct the separated sources est_source = torch.cat( [self.dec(sep_x[i]).unsqueeze(-1) for i in range(self.num_spks)], dim=-1, ) # Match the estimated output length to the original input length T_origin = input.size(1) T_est = est_source.size(1) if T_origin > T_est: est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est)) # Pad if estimated length is shorter else: est_source = est_source[:, :T_origin, :] # Trim if estimated length is longer out = [] # Collect outputs for each speaker for spk in range(self.num_spks): out.append(est_source[:, :, spk]) return out # Return list of separated outputs class MossFormer2_SS_16K(nn.Module): """ Wrapper for the MossFormer2 model, facilitating external calls. Arguments --------- args : Namespace Contains the necessary arguments for initializing the MossFormer model, such as: - encoder_embedding_dim: Dimension of the encoder's output embeddings. - mossformer_sequence_dim: Dimension of the MossFormer sequence. - num_mossformer_layer: Number of layers in the MossFormer. - encoder_kernel_size: Kernel size for the encoder. - num_spks: Number of sources (speakers) to separate. """ def __init__(self, args): """ Initializes the MossFormer2_SS_16K wrapper. Args: args (Namespace): Contains configuration parameters for the model. """ super(MossFormer2_SS_16K, self).__init__() # Initialize the main MossFormer model with parameters from args self.model = MossFormer( in_channels=args.encoder_embedding_dim, out_channels=args.mossformer_sequence_dim, num_blocks=args.num_mossformer_layer, kernel_size=args.encoder_kernel_size, norm="ln", num_spks=args.num_spks, skip_around_intra=True, use_global_pos_enc=True, max_length=20000 ) def forward(self, x: torch.Tensor) -> list: """Processes the input through the MossFormer model. Args: x (torch.Tensor): Input tensor of shape [B, T], where B is the batch size and T is the input length. Returns: outputs (list): List of output tensors for each speaker. """ outputs = self.model(x) # Forward pass through the MossFormer model return outputs # Return the list of outputs