Spaces:
Runtime error
Runtime error
import logging | |
from dataclasses import dataclass | |
from typing import Dict, List, Optional, Set, Union | |
import datasets | |
import numpy as np | |
import torch | |
from accelerate import Accelerator | |
from datasets import Dataset, IterableDataset, concatenate_datasets, interleave_datasets, load_dataset | |
from tqdm import tqdm | |
from transformers import AutoFeatureExtractor, AutoTokenizer | |
class DataCollatorEncodecWithPadding: | |
""" | |
Data collator that will dynamically pad the inputs received to the longest sequence in the batch or | |
to `max_length` if `max_length` is set and `padding=max_length`. | |
""" | |
feature_extractor: AutoFeatureExtractor | |
audio_column_name: str | |
feature_extractor_input_name: Optional[str] = "input_values" | |
max_length: Optional[int] = None | |
padding: Optional[str] = "longest" | |
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: | |
# split inputs and labels since they have to be of different lengths and need | |
# different padding methods | |
audios = [feature[self.audio_column_name]["array"] for feature in features] | |
len_audio = [len(audio) for audio in audios] | |
if self.max_length is not None: | |
audios = [audio[: min(l, self.max_length)] for audio, l in zip(audios, len_audio)] | |
# since resampling has already been performed in the 'load_multiple_datasets' function, | |
# a fixed sampling_rate(44100hz) is passed to the feature_extractor. | |
sampling_rate = self.feature_extractor.sampling_rate | |
batch = self.feature_extractor( | |
audios, sampling_rate=sampling_rate, return_tensors="pt", padding=self.padding, max_length=self.max_length | |
) | |
batch["len_audio"] = torch.tensor(len_audio).unsqueeze(1) | |
return batch | |
class DataCollatorParlerTTSWithPadding: | |
""" | |
Data collator that will dynamically pad the inputs received. | |
Args: | |
prompt_tokenizer (:class:`~transformers.AutoTokenizer`) | |
The prompt_tokenizer used for proccessing the data. | |
description_tokenizer (:class:`~transformers.AutoTokenizer`) | |
The description_tokenizer used for proccessing the data. | |
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): | |
Select a strategy to pad the returned sequences (according to the model's padding side and padding index) | |
among: | |
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single | |
sequence if provided). | |
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the | |
maximum acceptable input length for the model if that argument is not provided. | |
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of | |
different lengths). | |
pad_to_multiple_of (:obj:`int`, `optional`): | |
If set will pad the sequence to a multiple of the provided value. | |
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= | |
7.5 (Volta). | |
""" | |
prompt_tokenizer: AutoTokenizer | |
description_tokenizer: AutoTokenizer | |
padding: Union[bool, str] = "longest" | |
pad_to_multiple_of: Optional[int] = None | |
prompt_max_length: Optional[int] = None | |
description_max_length: Optional[int] = None | |
audio_max_length: Optional[int] = None | |
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: | |
# split inputs and labels since they have to be of different lengths and need | |
# different padding methods | |
labels = [torch.tensor(feature["labels"]).transpose(0, 1) for feature in features] | |
# (bsz, seq_len, num_codebooks) | |
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100) | |
if self.audio_max_length is not None and self.padding == "max_length": | |
labels = torch.nn.functional.pad( | |
labels, pad=(0, 0, 0, max(self.audio_max_length - labels.shape[1], 0)), value=-100 | |
) | |
input_ids = [{"input_ids": feature["input_ids"]} for feature in features] | |
input_ids = self.description_tokenizer.pad( | |
input_ids, | |
return_tensors="pt", | |
padding=self.padding, | |
pad_to_multiple_of=self.pad_to_multiple_of, | |
max_length=self.description_max_length, | |
) | |
batch = {"labels": labels, **input_ids} | |
prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features] | |
prompt_input_ids = self.prompt_tokenizer.pad( | |
prompt_input_ids, | |
return_tensors="pt", | |
padding=self.padding, | |
pad_to_multiple_of=self.pad_to_multiple_of, | |
max_length=self.prompt_max_length, | |
) | |
batch["prompt_input_ids"] = prompt_input_ids["input_ids"] | |
if "attention_mask" in prompt_input_ids: | |
batch["prompt_attention_mask"] = prompt_input_ids["attention_mask"] | |
return batch | |
def convert_dataset_str_to_list( | |
dataset_names, | |
dataset_config_names, | |
metadata_dataset_names=None, | |
splits=None, | |
dataset_samples=None, | |
default_split="train", | |
): | |
if isinstance(dataset_names, str): | |
dataset_names = dataset_names.split("+") | |
dataset_config_names = dataset_config_names.split("+") | |
splits = splits.split("+") if splits is not None else None | |
dataset_samples = dataset_samples.split("+") if dataset_samples is not None else None | |
metadata_dataset_names = metadata_dataset_names.split("+") if metadata_dataset_names is not None else None | |
# basic checks to ensure we've got the right number of datasets/configs/splits/columns/probs | |
if len(dataset_names) != len(dataset_config_names): | |
raise ValueError( | |
f"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and" | |
f" {len(dataset_config_names)} configs." | |
) | |
if splits is not None and len(splits) != len(dataset_names): | |
raise ValueError( | |
f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits." | |
) | |
if metadata_dataset_names is not None and len(metadata_dataset_names) != len(dataset_names): | |
raise ValueError( | |
f"Ensure one metadata dataset is passed for each dataset, got {len(dataset_names)} datasets and {len(metadata_dataset_names)} metadata datasets." | |
) | |
if dataset_samples is not None: | |
if len(dataset_samples) != len(dataset_names): | |
raise ValueError( | |
f"Ensure one sample is passed for each dataset, got {len(dataset_names)} datasets and " | |
f"{len(dataset_samples)} samples." | |
) | |
dataset_samples = [float(ds_sample) for ds_sample in dataset_samples] | |
else: | |
dataset_samples = [None] * len(dataset_names) | |
splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))] | |
dataset_names_dict = [] | |
for i, ds_name in enumerate(dataset_names): | |
dataset_names_dict.append( | |
{ | |
"name": ds_name, | |
"config": dataset_config_names[i], | |
"split": splits[i], | |
"metadata_dataset_name": metadata_dataset_names[i], | |
"samples": dataset_samples[i], | |
} | |
) | |
return dataset_names_dict | |
def load_multiple_datasets( | |
accelerator: Accelerator, | |
dataset_names: Union[List, str], | |
dataset_config_names: Union[List, str], | |
metadata_dataset_names: Optional[str] = None, | |
splits: Optional[Union[List, str]] = None, | |
label_column_names: Optional[List] = None, | |
stopping_strategy: Optional[str] = "first_exhausted", | |
dataset_samples: Optional[Union[List, np.array]] = None, | |
streaming: Optional[bool] = False, | |
seed: Optional[int] = None, | |
id_column_name: Optional[str] = None, | |
columns_to_keep: Optional[Set[str]] = None, | |
prompt_column_name: Optional[str] = None, | |
sampling_rate: Optional[int] = None, | |
audio_column_name: Optional[str] = None, | |
logger: Optional[logging.Logger] = None, | |
**kwargs, | |
) -> Union[Dataset, IterableDataset]: | |
dataset_names_dict = convert_dataset_str_to_list( | |
dataset_names, dataset_config_names, metadata_dataset_names, splits, label_column_names, dataset_samples | |
) | |
if dataset_samples is not None: | |
dataset_samples = [ds_dict["samples"] for ds_dict in dataset_names_dict] | |
probabilities = np.array(dataset_samples) / np.sum(dataset_samples) | |
else: | |
probabilities = None | |
all_datasets = [] | |
# iterate over the datasets we want to interleave | |
for dataset_dict in tqdm(dataset_names_dict, desc="Combining datasets..."): | |
with accelerator.local_main_process_first(): | |
dataset = load_dataset( | |
dataset_dict["name"], | |
dataset_dict["config"], | |
split=dataset_dict["split"], | |
streaming=streaming, | |
**kwargs, | |
) | |
dataset_features = dataset.features.keys() | |
if sampling_rate is not None and audio_column_name is not None: | |
# resample target audio | |
dataset = dataset.cast_column(audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)) | |
metadata_dataset_name = dataset_dict["metadata_dataset_name"] | |
if metadata_dataset_name is not None: | |
logger.info( | |
f'Merging {dataset_dict["name"]} - {dataset_dict["split"]} with {metadata_dataset_name} - {dataset_dict["split"]}' | |
) | |
metadata_dataset = load_dataset( | |
metadata_dataset_name, | |
dataset_dict["config"], | |
split=dataset_dict["split"], | |
streaming=streaming, | |
**kwargs, | |
) | |
# TODO(YL): I forgot to create unique ids for MLS english. | |
# To iterate faster, I bypass the original id check and do another one. - Done once because assuming it won't change next time | |
# if dataset_dict["name"] == "parler-tts/mls_eng_10k": | |
# def concat_ids(book_id, speaker_id, begin_time): | |
# return {"id": f"{book_id}_{speaker_id}_{str(begin_time).replace('.', '_')}"} | |
# dataset = dataset.map(concat_ids, input_columns=["book_id", "speaker_id", "begin_time"], num_proc=24) | |
# metadata_dataset = metadata_dataset.map(concat_ids, input_columns=["book_id", "speaker_id", "begin_time"], num_proc=24) | |
# metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}") | |
if dataset_dict["name"] not in {"parler-tts/mls_eng_10k", "parler-tts/mls_eng"}: | |
if id_column_name is not None and id_column_name not in dataset.column_names: | |
raise ValueError( | |
f"id_column_name={id_column_name} but has not been found in the dataset columns" | |
f"- one of {', '.join(list(dataset.column_names))}." | |
) | |
if id_column_name is not None and id_column_name not in metadata_dataset.column_names: | |
raise ValueError( | |
f"id_column_name={id_column_name} but has not been found in the metadata dataset columns" | |
f"- one of {', '.join(list(metadata_dataset.column_names))}." | |
) | |
elif id_column_name is not None: | |
metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}") | |
metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names)) | |
if prompt_column_name is not None: | |
# We might have applied some transformations to the prompts (e.g punctuation restoration) | |
# so we make sure to remove it from the original dataset | |
if prompt_column_name in dataset.column_names: | |
logger.info( | |
f"REMOVE {prompt_column_name} from dataset {dataset_dict['name']} - dataset_dict['split']" | |
) | |
dataset.remove_columns(prompt_column_name) | |
metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names)) | |
metadata_dataset = metadata_dataset.remove_columns(metadata_columns_to_remove) | |
dataset = concatenate_datasets([dataset, metadata_dataset], axis=1) | |
if id_column_name is not None and dataset_dict["name"] not in { | |
"parler-tts/mls_eng_10k", | |
"parler-tts/mls_eng", | |
}: | |
if ( | |
len( | |
dataset.filter( | |
lambda id1, id2: id1 != id2, | |
input_columns=[id_column_name, f"metadata_{id_column_name}"], | |
) | |
) | |
!= 0 | |
): | |
raise ValueError( | |
f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}" | |
) | |
dataset_features = dataset.features.keys() | |
if columns_to_keep is not None: | |
dataset = dataset.remove_columns(set(dataset_features - columns_to_keep)) | |
all_datasets.append(dataset) | |
if len(all_datasets) == 1: | |
# we have a single dataset so just return it as is | |
return all_datasets[0] | |
if streaming: | |
interleaved_dataset = interleave_datasets( | |
all_datasets, | |
stopping_strategy=stopping_strategy, | |
probabilities=probabilities, | |
seed=seed, | |
) | |
else: | |
with accelerator.local_main_process_first(): | |
interleaved_dataset = concatenate_datasets(all_datasets) | |
return interleaved_dataset | |