cstorm125 commited on
Commit
cbc9d72
·
1 Parent(s): c7ca823

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +114 -1
README.md CHANGED
@@ -10,4 +10,117 @@ tags:
10
  license: apache-2.0
11
  ---
12
 
13
- # `wav2vec2-large-xlsr-53-th` - Finetuning `wav2vec2-large-xlsr-53` on Thai [Common Voice 7.0](https://commonvoice.mozilla.org/en/datasets)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  license: apache-2.0
11
  ---
12
 
13
+ # `wav2vec2-large-xlsr-53-th`
14
+ Finetuning `wav2vec2-large-xlsr-53` on Thai [Common Voice 7.0](https://commonvoice.mozilla.org/en/datasets)
15
+
16
+ We finetune [wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) based on [Fine-tuning Wav2Vec2 for English ASR](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/Fine_tuning_Wav2Vec2_for_English_ASR.ipynb) using Thai examples of [Common Voice Corpus 7.0](https://commonvoice.mozilla.org/en/datasets). The notebooks and scripts can be found in [vistec-ai/wav2vec2-large-xlsr-53-th](https://github.com/vistec-ai/wav2vec2-large-xlsr-53-th). The pretrained model and processor can be found at [airesearch/wav2vec2-large-xlsr-53-th](https://huggingface.co/airesearch/wav2vec2-large-xlsr-53-th).
17
+
18
+ ## Usage
19
+
20
+ ```
21
+ #load pretrained processor and model
22
+ processor = Wav2Vec2Processor.from_pretrained("vistec-ai/wav2vec2-large-xlsr-53-th")
23
+ model = Wav2Vec2ForCTC.from_pretrained("vistec-ai/wav2vec2-large-xlsr-53-th")
24
+
25
+ #function to resample to 16_000
26
+ def speech_file_to_array_fn(batch,
27
+ text_col="sentence",
28
+ fname_col="path",
29
+ resampling_to=16000):
30
+ speech_array, sampling_rate = torchaudio.load(batch[fname_col])
31
+ resampler=torchaudio.transforms.Resample(sampling_rate, resampling_to)
32
+ batch["speech"] = resampler(speech_array)[0].numpy()
33
+ batch["sampling_rate"] = resampling_to
34
+ batch["target_text"] = batch[text_col]
35
+ return batch
36
+
37
+ #get 2 examples as sample input
38
+ test_dataset = test_dataset.map(speech_file_to_array_fn)
39
+ inputs = processor(test_dataset["speech"][:2], sampling_rate=16_000, return_tensors="pt", padding=True)
40
+
41
+ #infer
42
+ with torch.no_grad():
43
+ logits = model(inputs.input_values,).logits
44
+
45
+ predicted_ids = torch.argmax(logits, dim=-1)
46
+
47
+ print("Prediction:", processor.batch_decode(predicted_ids))
48
+ print("Reference:", test_dataset["sentence"][:2])
49
+
50
+ >> Prediction: ['และ เขา ก็ สัมผัส ดีบุก', 'คุณ สามารถ รับทราบ เมื่อ ข้อความ นี้ ถูก อ่าน แล้ว']
51
+ >> Reference: ['และเขาก็สัมผัสดีบุก', 'คุณสามารถรับทราบเมื่อข้อความนี้ถูกอ่านแล้ว']
52
+ ```
53
+
54
+ ## Datasets
55
+
56
+ Common Voice Corpus 7.0](https://commonvoice.mozilla.org/en/datasets) contains 133 validated hours of Thai (255 total hours) at 5GB. We pre-tokenize with `pythainlp.tokenize.word_tokenize`. We preprocess the dataset using cleaning rules described in `notebooks/cv-preprocess.ipynb` by [@tann9949](https://github.com/tann9949). We then deduplicate and split as described in [ekapolc/Thai_commonvoice_split](https://github.com/ekapolc/Thai_commonvoice_split) in order to 1) avoid data leakage due to random splits after cleaning in [Common Voice Corpus 7.0](https://commonvoice.mozilla.org/en/datasets) and 2) preserve the majority of the data for the training set. The dataset loading script is `scripts/th_common_voice_70.py`. The resulting dataset is as follows:
57
+
58
+ ```
59
+ DatasetDict({
60
+ train: Dataset({
61
+ features: ['path', 'sentence'],
62
+ num_rows: 86586
63
+ })
64
+ test: Dataset({
65
+ features: ['path', 'sentence'],
66
+ num_rows: 2502
67
+ })
68
+ validation: Dataset({
69
+ features: ['path', 'sentence'],
70
+ num_rows: 3027
71
+ })
72
+ })
73
+ ```
74
+
75
+ ## Training
76
+
77
+ We fintuned using the following configuration on a single V100 GPU and chose the checkpoint with the lowest validation loss. The finetuning script is `scripts/wav2vec2_finetune.py`
78
+
79
+ ```
80
+ # create model
81
+ model = Wav2Vec2ForCTC.from_pretrained(
82
+ "facebook/wav2vec2-large-xlsr-53",
83
+ attention_dropout=0.1,
84
+ hidden_dropout=0.1,
85
+ feat_proj_dropout=0.0,
86
+ mask_time_prob=0.05,
87
+ layerdrop=0.1,
88
+ gradient_checkpointing=True,
89
+ ctc_loss_reduction="mean",
90
+ pad_token_id=processor.tokenizer.pad_token_id,
91
+ vocab_size=len(processor.tokenizer)
92
+ )
93
+ model.freeze_feature_extractor()
94
+ training_args = TrainingArguments(
95
+ output_dir="../data/wav2vec2-large-xlsr-53-thai",
96
+ group_by_length=True,
97
+ per_device_train_batch_size=32,
98
+ gradient_accumulation_steps=1,
99
+ per_device_eval_batch_size=16,
100
+ metric_for_best_model='wer',
101
+ evaluation_strategy="steps",
102
+ eval_steps=1000,
103
+ logging_strategy="steps",
104
+ logging_steps=1000,
105
+ save_strategy="steps",
106
+ save_steps=1000,
107
+ num_train_epochs=100,
108
+ fp16=True,
109
+ learning_rate=1e-4,
110
+ warmup_steps=1000,
111
+ save_total_limit=3,
112
+ report_to="tensorboard"
113
+ )
114
+ ```
115
+
116
+ ## Evaluation
117
+
118
+ We benchmark on the test set using WER with words tokenized by [PyThaiNLP](https://github.com/PyThaiNLP/pythainlp) 2.3.1 and CER. We also measure performance when spell correction using [TNC](http://www.arts.chula.ac.th/ling/tnc/) ngrams is applied. Evaluation codes can be found in `notebooks/wav2vec2_finetuning_tutorial.ipynb`
119
+
120
+ | | WER | CER |
121
+ |--------------------------|------------|------------|
122
+ | without spell correction | 0.20754109 | 0.03727126 |
123
+ | with spell correction | TBD | TBD |
124
+
125
+
126
+