File size: 12,734 Bytes
8e8cd3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 |
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from scipy.signal import get_window
import scipy
def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False):
Initialize the kernels for STFT and iSTFT operations.
This function generates the kernel for the convolutional layers used in the short-time Fourier transform (STFT)
and its inverse (iSTFT). The kernel is created based on the window type and length specified.
win_len (int): Length of the window.
win_inc (int): Window increment (hop length).
fft_len (int): Length of the FFT.
win_type (str, optional): Type of window to apply (e.g., 'hanning', 'hamming'). Default is None (rectangular window).
invers (bool, optional): If True, computes the pseudo-inverse of the kernel. Default is False.
tuple: A tuple containing:
- torch.Tensor: The kernel used for convolution, with shape (2 * win_len, 1, fft_len).
- torch.Tensor: The window applied to the kernel, with shape (1, win_len, 1).
if win_type == 'None' or win_type is None:
window = np.ones(win_len)
# Convert 'hanning' to 'hann' if using scipy version 1.10.1 or higher
if scipy.__version__ >= '1.10.1' and win_type == 'hanning':
win_type = 'hann'
window = get_window(win_type, win_len, fftbins=True)**0.5
N = fft_len
fourier_basis = np.fft.rfft(np.eye(N))[:win_len] # Compute Fourier basis for the identity matrix
real_kernel = np.real(fourier_basis) # Extract real part
imag_kernel = np.imag(fourier_basis) # Extract imaginary part
kernel = np.concatenate([real_kernel, imag_kernel], 1).T # Combine real and imaginary parts
if invers:
kernel = np.linalg.pinv(kernel).T # Compute pseudo-inverse if required
kernel = kernel * window # Apply window to kernel
kernel = kernel[:, None, :] # Add singleton dimension for compatibility
return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None,:,None].astype(np.float32))
class ConvSTFT(nn.Module):
Convolutional layer that performs Short-Time Fourier Transform (STFT).
This class applies the STFT to input signals using convolution with pre-computed kernels.
It can return either the complex STFT representation or the magnitude and phase components.
weight (nn.Parameter): Learnable convolution kernel for STFT.
feature_type (str): Specifies whether to return 'complex' or 'real' features.
stride (int): The stride used for convolution, typically equal to win_inc.
win_len (int): The length of the window used in STFT.
dim (int): The FFT length, determining the number of output features.
def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True):
Initializes the ConvSTFT layer.
win_len (int): Length of the window.
win_inc (int): Window increment (hop length).
fft_len (int, optional): Length of the FFT. If None, it's computed based on win_len.
win_type (str, optional): Type of window to use (default is 'hamming').
feature_type (str, optional): Specifies the output feature type ('real' or 'complex'). Default is 'real'.
fix (bool, optional): If True, the kernel weights are fixed and not learnable. Default is True.
super(ConvSTFT, self).__init__()
if fft_len is None:
self.fft_len =**np.ceil(np.log2(win_len))) # Calculate fft_len based on win_len
self.fft_len = fft_len
kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type) # Initialize STFT kernel
self.weight = nn.Parameter(kernel, requires_grad=(not fix)) # Create a learnable parameter for the kernel
self.feature_type = feature_type
self.stride = win_inc
self.win_len = win_len
self.dim = self.fft_len
def forward(self, inputs):
Forward pass through the ConvSTFT layer.
inputs (torch.Tensor): Input tensor of shape (batch_size, channels, length).
tuple or torch.Tensor: Depending on feature_type, returns either:
- torch.Tensor: The complex STFT output if feature_type is 'complex'.
- tuple: A tuple containing the magnitude and phase tensors if feature_type is 'real'.
if inputs.dim() == 2:
inputs = torch.unsqueeze(inputs, 1) # Add channel dimension if not present
outputs = F.conv1d(inputs, self.weight, stride=self.stride) # Perform convolution to compute STFT
if self.feature_type == 'complex':
return outputs
dim = self.dim // 2 + 1 # Calculate the size for the real and imaginary components
real = outputs[:, :dim, :] # Extract real part
imag = outputs[:, dim:, :] # Extract imaginary part
mags = torch.sqrt(real**2 + imag**2) # Compute magnitude
phase = torch.atan2(imag, real) # Compute phase
return mags, phase # Return magnitude and phase
class ConviSTFT(nn.Module):
Convolutional layer that performs Inverse Short-Time Fourier Transform (iSTFT).
This class applies the iSTFT to reconstruct the time-domain signal from the frequency-domain representation
obtained from the ConvSTFT layer.
weight (nn.Parameter): Learnable convolution kernel for iSTFT.
feature_type (str): Specifies whether to use 'real' or 'complex' features for reconstruction.
win_type (str): Type of window used during iSTFT.
win_len (int): The length of the window used in iSTFT.
win_inc (int): The window increment (hop length).
stride (int): The stride used for transposed convolution, typically equal to win_inc.
dim (int): The FFT length, determining the number of output features.
window (torch.Tensor): Buffer for the window used in iSTFT.
enframe (torch.Tensor): Buffer for the framing matrix.
def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True):
Initializes the ConviSTFT layer.
win_len (int): Length of the window.
win_inc (int): Window increment (hop length).
fft_len (int, optional): Length of the FFT. If None, it's computed based on win_len.
win_type (str, optional): Type of window to use (default is 'hamming').
feature_type (str, optional): Specifies the output feature type ('real' or 'complex'). Default is 'real'.
fix (bool, optional): If True, the kernel weights are fixed and not learnable. Default is True.
super(ConviSTFT, self).__init__()
if fft_len is None:
self.fft_len =**np.ceil(np.log2(win_len))) # Calculate fft_len based on win_len
self.fft_len = fft_len
kernel, window = init_kernels(win_len, win_inc, self.fft_len, win_type, invers=True) # Initialize iSTFT kernel
self.weight = nn.Parameter(kernel, requires_grad=(not fix)) # Create a learnable parameter for the kernel
self.feature_type = feature_type
self.win_type = win_type
self.win_len = win_len
self.win_inc = win_inc
self.stride = win_inc
self.dim = self.fft_len
self.register_buffer('window', window) # Register the window as a buffer
self.register_buffer('enframe', torch.eye(win_len)[:, None, :]) # Framing matrix for overlap-add method
def forward(self, inputs, phase=None):
Forward pass through the ConviSTFT layer.
inputs (torch.Tensor): Input tensor of shape [B, N+2, T] for complex spectra
or [B, N//2+1, T] for magnitude spectra.
phase (torch.Tensor, optional): Phase tensor of shape [B, N//2+1, T]. If provided, used to reconstruct the complex spectra.
torch.Tensor: Reconstructed time-domain signal.
if phase is not None:
# Reconstruct real and imaginary components from magnitude and phase
real = inputs * torch.cos(phase) # Real part
imag = inputs * torch.sin(phase) # Imaginary part
inputs =[real, imag], 1) # Concatenate to form complex input
outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride) # Perform transposed convolution for iSTFT
# Compute the overlap-add normalization
t = self.window.repeat(1, 1, inputs.size(-1))**2
coff = F.conv_transpose1d(t, self.enframe, stride=self.stride) # Apply the framing matrix for overlap-add
outputs = outputs / (coff + 1e-8) # Normalize the output to prevent division by zero
return outputs
def test_fft():
Test the ConvSTFT layer against Librosa's STFT implementation.
This function generates a random input signal and computes its STFT using the ConvSTFT layer,
then compares the output with the STFT computed using Librosa to ensure correctness.
win_len = 320
win_inc = 160
fft_len = 512
inputs = torch.randn([1, 1, 16000*4]) # Random input tensor
fft = ConvSTFT(win_len, win_inc, fft_len, win_type='hanning', feature_type='real') # Initialize ConvSTFT
outputs1 = fft(inputs)[0] # Compute STFT using ConvSTFT
outputs1 = outputs1.numpy()[0] # Convert to NumPy array for comparison
np_inputs = inputs.numpy().reshape([-1]) # Reshape input for Librosa
librosa_stft = librosa.stft(np_inputs, win_length=win_len, n_fft=fft_len, hop_length=win_inc, center=False) # Compute STFT using Librosa
print(np.mean((outputs1 - np.abs(librosa_stft))**2)) # Print mean squared error between the two STFT outputs
def test_ifft1():
Test the ConviSTFT layer by reconstructing a waveform from the STFT output.
This function reads an audio file, applies the ConvSTFT to compute its STFT, and then
uses the ConviSTFT to reconstruct the time-domain signal. The reconstructed signal is saved to a file
and compared to the original to evaluate the reconstruction accuracy.
import soundfile as sf
N = 100
inc = 75
fft_len = 512
# Read input audio file and reshape
data ='../../wavs/ori.wav')[0]
inputs = data.reshape([1, 1, -1]) # Reshape to [1, 1, length]
fft = ConvSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex') # Initialize ConvSTFT
ifft = ConviSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex') # Initialize ConviSTFT
inputs = torch.from_numpy(inputs.astype(np.float32)) # Convert to torch tensor
outputs1 = fft(inputs) # Compute STFT
outputs2 = ifft(outputs1) # Reconstruct waveform from STFT
sf.write('conv_stft.wav', outputs2.numpy()[0, 0, :], 16000) # Save reconstructed waveform to file
print('wav MSE', torch.mean(torch.abs(inputs[..., :outputs2.size(2)] - outputs2))) # Print mean squared error
def test_ifft2():
Test the iSTFT reconstruction from a random input signal.
This function generates a random signal, computes its STFT, and then reconstructs it using the ConviSTFT layer.
The reconstructed waveform is saved to a file, and the mean squared error is printed to evaluate accuracy.
N = 400
inc = 100
fft_len = 512
# Generate a random signal
t = np.random.randn(16000*4) * 0.005
t = np.clip(t, -1, 1) # Clip to [-1, 1] range
input = torch.from_numpy(t[None, None, :].astype(np.float32)) # Reshape for input
fft = ConvSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex') # Initialize ConvSTFT
ifft = ConviSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex') # Initialize ConviSTFT
out1 = fft(input) # Compute STFT
output = ifft(out1) # Reconstruct waveform from STFT
print('random MSE', torch.mean(torch.abs(input - output)**2)) # Print mean squared error
import soundfile as sf
sf.write('zero.wav', output[0, 0].numpy(), 16000) # Save reconstructed waveform to file
if __name__ == '__main__':
test_ifft1() # Run the iSTFT reconstruction test