import torch import torch.nn as nn import numpy as np import torch.nn.functional as F from scipy.signal import get_window import scipy def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False): """ Initialize the kernels for STFT and iSTFT operations. This function generates the kernel for the convolutional layers used in the short-time Fourier transform (STFT) and its inverse (iSTFT). The kernel is created based on the window type and length specified. Args: win_len (int): Length of the window. win_inc (int): Window increment (hop length). fft_len (int): Length of the FFT. win_type (str, optional): Type of window to apply (e.g., 'hanning', 'hamming'). Default is None (rectangular window). invers (bool, optional): If True, computes the pseudo-inverse of the kernel. Default is False. Returns: tuple: A tuple containing: - torch.Tensor: The kernel used for convolution, with shape (2 * win_len, 1, fft_len). - torch.Tensor: The window applied to the kernel, with shape (1, win_len, 1). """ if win_type == 'None' or win_type is None: window = np.ones(win_len) else: # Convert 'hanning' to 'hann' if using scipy version 1.10.1 or higher if scipy.__version__ >= '1.10.1' and win_type == 'hanning': win_type = 'hann' window = get_window(win_type, win_len, fftbins=True)**0.5 N = fft_len fourier_basis = np.fft.rfft(np.eye(N))[:win_len] # Compute Fourier basis for the identity matrix real_kernel = np.real(fourier_basis) # Extract real part imag_kernel = np.imag(fourier_basis) # Extract imaginary part kernel = np.concatenate([real_kernel, imag_kernel], 1).T # Combine real and imaginary parts if invers: kernel = np.linalg.pinv(kernel).T # Compute pseudo-inverse if required kernel = kernel * window # Apply window to kernel kernel = kernel[:, None, :] # Add singleton dimension for compatibility return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None,:,None].astype(np.float32)) class ConvSTFT(nn.Module): """ Convolutional layer that performs Short-Time Fourier Transform (STFT). This class applies the STFT to input signals using convolution with pre-computed kernels. It can return either the complex STFT representation or the magnitude and phase components. Attributes: weight (nn.Parameter): Learnable convolution kernel for STFT. feature_type (str): Specifies whether to return 'complex' or 'real' features. stride (int): The stride used for convolution, typically equal to win_inc. win_len (int): The length of the window used in STFT. dim (int): The FFT length, determining the number of output features. """ def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True): """ Initializes the ConvSTFT layer. Args: win_len (int): Length of the window. win_inc (int): Window increment (hop length). fft_len (int, optional): Length of the FFT. If None, it's computed based on win_len. win_type (str, optional): Type of window to use (default is 'hamming'). feature_type (str, optional): Specifies the output feature type ('real' or 'complex'). Default is 'real'. fix (bool, optional): If True, the kernel weights are fixed and not learnable. Default is True. """ super(ConvSTFT, self).__init__() if fft_len is None: self.fft_len = np.int(2**np.ceil(np.log2(win_len))) # Calculate fft_len based on win_len else: self.fft_len = fft_len kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type) # Initialize STFT kernel self.weight = nn.Parameter(kernel, requires_grad=(not fix)) # Create a learnable parameter for the kernel self.feature_type = feature_type self.stride = win_inc self.win_len = win_len self.dim = self.fft_len def forward(self, inputs): """ Forward pass through the ConvSTFT layer. Args: inputs (torch.Tensor): Input tensor of shape (batch_size, channels, length). Returns: tuple or torch.Tensor: Depending on feature_type, returns either: - torch.Tensor: The complex STFT output if feature_type is 'complex'. - tuple: A tuple containing the magnitude and phase tensors if feature_type is 'real'. """ if inputs.dim() == 2: inputs = torch.unsqueeze(inputs, 1) # Add channel dimension if not present outputs = F.conv1d(inputs, self.weight, stride=self.stride) # Perform convolution to compute STFT if self.feature_type == 'complex': return outputs else: dim = self.dim // 2 + 1 # Calculate the size for the real and imaginary components real = outputs[:, :dim, :] # Extract real part imag = outputs[:, dim:, :] # Extract imaginary part mags = torch.sqrt(real**2 + imag**2) # Compute magnitude phase = torch.atan2(imag, real) # Compute phase return mags, phase # Return magnitude and phase class ConviSTFT(nn.Module): """ Convolutional layer that performs Inverse Short-Time Fourier Transform (iSTFT). This class applies the iSTFT to reconstruct the time-domain signal from the frequency-domain representation obtained from the ConvSTFT layer. Attributes: weight (nn.Parameter): Learnable convolution kernel for iSTFT. feature_type (str): Specifies whether to use 'real' or 'complex' features for reconstruction. win_type (str): Type of window used during iSTFT. win_len (int): The length of the window used in iSTFT. win_inc (int): The window increment (hop length). stride (int): The stride used for transposed convolution, typically equal to win_inc. dim (int): The FFT length, determining the number of output features. window (torch.Tensor): Buffer for the window used in iSTFT. enframe (torch.Tensor): Buffer for the framing matrix. """ def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True): """ Initializes the ConviSTFT layer. Args: win_len (int): Length of the window. win_inc (int): Window increment (hop length). fft_len (int, optional): Length of the FFT. If None, it's computed based on win_len. win_type (str, optional): Type of window to use (default is 'hamming'). feature_type (str, optional): Specifies the output feature type ('real' or 'complex'). Default is 'real'. fix (bool, optional): If True, the kernel weights are fixed and not learnable. Default is True. """ super(ConviSTFT, self).__init__() if fft_len is None: self.fft_len = np.int(2**np.ceil(np.log2(win_len))) # Calculate fft_len based on win_len else: self.fft_len = fft_len kernel, window = init_kernels(win_len, win_inc, self.fft_len, win_type, invers=True) # Initialize iSTFT kernel self.weight = nn.Parameter(kernel, requires_grad=(not fix)) # Create a learnable parameter for the kernel self.feature_type = feature_type self.win_type = win_type self.win_len = win_len self.win_inc = win_inc self.stride = win_inc self.dim = self.fft_len self.register_buffer('window', window) # Register the window as a buffer self.register_buffer('enframe', torch.eye(win_len)[:, None, :]) # Framing matrix for overlap-add method def forward(self, inputs, phase=None): """ Forward pass through the ConviSTFT layer. Args: inputs (torch.Tensor): Input tensor of shape [B, N+2, T] for complex spectra or [B, N//2+1, T] for magnitude spectra. phase (torch.Tensor, optional): Phase tensor of shape [B, N//2+1, T]. If provided, used to reconstruct the complex spectra. Returns: torch.Tensor: Reconstructed time-domain signal. """ if phase is not None: # Reconstruct real and imaginary components from magnitude and phase real = inputs * torch.cos(phase) # Real part imag = inputs * torch.sin(phase) # Imaginary part inputs = torch.cat([real, imag], 1) # Concatenate to form complex input outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride) # Perform transposed convolution for iSTFT # Compute the overlap-add normalization t = self.window.repeat(1, 1, inputs.size(-1))**2 coff = F.conv_transpose1d(t, self.enframe, stride=self.stride) # Apply the framing matrix for overlap-add outputs = outputs / (coff + 1e-8) # Normalize the output to prevent division by zero return outputs def test_fft(): """ Test the ConvSTFT layer against Librosa's STFT implementation. This function generates a random input signal and computes its STFT using the ConvSTFT layer, then compares the output with the STFT computed using Librosa to ensure correctness. """ torch.manual_seed(20) win_len = 320 win_inc = 160 fft_len = 512 inputs = torch.randn([1, 1, 16000*4]) # Random input tensor fft = ConvSTFT(win_len, win_inc, fft_len, win_type='hanning', feature_type='real') # Initialize ConvSTFT outputs1 = fft(inputs)[0] # Compute STFT using ConvSTFT outputs1 = outputs1.numpy()[0] # Convert to NumPy array for comparison np_inputs = inputs.numpy().reshape([-1]) # Reshape input for Librosa librosa_stft = librosa.stft(np_inputs, win_length=win_len, n_fft=fft_len, hop_length=win_inc, center=False) # Compute STFT using Librosa print(np.mean((outputs1 - np.abs(librosa_stft))**2)) # Print mean squared error between the two STFT outputs def test_ifft1(): """ Test the ConviSTFT layer by reconstructing a waveform from the STFT output. This function reads an audio file, applies the ConvSTFT to compute its STFT, and then uses the ConviSTFT to reconstruct the time-domain signal. The reconstructed signal is saved to a file and compared to the original to evaluate the reconstruction accuracy. """ import soundfile as sf N = 100 inc = 75 fft_len = 512 torch.manual_seed(N) # Read input audio file and reshape data = sf.read('../../wavs/ori.wav')[0] inputs = data.reshape([1, 1, -1]) # Reshape to [1, 1, length] fft = ConvSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex') # Initialize ConvSTFT ifft = ConviSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex') # Initialize ConviSTFT inputs = torch.from_numpy(inputs.astype(np.float32)) # Convert to torch tensor outputs1 = fft(inputs) # Compute STFT outputs2 = ifft(outputs1) # Reconstruct waveform from STFT sf.write('conv_stft.wav', outputs2.numpy()[0, 0, :], 16000) # Save reconstructed waveform to file print('wav MSE', torch.mean(torch.abs(inputs[..., :outputs2.size(2)] - outputs2))) # Print mean squared error def test_ifft2(): """ Test the iSTFT reconstruction from a random input signal. This function generates a random signal, computes its STFT, and then reconstructs it using the ConviSTFT layer. The reconstructed waveform is saved to a file, and the mean squared error is printed to evaluate accuracy. """ N = 400 inc = 100 fft_len = 512 np.random.seed(20) torch.manual_seed(20) # Generate a random signal t = np.random.randn(16000*4) * 0.005 t = np.clip(t, -1, 1) # Clip to [-1, 1] range input = torch.from_numpy(t[None, None, :].astype(np.float32)) # Reshape for input fft = ConvSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex') # Initialize ConvSTFT ifft = ConviSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex') # Initialize ConviSTFT out1 = fft(input) # Compute STFT output = ifft(out1) # Reconstruct waveform from STFT print('random MSE', torch.mean(torch.abs(input - output)**2)) # Print mean squared error import soundfile as sf sf.write('zero.wav', output[0, 0].numpy(), 16000) # Save reconstructed waveform to file if __name__ == '__main__': #test_fft() test_ifft1() # Run the iSTFT reconstruction test #test_ifft2()