community-events / whisper-fine-tuning-event /fine-tune-whisper-non-streaming-hindi.py
showgan's picture
Training in progress, step 1000
72621ec verified
#!/home/haroon/python_virtual_envs/whisper_fine_tuning/bin/python
from datasets import load_dataset, DatasetDict, Audio
from transformers import (WhisperTokenizer, WhisperFeatureExtractor,
WhisperProcessor, WhisperForConditionalGeneration,
Seq2SeqTrainingArguments, Seq2SeqTrainer)
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import evaluate
# ## Load Dataset
# Hugging Face Hub:
# [mozilla-foundation/common_voice_11_0]
# (https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0).
common_voice = DatasetDict()
common_voice["train"] = load_dataset("mozilla-foundation/common_voice_11_0",
"hi",
split="train+validation",
token=True)
common_voice["test"] = load_dataset("mozilla-foundation/common_voice_11_0",
"hi",
split="test",
token=True)
print(f'YYY1a {common_voice=}')
common_voice = common_voice.remove_columns([
"accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"])
print(f'YYY1b {common_voice=}')
print(f'YYY2 {type(common_voice)=}')
# ## Prepare Feature Extractor, Tokenizer and Data
# The ASR pipeline can be de-composed into three stages:
# 1) A feature extractor which pre-processes the raw audio-inputs
# 2) The model which performs the sequence-to-sequence mapping
# 3) A tokenizer which post-processes the model outputs to text format
#
# In 🤗 Transformers, the Whisper model has an associated feature extractor and tokenizer, called
# [WhisperFeatureExtractor]
# (https://huggingface.co/docs/transformers/main/model_doc/whisper#transformers.WhisperFeatureExtractor)
# and [WhisperTokenizer]
# (https://huggingface.co/docs/transformers/main/model_doc/whisper#transformers.WhisperTokenizer)
# respectively.
# ### Load WhisperFeatureExtractor
# The Whisper feature extractor performs two operations:
# 1. Pads / truncates the audio inputs to 30s: any audio inputs shorter than 30s are padded to 30s
# with silence (zeros), and those longer that 30s are truncated to 30s.
# 2. Converts the audio inputs to log-Mel spectrogram input features, a visual representation of the
# audio and the form of the input expected by the Whisper model.
# We'll load the feature extractor from the pre-trained checkpoint with the default values:
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
# ### Load WhisperTokenizer
# The Whisper model outputs a sequence of token ids.
# The tokenizer maps each of these token ids to their corresponding text string.
# For Hindi, we can load the pre-trained tokenizer and use it for fine-tuning without any
# further modifications.
# We simply have to specify the target language and the task.
# These arguments inform the tokenizer to prefix the language and task tokens to the start of encoded
# label sequences:
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small",
language="Hindi", task="transcribe")
# ### Combine To Create A WhisperProcessor
# To simplify using the feature extractor and tokenizer, we can wrap both into a single
# `WhisperProcessor` class. This processor object inherits from the `WhisperFeatureExtractor`
# and `WhisperProcessor`, and can be used on the audio inputs and model predictions as required.
# In doing so, we only need to keep track of two objects during training:
# the `processor` and the `model`:
processor = WhisperProcessor.from_pretrained("openai/whisper-small",
language="Hindi", task="transcribe")
# ### Prepare Data
# Let's print the first example of the Common Voice dataset to see what form the data is in:
print(common_voice["train"][0])
'''
In [9]: print(common_voice["train"][0].keys())
common_voice["train"][0] --> keys: 'audio', 'sentence'
common_voice["train"][0]['audio'] -> keys: 'path': str, 'array': list(float), 'sampling_rate': int
common_voice["train"][0]['sentence'] -> text
'''
# Since our input audio is sampled at 48kHz, we need to downsample it to 16kHz prior to passing
# it to the Whisper feature extractor, 16kHz being the sampling rate expected by the Whisper model.
# We'll set the audio inputs to the correct sampling rate using dataset's
# [`cast_column`]
# (https://huggingface.co/docs/datasets/package_reference/main_classes.html?highlight=cast_column#datasets.DatasetDict.cast_column)
# method.
# This operation does not change the audio in-place, but rather signals to `datasets` to resample
# audio samples on the fly the first time that they are loaded:
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))
# Re-loading the first audio sample in the Common Voice dataset will resample it to the
# desired sampling rate:
print(common_voice["train"][0])
# We'll define our pre-processing strategy. We advise that you **do not** lower-case the transcriptions
# or remove punctuation unless mixing different datasets.
# This will enable you to fine-tune Whisper models that can predict punctuation and casing.
# Later, you will see how we can evaluate the predictions without punctuation or casing, so that
# the models benefit from the WER improvement obtained by normalising the transcriptions while
# still predicting fully formatted transcriptions.
do_lower_case = False
do_remove_punctuation = False
normalizer = BasicTextNormalizer()
# Now we can write a function to prepare our data ready for the model:
# 1. We load and resample the audio data by calling `batch["audio"]`.
# As explained above, 🤗 Datasets performs any necessary resampling operations on the fly.
# 2. We use the feature extractor to compute the log-Mel spectrogram input features from our
# 1-dimensional audio array.
# 3. We perform any optional pre-processing (lower-case or remove punctuation).
# 4. We encode the transcriptions to label ids through the use of the tokenizer.
def prepare_dataset(batch):
# load and (possibly) resample audio data to 16kHz
audio = batch["audio"]
# compute log-Mel input features from input audio array
batch["input_features"] = processor.feature_extractor(
audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
# compute input length of audio sample in seconds
batch["input_length"] = len(audio["array"]) / audio["sampling_rate"]
# optional pre-processing steps
transcription = batch["sentence"]
if do_lower_case:
transcription = transcription.lower()
if do_remove_punctuation:
transcription = normalizer(transcription).strip()
# encode target text to label ids
batch["labels"] = processor.tokenizer(transcription).input_ids
return batch
# We can apply the data preparation function to all of our training examples using dataset's
# `.map` method.
# The argument `num_proc` specifies how many CPU cores to use. Setting `num_proc` > 1 will
# enable multiprocessing. If the `.map` method hangs with multiprocessing, set `num_proc=1`
# and process the dataset sequentially.
common_voice = common_voice.map(prepare_dataset,
remove_columns=common_voice.column_names["train"],
num_proc=2)
# Finally, we filter any training data with audio samples longer than 30s.
# These samples would otherwise be truncated by the Whisper feature-extractor which could affect
# the stability of training.
# We define a function that returns `True` for samples that are less than 30s, and `False` for
# those that are longer:
max_input_length = 30.0
def is_audio_in_length_range(length):
return length < max_input_length
# We apply our filter function to all samples of our training dataset through 🤗 Datasets'
# `.filter` method:
common_voice["train"] = common_voice["train"].filter(
is_audio_in_length_range,
input_columns=["input_length"],
)
# ## Training and Evaluation
# Now that we've prepared our data, we're ready to dive into the training pipeline.
# The [🤗 Trainer]
# (https://huggingface.co/transformers/master/main_classes/trainer.html?highlight=trainer)
# will do much of the heavy lifting for us. All we have to do is:
# - Define a data collator: the data collator takes our pre-processed data and prepares PyTorch
# tensors ready for the model.
# - Evaluation metrics: during evaluation, we want to evaluate the model using the
# [word error rate (WER)] (https://huggingface.co/metrics/wer) metric.
# We need to define a `compute_metrics` function that handles this computation.
# - Load a pre-trained checkpoint: we need to load a pre-trained checkpoint and configure it correctly
# for training.
# - Define the training configuration: this will be used by the 🤗 Trainer to define the training
# schedule.
# Once we've fine-tuned the model, we will evaluate it on the test data to verify that we have
# correctly trained it to transcribe speech in Hindi.
# ### Define a Data Collator
# The data collator for a sequence-to-sequence speech model is unique in the sense that it treats
# the `input_features` and `labels` independently: the `input_features` must be handled by the
# feature extractor and the `labels` by the tokenizer.
# The `input_features` are already padded to 30s and converted to a log-Mel spectrogram of fixed
# dimension by action of the feature extractor, so all we have to do is convert the `input_features`
# to batched PyTorch tensors.
# We do this using the feature extractor's `.pad` method with `return_tensors=pt`.
# The `labels` on the other hand are un-padded. We first pad the sequences to the maximum length
# in the batch using the tokenizer's `.pad` method. The padding tokens are then replaced by `-100`
# so that these tokens are **not** taken into account when computing the loss.
# We then cut the BOS token from the start of the label sequence as we append it later during training.
# We can leverage the `WhisperProcessor` we defined earlier to perform both the feature extractor
# and the tokenizer operations:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]])\
-> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need different
# padding methods.
# First treat the audio inputs by simply returning torch tensors.
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]
# pad the labels to max length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
# if bos token is appended in previous tokenization step, cut bos token here as it
# gets appended later.
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
# Let's initialise the data collator we've just defined:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
# ### Evaluation Metrics
# We'll use the word error rate (WER) metric, the 'de-facto' metric for assessing ASR systems.
# For more information, refer to the WER
# [docs] (https://huggingface.co/metrics/wer).
# We'll load the WER metric from 🤗 Evaluate:
metric = evaluate.load("wer")
# We then simply have to define a function that takes our model predictions and returns the WER metric.
# This function, called `compute_metrics`, first replaces `-100` with the `pad_token_id` in the
# `label_ids` (undoing the step we applied in the data collator to ignore padded tokens correctly in
# the loss).
# It then decodes the predicted and label ids to strings. Finally, it computes the WER between the
# predictions and reference labels.
# Here, we have the option of evaluating with the 'normalised' transcriptions and predictions.
# We recommend you set this to `True` to benefit from the WER improvement obtained by normalising
# the transcriptions.
# Evaluate with the 'normalised' WER
do_normalize_eval = True
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
# replace -100 with the pad_token_id
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
# we do not want to group tokens when computing the metrics
pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
if do_normalize_eval:
pred_str = [normalizer(pred) for pred in pred_str]
label_str = [normalizer(label) for label in label_str]
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
# ### Load a Pre-Trained Checkpoint
# Now let's load the pre-trained Whisper `small` checkpoint. Again, this is trivial through
# use of 🤗 Transformers!
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
# define your language of choice here
model.generation_config.language = "hi"
# Override generation arguments - no tokens are forced as decoder outputs
# (see [`forced_decoder_ids`]
# (https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.forced_decoder_ids)),
# no tokens are suppressed during generation
# (see [`suppress_tokens`]
# (https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.suppress_tokens)).
# Set `use_cache` to False since we're using gradient checkpointing, and the two are incompatible:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model.config.use_cache = False
# ### Define the Training Configuration
# In the final step, we define all the parameters related to training.
# For more detail on the training arguments, refer to the Seq2SeqTrainingArguments
# [docs]
# (https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments).
training_args = Seq2SeqTrainingArguments(
output_dir="./",
per_device_train_batch_size=8,
gradient_accumulation_steps=8, # increase by 2x for every 2x decrease in batch size
learning_rate=1e-5,
warmup_steps=500,
max_steps=5000,
gradient_checkpointing=True,
fp16=True,
evaluation_strategy="steps",
per_device_eval_batch_size=4,
predict_with_generate=True,
generation_max_length=225,
save_steps=1000,
eval_steps=1000,
logging_steps=25,
report_to=["tensorboard"],
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
push_to_hub=True,
)
# **Note**: if one does not want to upload the model checkpoints to the Hub, set `push_to_hub=False`.
# We can forward the training arguments to the 🤗 Trainer along with our model, dataset, data collator
# and `compute_metrics` function:
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=common_voice["train"],
eval_dataset=common_voice["test"],
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor,
)
# We'll save the processor object once before starting training. Since the processor is not trainable,
# it won't change over the course of training:
processor.save_pretrained(training_args.output_dir)
# ### Training
# Training will take approximately 5-10 hours depending on your GPU. The peak GPU memory for the
# given training configuration is approximately 36GB.
# Depending on your GPU, it is possible that you will encounter a CUDA `"out-of-memory"` error when
# you launch training. In this case, you can reduce the `per_device_train_batch_size` incrementally
# by factors of 2 and employ [`gradient_accumulation_steps`]
# (https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments.gradient_accumulation_steps)
# to compensate.
# To launch training, simply execute:
trainer.train()
# We can label our checkpoint with the `whisper-event` tag on push by setting the appropriate
# keyword arguments (kwargs):
kwargs = {
"dataset_tags": "mozilla-foundation/common_voice_11_0",
"dataset": "Common Voice 11.0", # a 'pretty' name for the training dataset
"language": "hi",
"model_name": "Whisper Small Hi - Sanchit Gandhi", # a 'pretty' name for your model
"finetuned_from": "openai/whisper-small",
"tasks": "automatic-speech-recognition",
"tags": "whisper-event",
}
# The training results can now be uploaded to the Hub. To do so, execute the `push_to_hub`
# command and save the preprocessor object we created:
trainer.push_to_hub(**kwargs)
# ## Closing Remarks
# If you're interested in fine-tuning other Transformers models, both for English and multilingual ASR,
# be sure to check out the examples scripts at
# [examples/pytorch/speech-recognition]
# (https://github.com/huggingface/transformers/tree/main/examples/pytorch/speech-recognition).