sanchit-gandhi commited on
Commit
ed727c2
·
1 Parent(s): 22aad52

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +21 -10
README.md CHANGED
@@ -91,15 +91,23 @@ To transcribe audio files the model can be used as a standalone acoustic model a
91
  ```
92
 
93
  ## Evaluation
94
-
95
- This code snippet shows how to evaluate **facebook/wav2vec2-base-960h** on LibriSpeech's "clean" and "other" test data.
96
-
 
 
 
 
 
 
 
 
 
97
  ```python
 
98
  from datasets import load_dataset
99
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
100
- import torch
101
- from jiwer import wer
102
-
103
 
104
  librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")
105
 
@@ -107,18 +115,21 @@ model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to("cuda")
107
  processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
108
 
109
  def map_to_pred(batch):
110
- input_values = processor(batch["audio"]["array"], return_tensors="pt", padding="longest").input_values
 
 
111
  with torch.no_grad():
112
  logits = model(input_values.to("cuda")).logits
113
 
114
  predicted_ids = torch.argmax(logits, dim=-1)
115
  transcription = processor.batch_decode(predicted_ids)
116
- batch["transcription"] = transcription
117
  return batch
118
 
119
- result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["audio"])
 
120
 
121
- print("WER:", wer(result["text"], result["transcription"]))
122
  ```
123
 
124
  *Result (WER)*:
 
91
  ```
92
 
93
  ## Evaluation
94
+
95
+ First, ensure the required Python packages are installed. We'll require `transformers` for running the Wav2Vec2 model,
96
+ `datasets` for loading the LibriSpeech dataset, and `evaluate` plus `jiwer` for computing the word-error rate (WER):
97
+
98
+ ```
99
+ pip install --upgrade pip
100
+ pip install --upgrade transformers datasets evaluate jiwer
101
+ ```
102
+
103
+ The following code snippet shows how to evaluate **facebook/wav2vec2-base-960h** on LibriSpeech's "clean" and "other" test data.
104
+ The batch size can be set according to your device, and is set to `8` by default:
105
+
106
  ```python
107
+ import torch
108
  from datasets import load_dataset
109
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
110
+ from evaluate import load
 
 
111
 
112
  librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")
113
 
 
115
  processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
116
 
117
  def map_to_pred(batch):
118
+ audios = [audio["array"] for audio in batch["audio"]]
119
+ sampling_rate = batch["audio"][0]["sampling_rate"]
120
+ input_values = processor(audios, sampling_rate=sampling_rate, return_tensors="pt", padding="longest").input_values
121
  with torch.no_grad():
122
  logits = model(input_values.to("cuda")).logits
123
 
124
  predicted_ids = torch.argmax(logits, dim=-1)
125
  transcription = processor.batch_decode(predicted_ids)
126
+ batch["transcription"] = [t for t in transcription]
127
  return batch
128
 
129
+ result = librispeech_eval.map(map_to_pred, batched=True, batch_size=8, remove_columns=["audio"])
130
+ wer = load("wer")
131
 
132
+ print("WER:", wer.compute(references=result["text"], predictions=result["transcription"]))
133
  ```
134
 
135
  *Result (WER)*: