File size: 15,613 Bytes
8e8cd3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
#!/usr/bin/env python -u
# -*- coding: utf-8 -*-

# Import future compatibility features for Python 2/3
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# Import necessary libraries
import torch 
import torch.nn as nn
import numpy as np
from joblib import Parallel, delayed
from pesq import pesq  # PESQ metric for speech quality evaluation
import os 
import sys
import librosa  # Library for audio processing
import torchaudio  # Library for audio processing with PyTorch

# Constants
MAX_WAV_VALUE = 32768.0  # Maximum value for WAV files
EPS = 1e-6  # Small value to avoid division by zero

def read_and_config_file(input_path, decode=0):
    """Reads input paths from a file or directory and configures them for processing.

    Args:
        input_path (str): Path to the input directory or file.
        decode (int): Flag indicating if decoding should occur (1 for decode, 0 for standard read).

    Returns:
        list: A list of processed paths or dictionaries containing input and label paths.
    """
    processed_list = []

    # If decoding is requested, find files in a directory
    if decode:
        if os.path.isdir(input_path):
            processed_list = librosa.util.find_files(input_path, ext="wav")  # Look for WAV files
            if len(processed_list) == 0:
                processed_list = librosa.util.find_files(input_path, ext="flac")  # Fallback to FLAC files
        else:
            # Read paths from a file
            with open(input_path) as fid:
                for line in fid:
                    path_s = line.strip().split()  # Split line into parts
                    processed_list.append(path_s[0])  # Append the first part (input path)
        return processed_list

    # Read input-label pairs from a file
    with open(input_path) as fid:
        for line in fid:
            tmp_paths = line.strip().split()  # Split line into parts
            if len(tmp_paths) == 3:  # Expecting input, label, and duration
                sample = {'inputs': tmp_paths[0], 'labels': tmp_paths[1], 'duration': float(tmp_paths[2])}
            elif len(tmp_paths) == 2:  # Expecting input and label only
                sample = {'inputs': tmp_paths[0], 'labels': tmp_paths[1]}
            processed_list.append(sample)  # Append the sample dictionary
    return processed_list

def load_checkpoint(checkpoint_path, use_cuda):
    """Loads the model checkpoint from the specified path.

    Args:
        checkpoint_path (str): Path to the checkpoint file.
        use_cuda (bool): Flag indicating whether to use CUDA for loading.

    Returns:
        dict: The loaded checkpoint containing model parameters.
    """
    if use_cuda:
        checkpoint = torch.load(checkpoint_path)  # Load using CUDA
    else:
        checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)  # Load to CPU
    return checkpoint

def get_learning_rate(optimizer):
    """Retrieves the current learning rate from the optimizer.

    Args:
        optimizer (torch.optim.Optimizer): The optimizer instance.

    Returns:
        float: The current learning rate.
    """
    return optimizer.param_groups[0]["lr"]

def reload_for_eval(model, checkpoint_dir, use_cuda):
    """Reloads a model for evaluation from the specified checkpoint directory.

    Args:
        model (nn.Module): The model to be reloaded.
        checkpoint_dir (str): Directory containing checkpoints.
        use_cuda (bool): Flag indicating whether to use CUDA.

    Returns:
        None
    """
    print('Reloading from: {}'.format(checkpoint_dir))
    best_name = os.path.join(checkpoint_dir, 'last_best_checkpoint')  # Path to the best checkpoint
    ckpt_name = os.path.join(checkpoint_dir, 'last_checkpoint')  # Path to the last checkpoint
    if os.path.isfile(best_name):
        name = best_name 
    elif os.path.isfile(ckpt_name):
        name = ckpt_name
    else:
        print('Warning: No existing checkpoint or best_model found!')
        return
    
    with open(name, 'r') as f:
        model_name = f.readline().strip()  # Read the model name from the checkpoint file
    checkpoint_path = os.path.join(checkpoint_dir, model_name)  # Construct full checkpoint path
    print('Checkpoint path: {}'.format(checkpoint_path))
    checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
    #checkpoint = load_checkpoint(checkpoint_path, use_cuda)  # Load the checkpoint
    '''
    if 'model' in checkpoint:
        model.load_state_dict(checkpoint['model'], strict=False)  # Load model parameters
    else:
        model.load_state_dict(checkpoint, strict=False)
    '''
    if 'model' in checkpoint:
        pretrained_model = checkpoint['model']
    else:
        pretrained_model = checkpoint
    state = model.state_dict()
    for key in state.keys():
        if key in pretrained_model and state[key].shape == pretrained_model[key].shape:
            state[key] = pretrained_model[key]
        elif key.replace('module.', '') in pretrained_model and state[key].shape == pretrained_model[key.replace('module.', '')].shape:
            state[key] = pretrained_model[key.replace('module.', '')]
        elif 'module.'+key in pretrained_model and state[key].shape == pretrained_model['module.'+key].shape:
            state[key] = pretrained_model['module.'+key]
    model.load_state_dict(state)
    print('=> Reloaded well-trained model {} for decoding.'.format(model_name))

def reload_model(model, optimizer, checkpoint_dir, use_cuda=True, strict=True):
    """Reloads the model and optimizer state from a checkpoint.

    Args:
        model (nn.Module): The model to be reloaded.
        optimizer (torch.optim.Optimizer): The optimizer to be reloaded.
        checkpoint_dir (str): Directory containing checkpoints.
        use_cuda (bool): Flag indicating whether to use CUDA.
        strict (bool): If True, requires keys in state_dict to match exactly.

    Returns:
        tuple: Current epoch and step.
    """
    ckpt_name = os.path.join(checkpoint_dir, 'checkpoint')  # Path to the checkpoint file
    if os.path.isfile(ckpt_name):
        with open(ckpt_name, 'r') as f:
            model_name = f.readline().strip()  # Read model name from checkpoint file
        checkpoint_path = os.path.join(checkpoint_dir, model_name)  # Construct full checkpoint path
        checkpoint = load_checkpoint(checkpoint_path, use_cuda)  # Load the checkpoint
        model.load_state_dict(checkpoint['model'], strict=strict)  # Load model parameters
        optimizer.load_state_dict(checkpoint['optimizer'])  # Load optimizer parameters
        epoch = checkpoint['epoch']  # Get current epoch
        step = checkpoint['step']  # Get current step
        print('=> Reloaded previous model and optimizer.')
    else:
        print('[!] Checkpoint directory is empty. Train a new model ...')
        epoch = 0  # Initialize epoch
        step = 0  # Initialize step
    return epoch, step

def save_checkpoint(model, optimizer, epoch, step, checkpoint_dir, mode='checkpoint'):
    """Saves the model and optimizer state to a checkpoint file.

    Args:
        model (nn.Module): The model to be saved.
        optimizer (torch.optim.Optimizer): The optimizer to be saved.
        epoch (int): Current epoch number.
        step (int): Current training step number.
        checkpoint_dir (str): Directory to save the checkpoint.
        mode (str): Mode of the checkpoint ('checkpoint' or other).

    Returns:
        None
    """
    checkpoint_path = os.path.join(
        checkpoint_dir, 'model.ckpt-{}-{}.pt'.format(epoch, step))  # Construct checkpoint file path
    torch.save({'model': model.state_dict(),  # Save model parameters
                'optimizer': optimizer.state_dict(),  # Save optimizer parameters
                'epoch': epoch,  # Save epoch
                'step': step}, checkpoint_path)  # Save checkpoint to file

    # Save the checkpoint name to a file for easy access
    with open(os.path.join(checkpoint_dir, mode), 'w') as f:
        f.write('model.ckpt-{}-{}.pt'.format(epoch, step))
    print("=> Saved checkpoint:", checkpoint_path)

def setup_lr(opt, lr):
    """Sets the learning rate for all parameter groups in the optimizer.

    Args:
        opt (torch.optim.Optimizer): The optimizer instance whose learning rate needs to be set.
        lr (float): The new learning rate to be assigned.
    
    Returns:
        None
    """
    for param_group in opt.param_groups:
        param_group['lr'] = lr  # Update the learning rate for each parameter group


def pesq_loss(clean, noisy, sr=16000):
    """Calculates the PESQ (Perceptual Evaluation of Speech Quality) score between clean and noisy signals.

    Args:
        clean (ndarray): The clean audio signal.
        noisy (ndarray): The noisy audio signal.
        sr (int): Sample rate of the audio signals (default is 16000 Hz).

    Returns:
        float: The PESQ score or -1 in case of an error.
    """
    try:
        pesq_score = pesq(sr, clean, noisy, 'wb')  # Compute PESQ score
    except:
        # PESQ may fail due to silent periods in audio
        pesq_score = -1  # Assign -1 to indicate error
    return pesq_score


def batch_pesq(clean, noisy):
    """Computes the PESQ scores for batches of clean and noisy audio signals.

    Args:
        clean (list of ndarray): List of clean audio signals.
        noisy (list of ndarray): List of noisy audio signals.

    Returns:
        torch.FloatTensor: A tensor of normalized PESQ scores or None if any score is -1.
    """
    # Parallel processing for calculating PESQ scores for each pair of clean and noisy signals
    pesq_score = Parallel(n_jobs=-1)(delayed(pesq_loss)(c, n) for c, n in zip(clean, noisy))
    pesq_score = np.array(pesq_score)  # Convert to NumPy array
    
    if -1 in pesq_score:  # Check for errors in PESQ calculations
        return None
    
    # Normalize PESQ scores to a scale of 0 to 1
    pesq_score = (pesq_score - 1) / 3.5  
    return torch.FloatTensor(pesq_score).to('cuda')  # Return normalized scores as a tensor


def power_compress(x):
    """Compresses the power of a complex spectrogram.

    Args:
        x (torch.Tensor): Input tensor with real and imaginary components.

    Returns:
        torch.Tensor: Compressed magnitude and phase representation of the input.
    """
    real = x[..., 0]  # Extract real part
    imag = x[..., 1]  # Extract imaginary part
    spec = torch.complex(real, imag)  # Create complex tensor from real and imaginary parts
    mag = torch.abs(spec)  # Compute magnitude
    phase = torch.angle(spec)  # Compute phase
    
    mag = mag**0.3  # Compress magnitude using power of 0.3
    real_compress = mag * torch.cos(phase)  # Reconstruct real part
    imag_compress = mag * torch.sin(phase)  # Reconstruct imaginary part
    return torch.stack([real_compress, imag_compress], 1)  # Stack compressed parts


def power_uncompress(real, imag):
    """Uncompresses the power of a compressed complex spectrogram.

    Args:
        real (torch.Tensor): Compressed real component.
        imag (torch.Tensor): Compressed imaginary component.

    Returns:
        torch.Tensor: Uncompressed complex spectrogram.
    """
    spec = torch.complex(real, imag)  # Create complex tensor from real and imaginary parts
    mag = torch.abs(spec)  # Compute magnitude
    phase = torch.angle(spec)  # Compute phase
    
    mag = mag**(1./0.3)  # Uncompress magnitude by raising to the power of 1/0.3
    real_uncompress = mag * torch.cos(phase)  # Reconstruct real part
    imag_uncompress = mag * torch.sin(phase)  # Reconstruct imaginary part
    return torch.stack([real_uncompress, imag_uncompress], -1)  # Stack uncompressed parts


def stft(x, args, center=False):
    """Computes the Short-Time Fourier Transform (STFT) of an audio signal.

    Args:
        x (torch.Tensor): Input audio signal.
        args (Namespace): Configuration arguments containing window type and lengths.
        center (bool): Whether to center the window.

    Returns:
        torch.Tensor: The computed STFT of the input signal.
    """
    win_type = args.win_type
    win_len = args.win_len
    win_inc = args.win_inc
    fft_len = args.fft_len

    # Select window type and create window tensor
    if win_type == 'hamming':
        window = torch.hamming_window(win_len, periodic=False).to(x.device)
    elif win_type == 'hanning':
        window = torch.hann_window(win_len, periodic=False).to(x.device)
    else:
        print(f"In STFT, {win_type} is not supported!")
        return
    
    # Compute and return the STFT
    return torch.stft(x, fft_len, win_inc, win_len, center=center, window=window, return_complex=False)


def istft(x, args, slen=None, center=False, normalized=False, onsided=None, return_complex=False):
    """Computes the inverse Short-Time Fourier Transform (ISTFT) of a complex spectrogram.

    Args:
        x (torch.Tensor): Input complex spectrogram.
        args (Namespace): Configuration arguments containing window type and lengths.
        slen (int, optional): Length of the output signal.
        center (bool): Whether to center the window.
        normalized (bool): Whether to normalize the output.
        onsided (bool, optional): If True, computes only the one-sided transform.
        return_complex (bool): If True, returns complex output.

    Returns:
        torch.Tensor: The reconstructed audio signal from the spectrogram.
    """
    win_type = args.win_type
    win_len = args.win_len
    win_inc = args.win_inc
    fft_len = args.fft_len

    # Select window type and create window tensor
    if win_type == 'hamming':
        window = torch.hamming_window(win_len, periodic=False).to(x.device)
    elif win_type == 'hanning':
        window = torch.hann_window(win_len, periodic=False).to(x.device)
    else:
        print(f"In ISTFT, {win_type} is not supported!")
        return

    try:
        # Attempt to compute ISTFT
        output = torch.istft(x, n_fft=fft_len, hop_length=win_inc, win_length=win_len,
                              window=window, center=center, normalized=normalized,
                              onesided=onsided, length=slen, return_complex=False)
    except:
        # Handle potential errors by converting x to a complex tensor
        x_complex = torch.view_as_complex(x)
        output = torch.istft(x_complex, n_fft=fft_len, hop_length=win_inc, win_length=win_len,
                              window=window, center=center, normalized=normalized,
                              onesided=onsided, length=slen, return_complex=False)
    return output


def compute_fbank(audio_in, args):
    """Computes the filter bank features from an audio signal.

    Args:
        audio_in (torch.Tensor): Input audio signal.
        args (Namespace): Configuration arguments containing window length, shift, and sampling rate.

    Returns:
        torch.Tensor: Computed filter bank features.
    """
    frame_length = args.win_len / args.sampling_rate * 1000  # Frame length in milliseconds
    frame_shift = args.win_inc / args.sampling_rate * 1000  # Frame shift in milliseconds

    # Compute and return filter bank features using Kaldi's implementation
    return torchaudio.compliance.kaldi.fbank(audio_in, dither=1.0, frame_length=frame_length,
                                             frame_shift=frame_shift, num_mel_bins=args.num_mels,
                                             sample_frequency=args.sampling_rate, window_type=args.win_type)