Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import math, os, csv | |
import torchaudio | |
import torch | |
import torch.nn as nn | |
import torch.utils.data as data | |
import torch.distributed as dist | |
import soundfile as sf | |
from torch.utils.data import Dataset | |
import torch.utils.data as data | |
import os | |
import sys | |
sys.path.append(os.path.dirname(__file__)) | |
from dataloader.misc import read_and_config_file | |
import librosa | |
import random | |
EPS = 1e-6 | |
MAX_WAV_VALUE = 32768.0 | |
def audioread(path, sampling_rate): | |
""" | |
Reads an audio file from the specified path, normalizes the audio, | |
resamples it to the desired sampling rate (if necessary), and ensures it is single-channel. | |
Parameters: | |
path (str): The file path of the audio file to be read. | |
sampling_rate (int): The target sampling rate for the audio. | |
Returns: | |
numpy.ndarray: The processed audio data, normalized, resampled (if necessary), | |
and converted to mono (if the input audio has multiple channels). | |
""" | |
# Read audio data and its sample rate from the file. | |
data, fs = sf.read(path) | |
# Normalize the audio data. | |
data = audio_norm(data) | |
# Resample the audio if the sample rate is different from the target sampling rate. | |
if fs != sampling_rate: | |
data = librosa.resample(data, orig_sr=fs, target_sr=sampling_rate) | |
# Convert to mono by selecting the first channel if the audio has multiple channels. | |
if len(data.shape) > 1: | |
data = data[:, 0] | |
# Return the processed audio data. | |
return data | |
def audio_norm(x): | |
""" | |
Normalizes the input audio signal to a target Root Mean Square (RMS) level, | |
applying two stages of scaling. This ensures the audio signal is neither too quiet | |
nor too loud, keeping its amplitude consistent. | |
Parameters: | |
x (numpy.ndarray): Input audio signal to be normalized. | |
Returns: | |
numpy.ndarray: Normalized audio signal. | |
""" | |
# Compute the root mean square (RMS) of the input audio signal. | |
rms = (x ** 2).mean() ** 0.5 | |
# Calculate the scalar to adjust the signal to the target level (-25 dB). | |
scalar = 10 ** (-25 / 20) / (rms + EPS) | |
# Scale the input audio by the computed scalar. | |
x = x * scalar | |
# Compute the power of the scaled audio signal. | |
pow_x = x ** 2 | |
# Calculate the average power of the audio signal. | |
avg_pow_x = pow_x.mean() | |
# Compute RMS only for audio segments with higher-than-average power. | |
rmsx = pow_x[pow_x > avg_pow_x].mean() ** 0.5 | |
# Calculate another scalar to further normalize based on higher-power segments. | |
scalarx = 10 ** (-25 / 20) / (rmsx + EPS) | |
# Apply the second scalar to the audio. | |
x = x * scalarx | |
# Return the doubly normalized audio signal. | |
return x | |
class DataReader(object): | |
""" | |
A class for reading audio data from a list of files, normalizing it, | |
and extracting features for further processing. It supports extracting | |
features from each file, reshaping the data, and returning metadata | |
like utterance ID and data length. | |
Parameters: | |
args: Arguments containing the input path and target sampling rate. | |
Attributes: | |
file_list (list): A list of audio file paths to process. | |
sampling_rate (int): The target sampling rate for audio files. | |
""" | |
def __init__(self, args): | |
# Read and configure the file list from the input path provided in the arguments. | |
# The file list is decoded, if necessary. | |
self.file_list = read_and_config_file(args, args.input_path, decode=True) | |
# Store the target sampling rate. | |
self.sampling_rate = args.sampling_rate | |
# Store the args file | |
self.args = args | |
def __len__(self): | |
""" | |
Returns the number of audio files in the file list. | |
Returns: | |
int: Number of files to process. | |
""" | |
return len(self.file_list) | |
def __getitem__(self, index): | |
""" | |
Retrieves the features of the audio file at the given index. | |
Parameters: | |
index (int): Index of the file in the file list. | |
Returns: | |
tuple: Features (inputs, utterance ID, data length) for the selected audio file. | |
""" | |
if self.args.task == 'target_speaker_extraction': | |
if self.args.network_reference.cue== 'lip': | |
return self.file_list[index] | |
return self.extract_feature(self.file_list[index]) | |
def extract_feature(self, path): | |
""" | |
Extracts features from the given audio file path. | |
Parameters: | |
path (str): The file path of the audio file. | |
Returns: | |
inputs (numpy.ndarray): Reshaped audio data for further processing. | |
utt_id (str): The unique identifier of the audio file, usually the filename. | |
length (int): The length of the original audio data. | |
""" | |
# Extract the utterance ID from the file path (usually the filename). | |
utt_id = path.split('/')[-1] | |
# Read and normalize the audio data, converting it to float32 for processing. | |
data = audioread(path, self.sampling_rate).astype(np.float32) | |
# Reshape the data to ensure it's in the format [1, data_length]. | |
inputs = np.reshape(data, [1, data.shape[0]]) | |
# Return the reshaped audio data, utterance ID, and the length of the original data. | |
return inputs, utt_id, data.shape[0] | |
class Wave_Processor(object): | |
""" | |
A class for processing audio data, specifically for reading input and label audio files, | |
segmenting them into fixed-length segments, and applying padding or trimming as necessary. | |
Methods: | |
process(path, segment_length, sampling_rate): | |
Processes audio data by reading, padding, or segmenting it to match the specified segment length. | |
Parameters: | |
path (dict): A dictionary containing file paths for 'inputs' and 'labels' audio files. | |
segment_length (int): The desired length of audio segments to extract. | |
sampling_rate (int): The target sampling rate for reading the audio files. | |
""" | |
def process(self, path, segment_length, sampling_rate): | |
""" | |
Reads input and label audio files, and ensures the audio is segmented into | |
the desired length, padding if necessary or extracting random segments if | |
the audio is longer than the target segment length. | |
Parameters: | |
path (dict): Dictionary containing the paths to 'inputs' and 'labels' audio files. | |
segment_length (int): Desired length of the audio segment in samples. | |
sampling_rate (int): Target sample rate for the audio. | |
Returns: | |
tuple: A pair of numpy arrays representing the processed input and label audio, | |
either padded to the segment length or trimmed. | |
""" | |
# Read the input and label audio files using the target sampling rate. | |
wave_inputs = audioread(path['inputs'], sampling_rate) | |
wave_labels = audioread(path['labels'], sampling_rate) | |
# Get the length of the label audio (assumed both inputs and labels have similar lengths). | |
len_wav = wave_labels.shape[0] | |
# If the input audio is shorter than the desired segment length, pad it with zeros. | |
if wave_inputs.shape[0] < segment_length: | |
# Create zero-padded arrays for inputs and labels. | |
padded_inputs = np.zeros(segment_length, dtype=np.float32) | |
padded_labels = np.zeros(segment_length, dtype=np.float32) | |
# Copy the original audio into the padded arrays. | |
padded_inputs[:wave_inputs.shape[0]] = wave_inputs | |
padded_labels[:wave_labels.shape[0]] = wave_labels | |
else: | |
# Randomly select a start index for segmenting the audio if it's longer than the segment length. | |
st_idx = random.randint(0, len_wav - segment_length) | |
# Extract a segment of the desired length from the inputs and labels. | |
padded_inputs = wave_inputs[st_idx:st_idx + segment_length] | |
padded_labels = wave_labels[st_idx:st_idx + segment_length] | |
# Return the processed (padded or segmented) input and label audio. | |
return padded_inputs, padded_labels | |
class Fbank_Processor(object): | |
""" | |
A class for processing input audio data into mel-filterbank (Fbank) features, | |
including the computation of delta and delta-delta features. | |
Methods: | |
process(inputs, args): | |
Processes the raw audio input and returns the mel-filterbank features | |
along with delta and delta-delta features. | |
""" | |
def process(self, inputs, args): | |
# Convert frame length and shift from seconds to milliseconds. | |
frame_length = int(args.win_len / args.sampling_rate * 1000) | |
frame_shift = int(args.win_inc / args.sampling_rate * 1000) | |
# Set up configuration for the mel-filterbank computation. | |
fbank_config = { | |
"dither": 1.0, | |
"frame_length": frame_length, | |
"frame_shift": frame_shift, | |
"num_mel_bins": args.num_mels, | |
"sample_frequency": args.sampling_rate, | |
"window_type": args.win_type | |
} | |
# Convert the input audio to a FloatTensor and scale it to match the expected input range. | |
inputs = torch.FloatTensor(inputs * MAX_WAV_VALUE) | |
# Compute the mel-filterbank features using Kaldi's fbank function. | |
fbank = torchaudio.compliance.kaldi.fbank(inputs.unsqueeze(0), **fbank_config) | |
# Add delta and delta-delta features. | |
fbank_tr = torch.transpose(fbank, 0, 1) | |
fbank_delta = torchaudio.functional.compute_deltas(fbank_tr) | |
fbank_delta_delta = torchaudio.functional.compute_deltas(fbank_delta) | |
fbank_delta = torch.transpose(fbank_delta, 0, 1) | |
fbank_delta_delta = torch.transpose(fbank_delta_delta, 0, 1) | |
# Concatenate the original Fbank, delta, and delta-delta features. | |
fbanks = torch.cat([fbank, fbank_delta, fbank_delta_delta], dim=1) | |
return fbanks.numpy() | |
class AudioDataset(Dataset): | |
""" | |
A dataset class for loading and processing audio data from different data types | |
(train, validation, test). Supports audio processing and feature extraction | |
(e.g., waveform processing, Fbank feature extraction). | |
Parameters: | |
args: Arguments containing dataset configuration (paths, sampling rate, etc.). | |
data_type (str): The type of data to load (train, val, test). | |
""" | |
def __init__(self, args, data_type): | |
self.args = args | |
self.sampling_rate = args.sampling_rate | |
# Read the list of audio files based on the data type. | |
if data_type == 'train': | |
self.wav_list = read_and_config_file(args.tr_list) | |
elif data_type == 'val': | |
self.wav_list = read_and_config_file(args.cv_list) | |
elif data_type == 'test': | |
self.wav_list = read_and_config_file(args.tt_list) | |
else: | |
print(f'Data type: {data_type} is unknown!') | |
# Initialize processors for waveform and Fbank features. | |
self.wav_processor = Wave_Processor() | |
self.fbank_processor = Fbank_Processor() | |
# Clip data to a fixed segment length based on the sampling rate and max length. | |
self.segment_length = self.sampling_rate * self.args.max_length | |
print(f'No. {data_type} files: {len(self.wav_list)}') | |
def __len__(self): | |
# Return the number of audio files in the dataset. | |
return len(self.wav_list) | |
def __getitem__(self, index): | |
# Get the input and label paths from the list. | |
data_info = self.wav_list[index] | |
# Process the waveform inputs and labels. | |
inputs, labels = self.wav_processor.process( | |
{'inputs': data_info['inputs'], 'labels': data_info['labels']}, | |
self.segment_length, | |
self.sampling_rate | |
) | |
# Optionally load Fbank features if specified. | |
if self.args.load_fbank is not None: | |
fbanks = self.fbank_processor.process(inputs, self.args) | |
return inputs * MAX_WAV_VALUE, labels * MAX_WAV_VALUE, fbanks | |
return inputs, labels | |
def zero_pad_concat(self, inputs): | |
""" | |
Concatenates a list of input arrays, applying zero-padding as needed to ensure | |
they all match the length of the longest input. | |
Parameters: | |
inputs (list of numpy arrays): List of input arrays to be concatenated. | |
Returns: | |
numpy.ndarray: A zero-padded array with concatenated inputs. | |
""" | |
# Get the maximum length among all inputs. | |
max_t = max(inp.shape[0] for inp in inputs) | |
# Determine the shape of the output based on the input dimensions. | |
shape = None | |
if len(inputs[0].shape) == 1: | |
shape = (len(inputs), max_t) | |
elif len(inputs[0].shape) == 2: | |
shape = (len(inputs), max_t, inputs[0].shape[1]) | |
# Initialize an array with zeros to hold the concatenated inputs. | |
input_mat = np.zeros(shape, dtype=np.float32) | |
# Copy the input data into the zero-padded array. | |
for e, inp in enumerate(inputs): | |
if len(inp.shape) == 1: | |
input_mat[e, :inp.shape[0]] = inp | |
elif len(inp.shape) == 2: | |
input_mat[e, :inp.shape[0], :] = inp | |
return input_mat | |
def collate_fn_2x_wavs(data): | |
""" | |
A custom collate function for combining batches of waveform input and label pairs. | |
Parameters: | |
data (list): List of tuples (inputs, labels). | |
Returns: | |
tuple: Batched inputs and labels as torch.FloatTensors. | |
""" | |
inputs, labels = zip(*data) | |
x = torch.FloatTensor(inputs) | |
y = torch.FloatTensor(labels) | |
return x, y | |
def collate_fn_2x_wavs_fbank(data): | |
""" | |
A custom collate function for combining batches of waveform inputs, labels, and Fbank features. | |
Parameters: | |
data (list): List of tuples (inputs, labels, fbanks). | |
Returns: | |
tuple: Batched inputs, labels, and Fbank features as torch.FloatTensors. | |
""" | |
inputs, labels, fbanks = zip(*data) | |
x = torch.FloatTensor(inputs) | |
y = torch.FloatTensor(labels) | |
z = torch.FloatTensor(fbanks) | |
return x, y, z | |
class DistributedSampler(data.Sampler): | |
""" | |
Sampler for distributed training. Divides the dataset among multiple replicas (processes), | |
ensuring that each process gets a unique subset of the data. It also supports shuffling | |
and managing epochs. | |
Parameters: | |
dataset (Dataset): The dataset to sample from. | |
num_replicas (int): Number of processes participating in the training. | |
rank (int): Rank of the current process. | |
shuffle (bool): Whether to shuffle the data or not. | |
seed (int): Random seed for reproducibility. | |
""" | |
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0): | |
if num_replicas is None: | |
if not dist.is_available(): | |
raise RuntimeError("Requires distributed package to be available") | |
num_replicas = dist.get_world_size() | |
if rank is None: | |
if not dist.is_available(): | |
raise RuntimeError("Requires distributed package to be available") | |
rank = dist.get_rank() | |
self.dataset = dataset | |
self.num_replicas = num_replicas | |
self.rank = rank | |
self.epoch = 0 | |
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) | |
self.total_size = self.num_samples * self.num_replicas | |
self.shuffle = shuffle | |
self.seed = seed | |
def __iter__(self): | |
# Shuffle the indices based on the epoch and seed. | |
if self.shuffle: | |
g = torch.Generator() | |
g.manual_seed(self.seed + self.epoch) | |
ind = torch.randperm(int(len(self.dataset) / self.num_replicas), generator=g) * self.num_replicas | |
indices = [] | |
for i in range(self.num_replicas): | |
indices = indices + (ind + i).tolist() | |
else: | |
indices = list(range(len(self.dataset))) | |
# Add extra samples to make the dataset evenly divisible. | |
indices += indices[:(self.total_size - len(indices))] | |
assert len(indices) == self.total_size | |
# Subsample for the current process. | |
indices = indices[self.rank * self.num_samples:(self.rank + 1) * self.num_samples] | |
assert len(indices) == self.num_samples | |
return iter(indices) | |
def __len__(self): | |
return self.num_samples | |
def set_epoch(self, epoch): | |
self.epoch = epoch | |
def get_dataloader(args, data_type): | |
""" | |
Creates and returns a data loader and sampler for the specified dataset type (train, validation, or test). | |
Parameters: | |
args (Namespace): Configuration arguments containing details such as batch size, sampling rate, | |
network type, and whether distributed training is used. | |
data_type (str): The type of dataset to load ('train', 'val', 'test'). | |
Returns: | |
sampler (DistributedSampler or None): The sampler for distributed training, or None if not used. | |
generator (DataLoader): The PyTorch DataLoader for the specified dataset. | |
""" | |
# Initialize the dataset based on the given arguments and dataset type (train, val, or test). | |
datasets = AudioDataset(args=args, data_type=data_type) | |
# Create a distributed sampler if distributed training is enabled; otherwise, use no sampler. | |
sampler = DistributedSampler( | |
datasets, | |
num_replicas=args.world_size, # Number of replicas in distributed training. | |
rank=args.local_rank # Rank of the current process. | |
) if args.distributed else None | |
# Select the appropriate collate function based on the network type. | |
if args.network == 'FRCRN_SE_16K' or args.network == 'MossFormerGAN_SE_16K': | |
# Use the collate function for two-channel waveform data (inputs and labels). | |
collate_fn = collate_fn_2x_wavs | |
elif args.network == 'MossFormer2_SE_48K': | |
# Use the collate function for waveforms along with Fbank features. | |
collate_fn = collate_fn_2x_wavs_fbank | |
else: | |
# Print an error message if the network type is unknown. | |
print(f'in dataloader, please specify a correct network type using args.network!') | |
return | |
# Create a DataLoader with the specified dataset, batch size, and worker configuration. | |
generator = data.DataLoader( | |
datasets, | |
batch_size=args.batch_size, # Batch size for training. | |
shuffle=(sampler is None), # Shuffle the data only if no sampler is used. | |
collate_fn=collate_fn, # Use the selected collate function for batching data. | |
num_workers=args.num_workers, # Number of workers for data loading. | |
sampler=sampler # Use the distributed sampler if applicable. | |
) | |
# Return both the sampler and DataLoader (generator). | |
return sampler, generator | |