import torch.nn as nn import torch import torch.nn.functional as F import os import sys sys.path.append(os.path.dirname(__file__)) from models.frcrn_se.conv_stft import ConvSTFT, ConviSTFT import numpy as np from models.frcrn_se.unet import UNet class FRCRN_Wrapper_StandAlone(nn.Module): """ A wrapper class for the DCCRN model used in standalone mode. This class initializes the DCCRN model with predefined parameters and provides a forward method to process input audio signals for speech enhancement. Args: args: Arguments containing model configuration (not used in this wrapper). """ def __init__(self, args): super(FRCRN_Wrapper_StandAlone, self).__init__() # Initialize the DCCRN model with specific parameters self.model = DCCRN( complex=True, model_complexity=45, model_depth=14, log_amp=False, padding_mode="zeros", win_len=640, win_inc=320, fft_len=640, win_type='hanning' ) def forward(self, x): """ Forward pass of the model. Args: x (torch.Tensor): Input tensor representing audio signals. Returns: torch.Tensor: Processed output tensor after applying the model. """ output = self.model(x) return output[1][0] # Return estimated waveform class FRCRN_SE_16K(nn.Module): """ A class for the FRCRN model specifically configured for 16 kHz input signals. This class allows for customization of model parameters based on provided arguments. Args: args: Configuration parameters for the model. """ def __init__(self, args): super(FRCRN_SE_16K, self).__init__() # Initialize the DCCRN model with parameters from args self.model = DCCRN( complex=True, model_complexity=45, model_depth=14, log_amp=False, padding_mode="zeros", win_len=args.win_len, win_inc=args.win_inc, fft_len=args.fft_len, win_type=args.win_type ) def forward(self, x): """ Forward pass of the model. Args: x (torch.Tensor): Input tensor representing audio signals. Returns: torch.Tensor: Processed output tensor after applying the model. """ output = self.model(x) return output[1][0] # Return estimated waveform class DCCRN(nn.Module): """ We implemented our FRCRN model on the basis of DCCRN rep (https://github.com/huyanxin/DeepComplexCRN) for complex speech enhancement. The DCCRN model (Paper: https://arxiv.org/abs/2008.00264) employs a convolutional short-time Fourier transform (STFT) and a UNet architecture for estimating clean speech from noisy inputs, FRCRN uses an enhanced Unet architecture. Args: complex (bool): Flag to determine whether to use complex numbers. model_complexity (int): Complexity level for the model. model_depth (int): Depth of the UNet model (14 or 20). log_amp (bool): Whether to use log amplitude to estimate signals. padding_mode (str): Padding mode for convolutions ('zeros', 'reflect'). win_len (int): Window length for STFT. win_inc (int): Window increment for STFT. fft_len (int): FFT length. win_type (str): Window type for STFT (e.g., 'hanning'). """ def __init__(self, complex, model_complexity, model_depth, log_amp, padding_mode, win_len=400, win_inc=100, fft_len=512, win_type='hanning'): super().__init__() self.feat_dim = fft_len // 2 + 1 self.win_len = win_len self.win_inc = win_inc self.fft_len = fft_len self.win_type = win_type # Initialize STFT and iSTFT layers fix = True # Fixed STFT parameters self.stft = ConvSTFT(self.win_len, self.win_inc, self.fft_len, self.win_type, feature_type='complex', fix=fix) self.istft = ConviSTFT(self.win_len, self.win_inc, self.fft_len, self.win_type, feature_type='complex', fix=fix) # Initialize two UNet models for estimating complex masks self.unet = UNet(1, complex=complex, model_complexity=model_complexity, model_depth=model_depth, padding_mode=padding_mode) self.unet2 = UNet(1, complex=complex, model_complexity=model_complexity, model_depth=model_depth, padding_mode=padding_mode) def forward(self, inputs): """ Forward pass of the FRCRN model. Args: inputs (torch.Tensor): Input tensor representing audio signals. Returns: list: A list containing estimated spectral features, waveform, and masks. """ out_list = [] # Compute the complex spectrogram using STFT cmp_spec = self.stft(inputs) # [B, D*2, T] cmp_spec = torch.unsqueeze(cmp_spec, 1) # [B, 1, D*2, T] # Split into real and imaginary parts cmp_spec = torch.cat([ cmp_spec[:, :, :self.feat_dim, :], # Real part cmp_spec[:, :, self.feat_dim:, :], # Imaginary part ], 1) # [B, 2, D, T] cmp_spec = torch.unsqueeze(cmp_spec, 4) # [B, 2, D, T, 1] cmp_spec = torch.transpose(cmp_spec, 1, 4) # [B, 1, D, T, 2] # Pass through the UNet to estimate masks unet1_out = self.unet(cmp_spec) # First UNet output cmp_mask1 = torch.tanh(unet1_out) # First mask unet2_out = self.unet2(unet1_out) # Second UNet output cmp_mask2 = torch.tanh(unet2_out) # Second mask cmp_mask2 = cmp_mask2 + cmp_mask1 # Combine masks # Apply the estimated mask to the complex spectrogram est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask2) out_list.append(est_spec) out_list.append(est_wav) out_list.append(est_mask) return out_list def inference(self, inputs): """ Inference method for the FRCRN model. This method performs a forward pass through the model to estimate the clean waveform from the noisy input. Args: inputs (torch.Tensor): Input tensor representing audio signals. Returns: torch.Tensor: Estimated waveform after processing. """ # Compute the complex spectrogram using STFT cmp_spec = self.stft(inputs) # [B, D*2, T] cmp_spec = torch.unsqueeze(cmp_spec, 1) # [B, 1, D*2, T] # Split into real and imaginary parts cmp_spec = torch.cat([ cmp_spec[:, :, :self.feat_dim, :], # Real part cmp_spec[:, :, self.feat_dim:, :], # Imaginary part ], 1) # [B, 2, D, T] cmp_spec = torch.unsqueeze(cmp_spec, 4) # [B, 2, D, T, 1] cmp_spec = torch.transpose(cmp_spec, 1, 4) # [B, 1, D, T, 2] # Pass through the UNet to estimate masks unet1_out = self.unet(cmp_spec) cmp_mask1 = torch.tanh(unet1_out) unet2_out = self.unet2(unet1_out) cmp_mask2 = torch.tanh(unet2_out) cmp_mask2 = cmp_mask2 + cmp_mask1 # Combine masks # Apply the estimated mask to compute the estimated waveform _, est_wav, _ = self.apply_mask(cmp_spec, cmp_mask2) return est_wav[0] # Return the estimated waveform def apply_mask(self, cmp_spec, cmp_mask): """ Apply the estimated masks to the complex spectrogram. Args: cmp_spec (torch.Tensor): Complex spectrogram tensor. cmp_mask (torch.Tensor): Estimated mask tensor. Returns: tuple: Estimated spectrogram, waveform, and mask. """ # Compute the estimated complex spectrogram using masks est_spec = torch.cat([ cmp_spec[:, :, :, :, 0] * cmp_mask[:, :, :, :, 0] - cmp_spec[:, :, :, :, 1] * cmp_mask[:, :, :, :, 1], cmp_spec[:, :, :, :, 0] * cmp_mask[:, :, :, :, 1] + cmp_spec[:, :, :, :, 1] * cmp_mask[:, :, :, :, 0] ], 1) # Combine real and imaginary parts est_spec = torch.cat([est_spec[:, 0, :, :], est_spec[:, 1, :, :]], 1) # Flatten dimensions cmp_mask = torch.squeeze(cmp_mask, 1) cmp_mask = torch.cat([cmp_mask[:, :, :, 0], cmp_mask[:, :, :, 1]], 1) # Combine masks est_wav = self.istft(est_spec) # Inverse STFT to obtain waveform est_wav = torch.squeeze(est_wav, 1) # Remove unnecessary dimensions return est_spec, est_wav, cmp_mask def get_params(self, weight_decay=0.0): """ Get parameters for optimization with optional weight decay. Args: weight_decay (float): Weight decay for L2 regularization. Returns: list: List of dictionaries containing parameters and their weight decay settings. """ weights, biases = [], [] for name, param in self.named_parameters(): if 'bias' in name: biases += [param] else: weights += [param] params = [{ 'params': weights, 'weight_decay': weight_decay, }, { 'params': biases, 'weight_decay': 0.0, }] return params