#!/home/haroon/python_virtual_envs/whisper_fine_tuning/bin/python | |
from datasets import load_dataset, DatasetDict, Audio | |
from transformers import (WhisperTokenizer, WhisperFeatureExtractor, | |
WhisperProcessor, WhisperForConditionalGeneration, | |
Seq2SeqTrainingArguments, Seq2SeqTrainer) | |
from transformers.models.whisper.english_normalizer import BasicTextNormalizer | |
import torch | |
from dataclasses import dataclass | |
from typing import Any, Dict, List, Union | |
import evaluate | |
# ## Load Dataset | |
# Hugging Face Hub: | |
# [mozilla-foundation/common_voice_11_0] | |
# (https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0). | |
common_voice = DatasetDict() | |
common_voice["train"] = load_dataset("mozilla-foundation/common_voice_11_0", | |
"hi", | |
split="train+validation", | |
token=True) | |
common_voice["test"] = load_dataset("mozilla-foundation/common_voice_11_0", | |
"hi", | |
split="test", | |
token=True) | |
print(f'YYY1a {common_voice=}') | |
common_voice = common_voice.remove_columns([ | |
"accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"]) | |
print(f'YYY1b {common_voice=}') | |
print(f'YYY2 {type(common_voice)=}') | |
# ## Prepare Feature Extractor, Tokenizer and Data | |
# The ASR pipeline can be de-composed into three stages: | |
# 1) A feature extractor which pre-processes the raw audio-inputs | |
# 2) The model which performs the sequence-to-sequence mapping | |
# 3) A tokenizer which post-processes the model outputs to text format | |
# | |
# In 🤗 Transformers, the Whisper model has an associated feature extractor and tokenizer, called | |
# [WhisperFeatureExtractor] | |
# (https://huggingface.co/docs/transformers/main/model_doc/whisper#transformers.WhisperFeatureExtractor) | |
# and [WhisperTokenizer] | |
# (https://huggingface.co/docs/transformers/main/model_doc/whisper#transformers.WhisperTokenizer) | |
# respectively. | |
# ### Load WhisperFeatureExtractor | |
# The Whisper feature extractor performs two operations: | |
# 1. Pads / truncates the audio inputs to 30s: any audio inputs shorter than 30s are padded to 30s | |
# with silence (zeros), and those longer that 30s are truncated to 30s. | |
# 2. Converts the audio inputs to log-Mel spectrogram input features, a visual representation of the | |
# audio and the form of the input expected by the Whisper model. | |
# We'll load the feature extractor from the pre-trained checkpoint with the default values: | |
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small") | |
# ### Load WhisperTokenizer | |
# The Whisper model outputs a sequence of token ids. | |
# The tokenizer maps each of these token ids to their corresponding text string. | |
# For Hindi, we can load the pre-trained tokenizer and use it for fine-tuning without any | |
# further modifications. | |
# We simply have to specify the target language and the task. | |
# These arguments inform the tokenizer to prefix the language and task tokens to the start of encoded | |
# label sequences: | |
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", | |
language="Hindi", task="transcribe") | |
# ### Combine To Create A WhisperProcessor | |
# To simplify using the feature extractor and tokenizer, we can wrap both into a single | |
# `WhisperProcessor` class. This processor object inherits from the `WhisperFeatureExtractor` | |
# and `WhisperProcessor`, and can be used on the audio inputs and model predictions as required. | |
# In doing so, we only need to keep track of two objects during training: | |
# the `processor` and the `model`: | |
processor = WhisperProcessor.from_pretrained("openai/whisper-small", | |
language="Hindi", task="transcribe") | |
# ### Prepare Data | |
# Let's print the first example of the Common Voice dataset to see what form the data is in: | |
print(common_voice["train"][0]) | |
''' | |
In [9]: print(common_voice["train"][0].keys()) | |
common_voice["train"][0] --> keys: 'audio', 'sentence' | |
common_voice["train"][0]['audio'] -> keys: 'path': str, 'array': list(float), 'sampling_rate': int | |
common_voice["train"][0]['sentence'] -> text | |
''' | |
# Since our input audio is sampled at 48kHz, we need to downsample it to 16kHz prior to passing | |
# it to the Whisper feature extractor, 16kHz being the sampling rate expected by the Whisper model. | |
# We'll set the audio inputs to the correct sampling rate using dataset's | |
# [`cast_column`] | |
# (https://huggingface.co/docs/datasets/package_reference/main_classes.html?highlight=cast_column#datasets.DatasetDict.cast_column) | |
# method. | |
# This operation does not change the audio in-place, but rather signals to `datasets` to resample | |
# audio samples on the fly the first time that they are loaded: | |
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000)) | |
# Re-loading the first audio sample in the Common Voice dataset will resample it to the | |
# desired sampling rate: | |
print(common_voice["train"][0]) | |
# We'll define our pre-processing strategy. We advise that you **do not** lower-case the transcriptions | |
# or remove punctuation unless mixing different datasets. | |
# This will enable you to fine-tune Whisper models that can predict punctuation and casing. | |
# Later, you will see how we can evaluate the predictions without punctuation or casing, so that | |
# the models benefit from the WER improvement obtained by normalising the transcriptions while | |
# still predicting fully formatted transcriptions. | |
do_lower_case = False | |
do_remove_punctuation = False | |
normalizer = BasicTextNormalizer() | |
# Now we can write a function to prepare our data ready for the model: | |
# 1. We load and resample the audio data by calling `batch["audio"]`. | |
# As explained above, 🤗 Datasets performs any necessary resampling operations on the fly. | |
# 2. We use the feature extractor to compute the log-Mel spectrogram input features from our | |
# 1-dimensional audio array. | |
# 3. We perform any optional pre-processing (lower-case or remove punctuation). | |
# 4. We encode the transcriptions to label ids through the use of the tokenizer. | |
def prepare_dataset(batch): | |
# load and (possibly) resample audio data to 16kHz | |
audio = batch["audio"] | |
# compute log-Mel input features from input audio array | |
batch["input_features"] = processor.feature_extractor( | |
audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0] | |
# compute input length of audio sample in seconds | |
batch["input_length"] = len(audio["array"]) / audio["sampling_rate"] | |
# optional pre-processing steps | |
transcription = batch["sentence"] | |
if do_lower_case: | |
transcription = transcription.lower() | |
if do_remove_punctuation: | |
transcription = normalizer(transcription).strip() | |
# encode target text to label ids | |
batch["labels"] = processor.tokenizer(transcription).input_ids | |
return batch | |
# We can apply the data preparation function to all of our training examples using dataset's | |
# `.map` method. | |
# The argument `num_proc` specifies how many CPU cores to use. Setting `num_proc` > 1 will | |
# enable multiprocessing. If the `.map` method hangs with multiprocessing, set `num_proc=1` | |
# and process the dataset sequentially. | |
common_voice = common_voice.map(prepare_dataset, | |
remove_columns=common_voice.column_names["train"], | |
num_proc=2) | |
# Finally, we filter any training data with audio samples longer than 30s. | |
# These samples would otherwise be truncated by the Whisper feature-extractor which could affect | |
# the stability of training. | |
# We define a function that returns `True` for samples that are less than 30s, and `False` for | |
# those that are longer: | |
max_input_length = 30.0 | |
def is_audio_in_length_range(length): | |
return length < max_input_length | |
# We apply our filter function to all samples of our training dataset through 🤗 Datasets' | |
# `.filter` method: | |
common_voice["train"] = common_voice["train"].filter( | |
is_audio_in_length_range, | |
input_columns=["input_length"], | |
) | |
# ## Training and Evaluation | |
# Now that we've prepared our data, we're ready to dive into the training pipeline. | |
# The [🤗 Trainer] | |
# (https://huggingface.co/transformers/master/main_classes/trainer.html?highlight=trainer) | |
# will do much of the heavy lifting for us. All we have to do is: | |
# - Define a data collator: the data collator takes our pre-processed data and prepares PyTorch | |
# tensors ready for the model. | |
# - Evaluation metrics: during evaluation, we want to evaluate the model using the | |
# [word error rate (WER)] (https://huggingface.co/metrics/wer) metric. | |
# We need to define a `compute_metrics` function that handles this computation. | |
# - Load a pre-trained checkpoint: we need to load a pre-trained checkpoint and configure it correctly | |
# for training. | |
# - Define the training configuration: this will be used by the 🤗 Trainer to define the training | |
# schedule. | |
# Once we've fine-tuned the model, we will evaluate it on the test data to verify that we have | |
# correctly trained it to transcribe speech in Hindi. | |
# ### Define a Data Collator | |
# The data collator for a sequence-to-sequence speech model is unique in the sense that it treats | |
# the `input_features` and `labels` independently: the `input_features` must be handled by the | |
# feature extractor and the `labels` by the tokenizer. | |
# The `input_features` are already padded to 30s and converted to a log-Mel spectrogram of fixed | |
# dimension by action of the feature extractor, so all we have to do is convert the `input_features` | |
# to batched PyTorch tensors. | |
# We do this using the feature extractor's `.pad` method with `return_tensors=pt`. | |
# The `labels` on the other hand are un-padded. We first pad the sequences to the maximum length | |
# in the batch using the tokenizer's `.pad` method. The padding tokens are then replaced by `-100` | |
# so that these tokens are **not** taken into account when computing the loss. | |
# We then cut the BOS token from the start of the label sequence as we append it later during training. | |
# We can leverage the `WhisperProcessor` we defined earlier to perform both the feature extractor | |
# and the tokenizer operations: | |
class DataCollatorSpeechSeq2SeqWithPadding: | |
processor: Any | |
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]])\ | |
-> Dict[str, torch.Tensor]: | |
# split inputs and labels since they have to be of different lengths and need different | |
# padding methods. | |
# First treat the audio inputs by simply returning torch tensors. | |
input_features = [{"input_features": feature["input_features"]} for feature in features] | |
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") | |
# get the tokenized label sequences | |
label_features = [{"input_ids": feature["labels"]} for feature in features] | |
# pad the labels to max length | |
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") | |
# replace padding with -100 to ignore loss correctly | |
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) | |
# if bos token is appended in previous tokenization step, cut bos token here as it | |
# gets appended later. | |
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item(): | |
labels = labels[:, 1:] | |
batch["labels"] = labels | |
return batch | |
# Let's initialise the data collator we've just defined: | |
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor) | |
# ### Evaluation Metrics | |
# We'll use the word error rate (WER) metric, the 'de-facto' metric for assessing ASR systems. | |
# For more information, refer to the WER | |
# [docs] (https://huggingface.co/metrics/wer). | |
# We'll load the WER metric from 🤗 Evaluate: | |
metric = evaluate.load("wer") | |
# We then simply have to define a function that takes our model predictions and returns the WER metric. | |
# This function, called `compute_metrics`, first replaces `-100` with the `pad_token_id` in the | |
# `label_ids` (undoing the step we applied in the data collator to ignore padded tokens correctly in | |
# the loss). | |
# It then decodes the predicted and label ids to strings. Finally, it computes the WER between the | |
# predictions and reference labels. | |
# Here, we have the option of evaluating with the 'normalised' transcriptions and predictions. | |
# We recommend you set this to `True` to benefit from the WER improvement obtained by normalising | |
# the transcriptions. | |
# Evaluate with the 'normalised' WER | |
do_normalize_eval = True | |
def compute_metrics(pred): | |
pred_ids = pred.predictions | |
label_ids = pred.label_ids | |
# replace -100 with the pad_token_id | |
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id | |
# we do not want to group tokens when computing the metrics | |
pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True) | |
label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True) | |
if do_normalize_eval: | |
pred_str = [normalizer(pred) for pred in pred_str] | |
label_str = [normalizer(label) for label in label_str] | |
wer = 100 * metric.compute(predictions=pred_str, references=label_str) | |
return {"wer": wer} | |
# ### Load a Pre-Trained Checkpoint | |
# Now let's load the pre-trained Whisper `small` checkpoint. Again, this is trivial through | |
# use of 🤗 Transformers! | |
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") | |
# define your language of choice here | |
model.generation_config.language = "hi" | |
# Override generation arguments - no tokens are forced as decoder outputs | |
# (see [`forced_decoder_ids`] | |
# (https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.forced_decoder_ids)), | |
# no tokens are suppressed during generation | |
# (see [`suppress_tokens`] | |
# (https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.suppress_tokens)). | |
# Set `use_cache` to False since we're using gradient checkpointing, and the two are incompatible: | |
model.config.forced_decoder_ids = None | |
model.config.suppress_tokens = [] | |
model.config.use_cache = False | |
# ### Define the Training Configuration | |
# In the final step, we define all the parameters related to training. | |
# For more detail on the training arguments, refer to the Seq2SeqTrainingArguments | |
# [docs] | |
# (https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments). | |
training_args = Seq2SeqTrainingArguments( | |
output_dir="./", | |
per_device_train_batch_size=8, | |
gradient_accumulation_steps=8, # increase by 2x for every 2x decrease in batch size | |
learning_rate=1e-5, | |
warmup_steps=500, | |
max_steps=5000, | |
gradient_checkpointing=True, | |
fp16=True, | |
evaluation_strategy="steps", | |
per_device_eval_batch_size=4, | |
predict_with_generate=True, | |
generation_max_length=225, | |
save_steps=1000, | |
eval_steps=1000, | |
logging_steps=25, | |
report_to=["tensorboard"], | |
load_best_model_at_end=True, | |
metric_for_best_model="wer", | |
greater_is_better=False, | |
push_to_hub=True, | |
) | |
# **Note**: if one does not want to upload the model checkpoints to the Hub, set `push_to_hub=False`. | |
# We can forward the training arguments to the 🤗 Trainer along with our model, dataset, data collator | |
# and `compute_metrics` function: | |
trainer = Seq2SeqTrainer( | |
args=training_args, | |
model=model, | |
train_dataset=common_voice["train"], | |
eval_dataset=common_voice["test"], | |
data_collator=data_collator, | |
compute_metrics=compute_metrics, | |
tokenizer=processor.feature_extractor, | |
) | |
# We'll save the processor object once before starting training. Since the processor is not trainable, | |
# it won't change over the course of training: | |
processor.save_pretrained(training_args.output_dir) | |
# ### Training | |
# Training will take approximately 5-10 hours depending on your GPU. The peak GPU memory for the | |
# given training configuration is approximately 36GB. | |
# Depending on your GPU, it is possible that you will encounter a CUDA `"out-of-memory"` error when | |
# you launch training. In this case, you can reduce the `per_device_train_batch_size` incrementally | |
# by factors of 2 and employ [`gradient_accumulation_steps`] | |
# (https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments.gradient_accumulation_steps) | |
# to compensate. | |
# To launch training, simply execute: | |
trainer.train() | |
# We can label our checkpoint with the `whisper-event` tag on push by setting the appropriate | |
# keyword arguments (kwargs): | |
kwargs = { | |
"dataset_tags": "mozilla-foundation/common_voice_11_0", | |
"dataset": "Common Voice 11.0", # a 'pretty' name for the training dataset | |
"language": "hi", | |
"model_name": "Whisper Small Hi - Sanchit Gandhi", # a 'pretty' name for your model | |
"finetuned_from": "openai/whisper-small", | |
"tasks": "automatic-speech-recognition", | |
"tags": "whisper-event", | |
} | |
# The training results can now be uploaded to the Hub. To do so, execute the `push_to_hub` | |
# command and save the preprocessor object we created: | |
trainer.push_to_hub(**kwargs) | |
# ## Closing Remarks | |
# If you're interested in fine-tuning other Transformers models, both for English and multilingual ASR, | |
# be sure to check out the examples scripts at | |
# [examples/pytorch/speech-recognition] | |
# (https://github.com/huggingface/transformers/tree/main/examples/pytorch/speech-recognition). | |