File size: 12,734 Bytes
8e8cd3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from scipy.signal import get_window
import scipy

def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False):
    """
    Initialize the kernels for STFT and iSTFT operations.

    This function generates the kernel for the convolutional layers used in the short-time Fourier transform (STFT)
    and its inverse (iSTFT). The kernel is created based on the window type and length specified.

    Args:
        win_len (int): Length of the window.
        win_inc (int): Window increment (hop length).
        fft_len (int): Length of the FFT.
        win_type (str, optional): Type of window to apply (e.g., 'hanning', 'hamming'). Default is None (rectangular window).
        invers (bool, optional): If True, computes the pseudo-inverse of the kernel. Default is False.

    Returns:
        tuple: A tuple containing:
            - torch.Tensor: The kernel used for convolution, with shape (2 * win_len, 1, fft_len).
            - torch.Tensor: The window applied to the kernel, with shape (1, win_len, 1).
    """
    if win_type == 'None' or win_type is None:
        window = np.ones(win_len)
    else:
        # Convert 'hanning' to 'hann' if using scipy version 1.10.1 or higher
        if scipy.__version__ >= '1.10.1' and win_type == 'hanning':
            win_type = 'hann'
        window = get_window(win_type, win_len, fftbins=True)**0.5
    
    N = fft_len
    fourier_basis = np.fft.rfft(np.eye(N))[:win_len]  # Compute Fourier basis for the identity matrix
    real_kernel = np.real(fourier_basis)  # Extract real part
    imag_kernel = np.imag(fourier_basis)  # Extract imaginary part
    kernel = np.concatenate([real_kernel, imag_kernel], 1).T  # Combine real and imaginary parts

    if invers:
        kernel = np.linalg.pinv(kernel).T  # Compute pseudo-inverse if required

    kernel = kernel * window  # Apply window to kernel
    kernel = kernel[:, None, :]  # Add singleton dimension for compatibility
    return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None,:,None].astype(np.float32))


class ConvSTFT(nn.Module):
    """
    Convolutional layer that performs Short-Time Fourier Transform (STFT).

    This class applies the STFT to input signals using convolution with pre-computed kernels.
    It can return either the complex STFT representation or the magnitude and phase components.

    Attributes:
        weight (nn.Parameter): Learnable convolution kernel for STFT.
        feature_type (str): Specifies whether to return 'complex' or 'real' features.
        stride (int): The stride used for convolution, typically equal to win_inc.
        win_len (int): The length of the window used in STFT.
        dim (int): The FFT length, determining the number of output features.
    """

    def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True):
        """
        Initializes the ConvSTFT layer.

        Args:
            win_len (int): Length of the window.
            win_inc (int): Window increment (hop length).
            fft_len (int, optional): Length of the FFT. If None, it's computed based on win_len.
            win_type (str, optional): Type of window to use (default is 'hamming').
            feature_type (str, optional): Specifies the output feature type ('real' or 'complex'). Default is 'real'.
            fix (bool, optional): If True, the kernel weights are fixed and not learnable. Default is True.
        """
        super(ConvSTFT, self).__init__()

        if fft_len is None:
            self.fft_len = np.int(2**np.ceil(np.log2(win_len)))  # Calculate fft_len based on win_len
        else:
            self.fft_len = fft_len
        
        kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type)  # Initialize STFT kernel
        self.weight = nn.Parameter(kernel, requires_grad=(not fix))  # Create a learnable parameter for the kernel
        self.feature_type = feature_type
        self.stride = win_inc
        self.win_len = win_len
        self.dim = self.fft_len

    def forward(self, inputs):
        """
        Forward pass through the ConvSTFT layer.

        Args:
            inputs (torch.Tensor): Input tensor of shape (batch_size, channels, length).

        Returns:
            tuple or torch.Tensor: Depending on feature_type, returns either:
                - torch.Tensor: The complex STFT output if feature_type is 'complex'.
                - tuple: A tuple containing the magnitude and phase tensors if feature_type is 'real'.
        """
        if inputs.dim() == 2:
            inputs = torch.unsqueeze(inputs, 1)  # Add channel dimension if not present

        outputs = F.conv1d(inputs, self.weight, stride=self.stride)  # Perform convolution to compute STFT
         
        if self.feature_type == 'complex':
            return outputs
        else:
            dim = self.dim // 2 + 1  # Calculate the size for the real and imaginary components
            real = outputs[:, :dim, :]  # Extract real part
            imag = outputs[:, dim:, :]  # Extract imaginary part
            mags = torch.sqrt(real**2 + imag**2)  # Compute magnitude
            phase = torch.atan2(imag, real)  # Compute phase
            return mags, phase  # Return magnitude and phase


class ConviSTFT(nn.Module):
    """
    Convolutional layer that performs Inverse Short-Time Fourier Transform (iSTFT).

    This class applies the iSTFT to reconstruct the time-domain signal from the frequency-domain representation
    obtained from the ConvSTFT layer.

    Attributes:
        weight (nn.Parameter): Learnable convolution kernel for iSTFT.
        feature_type (str): Specifies whether to use 'real' or 'complex' features for reconstruction.
        win_type (str): Type of window used during iSTFT.
        win_len (int): The length of the window used in iSTFT.
        win_inc (int): The window increment (hop length).
        stride (int): The stride used for transposed convolution, typically equal to win_inc.
        dim (int): The FFT length, determining the number of output features.
        window (torch.Tensor): Buffer for the window used in iSTFT.
        enframe (torch.Tensor): Buffer for the framing matrix.
    """

    def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True):
        """
        Initializes the ConviSTFT layer.

        Args:
            win_len (int): Length of the window.
            win_inc (int): Window increment (hop length).
            fft_len (int, optional): Length of the FFT. If None, it's computed based on win_len.
            win_type (str, optional): Type of window to use (default is 'hamming').
            feature_type (str, optional): Specifies the output feature type ('real' or 'complex'). Default is 'real'.
            fix (bool, optional): If True, the kernel weights are fixed and not learnable. Default is True.
        """
        super(ConviSTFT, self).__init__()

        if fft_len is None:
            self.fft_len = np.int(2**np.ceil(np.log2(win_len)))  # Calculate fft_len based on win_len
        else:
            self.fft_len = fft_len

        kernel, window = init_kernels(win_len, win_inc, self.fft_len, win_type, invers=True)  # Initialize iSTFT kernel
        self.weight = nn.Parameter(kernel, requires_grad=(not fix))  # Create a learnable parameter for the kernel
        self.feature_type = feature_type
        self.win_type = win_type
        self.win_len = win_len
        self.win_inc = win_inc
        self.stride = win_inc
        self.dim = self.fft_len
        self.register_buffer('window', window)  # Register the window as a buffer
        self.register_buffer('enframe', torch.eye(win_len)[:, None, :])  # Framing matrix for overlap-add method

    def forward(self, inputs, phase=None):
        """
        Forward pass through the ConviSTFT layer.

        Args:
            inputs (torch.Tensor): Input tensor of shape [B, N+2, T] for complex spectra 
                                   or [B, N//2+1, T] for magnitude spectra.
            phase (torch.Tensor, optional): Phase tensor of shape [B, N//2+1, T]. If provided, used to reconstruct the complex spectra.

        Returns:
            torch.Tensor: Reconstructed time-domain signal.
        """
        if phase is not None:
            # Reconstruct real and imaginary components from magnitude and phase
            real = inputs * torch.cos(phase)  # Real part
            imag = inputs * torch.sin(phase)  # Imaginary part
            inputs = torch.cat([real, imag], 1)  # Concatenate to form complex input

        outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride)  # Perform transposed convolution for iSTFT

        # Compute the overlap-add normalization
        t = self.window.repeat(1, 1, inputs.size(-1))**2
        coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)  # Apply the framing matrix for overlap-add
        outputs = outputs / (coff + 1e-8)  # Normalize the output to prevent division by zero
        return outputs


def test_fft():
    """
    Test the ConvSTFT layer against Librosa's STFT implementation.

    This function generates a random input signal and computes its STFT using the ConvSTFT layer,
    then compares the output with the STFT computed using Librosa to ensure correctness.
    """
    torch.manual_seed(20)
    win_len = 320
    win_inc = 160 
    fft_len = 512
    inputs = torch.randn([1, 1, 16000*4])  # Random input tensor
    fft = ConvSTFT(win_len, win_inc, fft_len, win_type='hanning', feature_type='real')  # Initialize ConvSTFT

    outputs1 = fft(inputs)[0]  # Compute STFT using ConvSTFT
    outputs1 = outputs1.numpy()[0]  # Convert to NumPy array for comparison
    np_inputs = inputs.numpy().reshape([-1])  # Reshape input for Librosa
    librosa_stft = librosa.stft(np_inputs, win_length=win_len, n_fft=fft_len, hop_length=win_inc, center=False)  # Compute STFT using Librosa
    print(np.mean((outputs1 - np.abs(librosa_stft))**2))  # Print mean squared error between the two STFT outputs


def test_ifft1():
    """
    Test the ConviSTFT layer by reconstructing a waveform from the STFT output.

    This function reads an audio file, applies the ConvSTFT to compute its STFT, and then
    uses the ConviSTFT to reconstruct the time-domain signal. The reconstructed signal is saved to a file
    and compared to the original to evaluate the reconstruction accuracy.
    """
    import soundfile as sf
    N = 100
    inc = 75
    fft_len = 512
    torch.manual_seed(N)

    # Read input audio file and reshape
    data = sf.read('../../wavs/ori.wav')[0]
    inputs = data.reshape([1, 1, -1])  # Reshape to [1, 1, length]

    fft = ConvSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex')  # Initialize ConvSTFT
    ifft = ConviSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex')  # Initialize ConviSTFT
    
    inputs = torch.from_numpy(inputs.astype(np.float32))  # Convert to torch tensor
    outputs1 = fft(inputs)  # Compute STFT
    outputs2 = ifft(outputs1)  # Reconstruct waveform from STFT
    sf.write('conv_stft.wav', outputs2.numpy()[0, 0, :], 16000)  # Save reconstructed waveform to file
    print('wav MSE', torch.mean(torch.abs(inputs[..., :outputs2.size(2)] - outputs2)))  # Print mean squared error


def test_ifft2():
    """
    Test the iSTFT reconstruction from a random input signal.

    This function generates a random signal, computes its STFT, and then reconstructs it using the ConviSTFT layer.
    The reconstructed waveform is saved to a file, and the mean squared error is printed to evaluate accuracy.
    """
    N = 400
    inc = 100
    fft_len = 512
    np.random.seed(20)
    torch.manual_seed(20)
    
    # Generate a random signal
    t = np.random.randn(16000*4) * 0.005
    t = np.clip(t, -1, 1)  # Clip to [-1, 1] range
    input = torch.from_numpy(t[None, None, :].astype(np.float32))  # Reshape for input

    fft = ConvSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex')  # Initialize ConvSTFT
    ifft = ConviSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex')  # Initialize ConviSTFT
    
    out1 = fft(input)  # Compute STFT
    output = ifft(out1)  # Reconstruct waveform from STFT
    print('random MSE', torch.mean(torch.abs(input - output)**2))  # Print mean squared error
    import soundfile as sf
    sf.write('zero.wav', output[0, 0].numpy(), 16000)  # Save reconstructed waveform to file


if __name__ == '__main__':
    #test_fft()
    test_ifft1()  # Run the iSTFT reconstruction test
    #test_ifft2()