pascal lim
commited on
Commit
·
9e32bde
1
Parent(s):
da15453
update eval script with lm
Browse files- eval_lm.py +24 -25
eval_lm.py
CHANGED
@@ -4,7 +4,7 @@ import re
|
|
4 |
from typing import Dict
|
5 |
|
6 |
from datasets import Audio, Dataset, load_dataset, load_metric
|
7 |
-
|
8 |
from transformers import AutoFeatureExtractor, pipeline, Wav2Vec2ProcessorWithLM
|
9 |
|
10 |
|
@@ -62,39 +62,38 @@ def normalize_text(text: str) -> str:
|
|
62 |
|
63 |
return text
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
def main(args):
|
67 |
# load dataset
|
68 |
dataset = load_dataset(args.dataset, args.config, split=args.split, use_auth_token=True)
|
|
|
|
|
69 |
|
70 |
# for testing: only process the first two examples as a test
|
71 |
# dataset = dataset.select(range(10))
|
72 |
|
73 |
# load processor
|
74 |
-
processor = Wav2Vec2ProcessorWithLM.from_pretrained("
|
75 |
|
76 |
-
model = Wav2Vec2ForCTC.from_pretrained(
|
77 |
-
|
78 |
-
sampling_rate = feature_extractor.sampling_rate
|
79 |
-
|
80 |
-
# resample audio
|
81 |
-
dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
|
82 |
-
|
83 |
-
# load eval pipeline
|
84 |
-
asr = pipeline("automatic-speech-recognition", model=args.model_id)
|
85 |
-
|
86 |
-
# map function to decode audio
|
87 |
-
def map_to_pred(batch):
|
88 |
-
prediction = asr(
|
89 |
-
batch["audio"]["array"], chunk_length_s=args.chunk_length_s, stride_length_s=args.stride_length_s
|
90 |
-
)
|
91 |
-
|
92 |
-
batch["prediction"] = prediction["text"]
|
93 |
-
batch["target"] = normalize_text(batch["sentence"])
|
94 |
-
return batch
|
95 |
|
96 |
# run inference on all examples
|
97 |
-
result = dataset.map(
|
98 |
|
99 |
# compute and log_results
|
100 |
# do not change function below
|
@@ -104,9 +103,9 @@ def main(args):
|
|
104 |
if __name__ == "__main__":
|
105 |
parser = argparse.ArgumentParser()
|
106 |
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
parser.add_argument(
|
111 |
"--dataset",
|
112 |
type=str,
|
|
|
4 |
from typing import Dict
|
5 |
|
6 |
from datasets import Audio, Dataset, load_dataset, load_metric
|
7 |
+
import torch
|
8 |
from transformers import AutoFeatureExtractor, pipeline, Wav2Vec2ProcessorWithLM
|
9 |
|
10 |
|
|
|
62 |
|
63 |
return text
|
64 |
|
65 |
+
def evaluate_with_lm(batch):
|
66 |
+
inputs = processor(batch["audio"]["array"], sampling_rate=16_000, return_tensors="pt", padding=True)
|
67 |
+
|
68 |
+
with torch.no_grad():
|
69 |
+
logits = model(**inputs.to('cuda')).logits
|
70 |
+
int_result = processor.batch_decode(logits.cpu().numpy())
|
71 |
+
|
72 |
+
batch["prediction"] = int_result.text
|
73 |
+
batch["target"] = normalize_text(batch["sentence"])
|
74 |
+
|
75 |
+
del int_result
|
76 |
+
torch.cuda.empty_cache()
|
77 |
+
|
78 |
+
return batch
|
79 |
|
80 |
def main(args):
|
81 |
# load dataset
|
82 |
dataset = load_dataset(args.dataset, args.config, split=args.split, use_auth_token=True)
|
83 |
+
# resample audio
|
84 |
+
dataset = dataset.cast_column("audio", Audio(sampling_rate=16_000))
|
85 |
|
86 |
# for testing: only process the first two examples as a test
|
87 |
# dataset = dataset.select(range(10))
|
88 |
|
89 |
# load processor
|
90 |
+
processor = Wav2Vec2ProcessorWithLM.from_pretrained("./")
|
91 |
|
92 |
+
model = Wav2Vec2ForCTC.from_pretrained("./")
|
93 |
+
model.to('cuda')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
# run inference on all examples
|
96 |
+
result = dataset.map(evaluate_with_lm, remove_columns=dataset.column_names)
|
97 |
|
98 |
# compute and log_results
|
99 |
# do not change function below
|
|
|
103 |
if __name__ == "__main__":
|
104 |
parser = argparse.ArgumentParser()
|
105 |
|
106 |
+
# parser.add_argument(
|
107 |
+
# "--model_id", type=str, required=True, help="Model identifier. Should be loadable with 🤗 Transformers"
|
108 |
+
# )
|
109 |
parser.add_argument(
|
110 |
"--dataset",
|
111 |
type=str,
|