|
import torch |
|
from speechbrain.pretrained import Pretrained |
|
|
|
class WhisperASR(Pretrained): |
|
"""A ready-to-use Whisper ASR model |
|
|
|
The class can be used to run only the encoder (encode()) to run the entire encoder-decoder whisper model |
|
(transcribe()) to transcribe speech. The given YAML must contains the fields |
|
specified in the *_NEEDED[] lists. |
|
|
|
Example |
|
------- |
|
>>> from speechbrain.pretrained.interfaces import foreign_class |
|
>>> tmpdir = getfixture("tmpdir") |
|
>>> asr_model = foreign_class(source="hf", |
|
... pymodule_file="custom_interface.py", |
|
... classname="WhisperASR", |
|
... hparams_file='hparams.yaml', |
|
... savedir=tmpdir, |
|
... ) |
|
>>> asr_model.transcribe_file("tests/samples/example2.wav") |
|
""" |
|
|
|
HPARAMS_NEEDED = ['language'] |
|
MODULES_NEEDED = ["whisper", "decoder"] |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.tokenizer = self.hparams.whisper.tokenizer |
|
self.tokenizer.set_prefix_tokens(self.hparams.language, "transcribe", False) |
|
self.hparams.decoder.set_decoder_input_tokens( |
|
self.tokenizer.prefix_tokens |
|
) |
|
|
|
def transcribe_file(self, path): |
|
"""Transcribes the given audiofile into a sequence of words. |
|
|
|
Arguments |
|
--------- |
|
path : str |
|
Path to audio file which to transcribe. |
|
|
|
Returns |
|
------- |
|
str |
|
The audiofile transcription produced by this ASR system. |
|
""" |
|
waveform = self.load_audio(path) |
|
|
|
batch = waveform.unsqueeze(0) |
|
rel_length = torch.tensor([1.0]) |
|
predicted_words, predicted_tokens = self.transcribe_batch( |
|
batch, rel_length |
|
) |
|
return predicted_words |
|
|
|
def encode_batch(self, wavs, wav_lens): |
|
"""Encodes the input audio into a sequence of hidden states |
|
|
|
The waveforms should already be in the model's desired format. |
|
You can call: |
|
``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)`` |
|
to get a correctly converted signal in most cases. |
|
|
|
Arguments |
|
--------- |
|
wavs : torch.tensor |
|
Batch of waveforms [batch, time, channels]. |
|
wav_lens : torch.tensor |
|
Lengths of the waveforms relative to the longest one in the |
|
batch, tensor of shape [batch]. The longest one should have |
|
relative length 1.0 and others len(waveform) / max_length. |
|
Used for ignoring padding. |
|
|
|
Returns |
|
------- |
|
torch.tensor |
|
The encoded batch |
|
""" |
|
wavs = wavs.float() |
|
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) |
|
encoder_out = self.mods.whisper.forward_encoder(wavs) |
|
return encoder_out |
|
|
|
def transcribe_batch(self, wavs, wav_lens): |
|
"""Transcribes the input audio into a sequence of words |
|
|
|
The waveforms should already be in the model's desired format. |
|
You can call: |
|
``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)`` |
|
to get a correctly converted signal in most cases. |
|
|
|
Arguments |
|
--------- |
|
wavs : torch.tensor |
|
Batch of waveforms [batch, time, channels]. |
|
wav_lens : torch.tensor |
|
Lengths of the waveforms relative to the longest one in the |
|
batch, tensor of shape [batch]. The longest one should have |
|
relative length 1.0 and others len(waveform) / max_length. |
|
Used for ignoring padding. |
|
|
|
Returns |
|
------- |
|
list |
|
Each waveform in the batch transcribed. |
|
tensor |
|
Each predicted token id. |
|
""" |
|
with torch.no_grad(): |
|
wav_lens = wav_lens.to(self.device) |
|
encoder_out = self.encode_batch(wavs, wav_lens) |
|
predicted_tokens, scores = self.mods.decoder(encoder_out, wav_lens) |
|
predicted_words = self.tokenizer.batch_decode( |
|
predicted_tokens, skip_special_tokens=True) |
|
if self.hparams.normalized_transcripts: |
|
predicted_words = [ |
|
self.tokenizer._normalize(text).split(" ") |
|
for text in predicted_words |
|
] |
|
|
|
|
|
return predicted_words, predicted_tokens |
|
|
|
def forward(self, wavs, wav_lens): |
|
"""Runs full transcription - note: no gradients through decoding""" |
|
return self.transcribe_batch(wavs, wav_lens) |
|
|