Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python -u | |
# -*- coding: utf-8 -*- | |
# Import future compatibility features for Python 2/3 | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
# Import necessary libraries | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
from joblib import Parallel, delayed | |
from pesq import pesq # PESQ metric for speech quality evaluation | |
import os | |
import sys | |
import librosa # Library for audio processing | |
import torchaudio # Library for audio processing with PyTorch | |
# Constants | |
MAX_WAV_VALUE = 32768.0 # Maximum value for WAV files | |
EPS = 1e-6 # Small value to avoid division by zero | |
def read_and_config_file(input_path, decode=0): | |
"""Reads input paths from a file or directory and configures them for processing. | |
Args: | |
input_path (str): Path to the input directory or file. | |
decode (int): Flag indicating if decoding should occur (1 for decode, 0 for standard read). | |
Returns: | |
list: A list of processed paths or dictionaries containing input and label paths. | |
""" | |
processed_list = [] | |
# If decoding is requested, find files in a directory | |
if decode: | |
if os.path.isdir(input_path): | |
processed_list = librosa.util.find_files(input_path, ext="wav") # Look for WAV files | |
if len(processed_list) == 0: | |
processed_list = librosa.util.find_files(input_path, ext="flac") # Fallback to FLAC files | |
else: | |
# Read paths from a file | |
with open(input_path) as fid: | |
for line in fid: | |
path_s = line.strip().split() # Split line into parts | |
processed_list.append(path_s[0]) # Append the first part (input path) | |
return processed_list | |
# Read input-label pairs from a file | |
with open(input_path) as fid: | |
for line in fid: | |
tmp_paths = line.strip().split() # Split line into parts | |
if len(tmp_paths) == 3: # Expecting input, label, and duration | |
sample = {'inputs': tmp_paths[0], 'labels': tmp_paths[1], 'duration': float(tmp_paths[2])} | |
elif len(tmp_paths) == 2: # Expecting input and label only | |
sample = {'inputs': tmp_paths[0], 'labels': tmp_paths[1]} | |
processed_list.append(sample) # Append the sample dictionary | |
return processed_list | |
def load_checkpoint(checkpoint_path, use_cuda): | |
"""Loads the model checkpoint from the specified path. | |
Args: | |
checkpoint_path (str): Path to the checkpoint file. | |
use_cuda (bool): Flag indicating whether to use CUDA for loading. | |
Returns: | |
dict: The loaded checkpoint containing model parameters. | |
""" | |
if use_cuda: | |
checkpoint = torch.load(checkpoint_path) # Load using CUDA | |
else: | |
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) # Load to CPU | |
return checkpoint | |
def get_learning_rate(optimizer): | |
"""Retrieves the current learning rate from the optimizer. | |
Args: | |
optimizer (torch.optim.Optimizer): The optimizer instance. | |
Returns: | |
float: The current learning rate. | |
""" | |
return optimizer.param_groups[0]["lr"] | |
def reload_for_eval(model, checkpoint_dir, use_cuda): | |
"""Reloads a model for evaluation from the specified checkpoint directory. | |
Args: | |
model (nn.Module): The model to be reloaded. | |
checkpoint_dir (str): Directory containing checkpoints. | |
use_cuda (bool): Flag indicating whether to use CUDA. | |
Returns: | |
None | |
""" | |
print('Reloading from: {}'.format(checkpoint_dir)) | |
best_name = os.path.join(checkpoint_dir, 'last_best_checkpoint') # Path to the best checkpoint | |
ckpt_name = os.path.join(checkpoint_dir, 'last_checkpoint') # Path to the last checkpoint | |
if os.path.isfile(best_name): | |
name = best_name | |
elif os.path.isfile(ckpt_name): | |
name = ckpt_name | |
else: | |
print('Warning: No existing checkpoint or best_model found!') | |
return | |
with open(name, 'r') as f: | |
model_name = f.readline().strip() # Read the model name from the checkpoint file | |
checkpoint_path = os.path.join(checkpoint_dir, model_name) # Construct full checkpoint path | |
print('Checkpoint path: {}'.format(checkpoint_path)) | |
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) | |
#checkpoint = load_checkpoint(checkpoint_path, use_cuda) # Load the checkpoint | |
''' | |
if 'model' in checkpoint: | |
model.load_state_dict(checkpoint['model'], strict=False) # Load model parameters | |
else: | |
model.load_state_dict(checkpoint, strict=False) | |
''' | |
if 'model' in checkpoint: | |
pretrained_model = checkpoint['model'] | |
else: | |
pretrained_model = checkpoint | |
state = model.state_dict() | |
for key in state.keys(): | |
if key in pretrained_model and state[key].shape == pretrained_model[key].shape: | |
state[key] = pretrained_model[key] | |
elif key.replace('module.', '') in pretrained_model and state[key].shape == pretrained_model[key.replace('module.', '')].shape: | |
state[key] = pretrained_model[key.replace('module.', '')] | |
elif 'module.'+key in pretrained_model and state[key].shape == pretrained_model['module.'+key].shape: | |
state[key] = pretrained_model['module.'+key] | |
model.load_state_dict(state) | |
print('=> Reloaded well-trained model {} for decoding.'.format(model_name)) | |
def reload_model(model, optimizer, checkpoint_dir, use_cuda=True, strict=True): | |
"""Reloads the model and optimizer state from a checkpoint. | |
Args: | |
model (nn.Module): The model to be reloaded. | |
optimizer (torch.optim.Optimizer): The optimizer to be reloaded. | |
checkpoint_dir (str): Directory containing checkpoints. | |
use_cuda (bool): Flag indicating whether to use CUDA. | |
strict (bool): If True, requires keys in state_dict to match exactly. | |
Returns: | |
tuple: Current epoch and step. | |
""" | |
ckpt_name = os.path.join(checkpoint_dir, 'checkpoint') # Path to the checkpoint file | |
if os.path.isfile(ckpt_name): | |
with open(ckpt_name, 'r') as f: | |
model_name = f.readline().strip() # Read model name from checkpoint file | |
checkpoint_path = os.path.join(checkpoint_dir, model_name) # Construct full checkpoint path | |
checkpoint = load_checkpoint(checkpoint_path, use_cuda) # Load the checkpoint | |
model.load_state_dict(checkpoint['model'], strict=strict) # Load model parameters | |
optimizer.load_state_dict(checkpoint['optimizer']) # Load optimizer parameters | |
epoch = checkpoint['epoch'] # Get current epoch | |
step = checkpoint['step'] # Get current step | |
print('=> Reloaded previous model and optimizer.') | |
else: | |
print('[!] Checkpoint directory is empty. Train a new model ...') | |
epoch = 0 # Initialize epoch | |
step = 0 # Initialize step | |
return epoch, step | |
def save_checkpoint(model, optimizer, epoch, step, checkpoint_dir, mode='checkpoint'): | |
"""Saves the model and optimizer state to a checkpoint file. | |
Args: | |
model (nn.Module): The model to be saved. | |
optimizer (torch.optim.Optimizer): The optimizer to be saved. | |
epoch (int): Current epoch number. | |
step (int): Current training step number. | |
checkpoint_dir (str): Directory to save the checkpoint. | |
mode (str): Mode of the checkpoint ('checkpoint' or other). | |
Returns: | |
None | |
""" | |
checkpoint_path = os.path.join( | |
checkpoint_dir, 'model.ckpt-{}-{}.pt'.format(epoch, step)) # Construct checkpoint file path | |
torch.save({'model': model.state_dict(), # Save model parameters | |
'optimizer': optimizer.state_dict(), # Save optimizer parameters | |
'epoch': epoch, # Save epoch | |
'step': step}, checkpoint_path) # Save checkpoint to file | |
# Save the checkpoint name to a file for easy access | |
with open(os.path.join(checkpoint_dir, mode), 'w') as f: | |
f.write('model.ckpt-{}-{}.pt'.format(epoch, step)) | |
print("=> Saved checkpoint:", checkpoint_path) | |
def setup_lr(opt, lr): | |
"""Sets the learning rate for all parameter groups in the optimizer. | |
Args: | |
opt (torch.optim.Optimizer): The optimizer instance whose learning rate needs to be set. | |
lr (float): The new learning rate to be assigned. | |
Returns: | |
None | |
""" | |
for param_group in opt.param_groups: | |
param_group['lr'] = lr # Update the learning rate for each parameter group | |
def pesq_loss(clean, noisy, sr=16000): | |
"""Calculates the PESQ (Perceptual Evaluation of Speech Quality) score between clean and noisy signals. | |
Args: | |
clean (ndarray): The clean audio signal. | |
noisy (ndarray): The noisy audio signal. | |
sr (int): Sample rate of the audio signals (default is 16000 Hz). | |
Returns: | |
float: The PESQ score or -1 in case of an error. | |
""" | |
try: | |
pesq_score = pesq(sr, clean, noisy, 'wb') # Compute PESQ score | |
except: | |
# PESQ may fail due to silent periods in audio | |
pesq_score = -1 # Assign -1 to indicate error | |
return pesq_score | |
def batch_pesq(clean, noisy): | |
"""Computes the PESQ scores for batches of clean and noisy audio signals. | |
Args: | |
clean (list of ndarray): List of clean audio signals. | |
noisy (list of ndarray): List of noisy audio signals. | |
Returns: | |
torch.FloatTensor: A tensor of normalized PESQ scores or None if any score is -1. | |
""" | |
# Parallel processing for calculating PESQ scores for each pair of clean and noisy signals | |
pesq_score = Parallel(n_jobs=-1)(delayed(pesq_loss)(c, n) for c, n in zip(clean, noisy)) | |
pesq_score = np.array(pesq_score) # Convert to NumPy array | |
if -1 in pesq_score: # Check for errors in PESQ calculations | |
return None | |
# Normalize PESQ scores to a scale of 0 to 1 | |
pesq_score = (pesq_score - 1) / 3.5 | |
return torch.FloatTensor(pesq_score).to('cuda') # Return normalized scores as a tensor | |
def power_compress(x): | |
"""Compresses the power of a complex spectrogram. | |
Args: | |
x (torch.Tensor): Input tensor with real and imaginary components. | |
Returns: | |
torch.Tensor: Compressed magnitude and phase representation of the input. | |
""" | |
real = x[..., 0] # Extract real part | |
imag = x[..., 1] # Extract imaginary part | |
spec = torch.complex(real, imag) # Create complex tensor from real and imaginary parts | |
mag = torch.abs(spec) # Compute magnitude | |
phase = torch.angle(spec) # Compute phase | |
mag = mag**0.3 # Compress magnitude using power of 0.3 | |
real_compress = mag * torch.cos(phase) # Reconstruct real part | |
imag_compress = mag * torch.sin(phase) # Reconstruct imaginary part | |
return torch.stack([real_compress, imag_compress], 1) # Stack compressed parts | |
def power_uncompress(real, imag): | |
"""Uncompresses the power of a compressed complex spectrogram. | |
Args: | |
real (torch.Tensor): Compressed real component. | |
imag (torch.Tensor): Compressed imaginary component. | |
Returns: | |
torch.Tensor: Uncompressed complex spectrogram. | |
""" | |
spec = torch.complex(real, imag) # Create complex tensor from real and imaginary parts | |
mag = torch.abs(spec) # Compute magnitude | |
phase = torch.angle(spec) # Compute phase | |
mag = mag**(1./0.3) # Uncompress magnitude by raising to the power of 1/0.3 | |
real_uncompress = mag * torch.cos(phase) # Reconstruct real part | |
imag_uncompress = mag * torch.sin(phase) # Reconstruct imaginary part | |
return torch.stack([real_uncompress, imag_uncompress], -1) # Stack uncompressed parts | |
def stft(x, args, center=False): | |
"""Computes the Short-Time Fourier Transform (STFT) of an audio signal. | |
Args: | |
x (torch.Tensor): Input audio signal. | |
args (Namespace): Configuration arguments containing window type and lengths. | |
center (bool): Whether to center the window. | |
Returns: | |
torch.Tensor: The computed STFT of the input signal. | |
""" | |
win_type = args.win_type | |
win_len = args.win_len | |
win_inc = args.win_inc | |
fft_len = args.fft_len | |
# Select window type and create window tensor | |
if win_type == 'hamming': | |
window = torch.hamming_window(win_len, periodic=False).to(x.device) | |
elif win_type == 'hanning': | |
window = torch.hann_window(win_len, periodic=False).to(x.device) | |
else: | |
print(f"In STFT, {win_type} is not supported!") | |
return | |
# Compute and return the STFT | |
return torch.stft(x, fft_len, win_inc, win_len, center=center, window=window, return_complex=False) | |
def istft(x, args, slen=None, center=False, normalized=False, onsided=None, return_complex=False): | |
"""Computes the inverse Short-Time Fourier Transform (ISTFT) of a complex spectrogram. | |
Args: | |
x (torch.Tensor): Input complex spectrogram. | |
args (Namespace): Configuration arguments containing window type and lengths. | |
slen (int, optional): Length of the output signal. | |
center (bool): Whether to center the window. | |
normalized (bool): Whether to normalize the output. | |
onsided (bool, optional): If True, computes only the one-sided transform. | |
return_complex (bool): If True, returns complex output. | |
Returns: | |
torch.Tensor: The reconstructed audio signal from the spectrogram. | |
""" | |
win_type = args.win_type | |
win_len = args.win_len | |
win_inc = args.win_inc | |
fft_len = args.fft_len | |
# Select window type and create window tensor | |
if win_type == 'hamming': | |
window = torch.hamming_window(win_len, periodic=False).to(x.device) | |
elif win_type == 'hanning': | |
window = torch.hann_window(win_len, periodic=False).to(x.device) | |
else: | |
print(f"In ISTFT, {win_type} is not supported!") | |
return | |
try: | |
# Attempt to compute ISTFT | |
output = torch.istft(x, n_fft=fft_len, hop_length=win_inc, win_length=win_len, | |
window=window, center=center, normalized=normalized, | |
onesided=onsided, length=slen, return_complex=False) | |
except: | |
# Handle potential errors by converting x to a complex tensor | |
x_complex = torch.view_as_complex(x) | |
output = torch.istft(x_complex, n_fft=fft_len, hop_length=win_inc, win_length=win_len, | |
window=window, center=center, normalized=normalized, | |
onesided=onsided, length=slen, return_complex=False) | |
return output | |
def compute_fbank(audio_in, args): | |
"""Computes the filter bank features from an audio signal. | |
Args: | |
audio_in (torch.Tensor): Input audio signal. | |
args (Namespace): Configuration arguments containing window length, shift, and sampling rate. | |
Returns: | |
torch.Tensor: Computed filter bank features. | |
""" | |
frame_length = args.win_len / args.sampling_rate * 1000 # Frame length in milliseconds | |
frame_shift = args.win_inc / args.sampling_rate * 1000 # Frame shift in milliseconds | |
# Compute and return filter bank features using Kaldi's implementation | |
return torchaudio.compliance.kaldi.fbank(audio_in, dither=1.0, frame_length=frame_length, | |
frame_shift=frame_shift, num_mel_bins=args.num_mels, | |
sample_frequency=args.sampling_rate, window_type=args.win_type) | |