File size: 7,839 Bytes
23ab61f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 |
# Seed needs to be set at top of yaml, before objects with parameters
# are instantiated
seed: 1994
__set_seed: !apply:torch.manual_seed [!ref <seed>]
skip_training: True
output_folder: !ref output_folder_seq2seq_cv_podcast_arhiv_augmentation_128_emb_5000_vocab
output_wer_folder: !ref <output_folder>/
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt
lm_folder: LM/output_folder_lm
# Data files
data_folder: "../../data/combined_data/speechbrain_splits"
wav2vec2_hub: facebook/wav2vec2-large-xlsr-53
wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
# pretrained_tokenizer_path: "Tokenizer/output_folder_cv/1K_subword_unigram" # Use this for the CV model
pretrained_tokenizer_path: "Tokenizer/output_folder_cv_podcast_arhiv/5K_subword_unigram" # Use this for the CV+Podcast+Arhiv model
####################### Training Parameters ####################################
number_of_epochs: 50
number_of_ctc_epochs: 15
# batch_size: 16
# batch_size: 6 # for cv+podcast
batch_size: 6 # for cv+podcast+arhiv
label_smoothing: 0.1
lr: 0.0001
ctc_weight: 0.5
opt_class: !name:torch.optim.Adam
lr: !ref <lr>
lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr>
improvement_threshold: 0.0025
annealing_factor: 0.8
patient: 0
# Dataloader options
num_workers: 4
num_workers: !ref <num_workers>
batch_size: !ref <batch_size>
num_workers: !ref <num_workers>
batch_size: !ref <batch_size>
batch_size: 1
####################### Model Parameters #######################################
dropout: 0.15
wav2vec_output_dim: 1024
emb_size: 128
dec_neurons: 1024
dec_layers: 1
output_neurons: 5000
blank_index: 0
bos_index: 0
eos_index: 0
unk_index: 0
# Decoding parameters
min_decode_ratio: 0.0
max_decode_ratio: 1.0
valid_beam_size: 10
test_beam_size: 10
using_eos_threshold: True
eos_threshold: 1.5
using_max_attn_shift: True
max_attn_shift: 300
temperature: 1.0
ctc_window_size: 200
temperature_lm: 1.25
# Scoring parameters
ctc_weight_decode: 0.0
coverage_penalty: 1.5
lm_weight: 0.0
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <number_of_epochs>
# Wav2vec2 encoder
encoder_w2v2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
source: !ref <wav2vec2_hub>
output_norm: True
freeze: False
freeze_feature_extractor: True
save_path: !ref <wav2vec2_folder>
output_all_hiddens: False
embedding: !new:speechbrain.nnet.embedding.Embedding
num_embeddings: !ref <output_neurons>
embedding_dim: !ref <emb_size>
# Attention-based RNN decoder.
decoder: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
enc_dim: !ref <wav2vec_output_dim>
input_size: !ref <emb_size>
rnn_type: gru
attn_type: location
hidden_size: !ref <dec_neurons>
attn_dim: 512
num_layers: !ref <dec_layers>
scaling: 1.0
channels: 10
kernel_size: 100
re_init: True
dropout: !ref <dropout>
ctc_lin: !new:speechbrain.nnet.linear.Linear
input_size: !ref <wav2vec_output_dim>
n_neurons: !ref <output_neurons>
seq_lin: !new:speechbrain.nnet.linear.Linear
input_size: !ref <dec_neurons>
n_neurons: !ref <output_neurons>
log_softmax: !new:speechbrain.nnet.activations.Softmax
apply_log: True
ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
blank_index: !ref <blank_index>
nll_cost: !name:speechbrain.nnet.losses.nll_loss
label_smoothing: 0.1
# This is the RNNLM that is used according to the Huggingface repository
# NB: It has to match the pre-trained RNNLM!!
#lm_model: !new:speechbrain.lobes.models.RNNLM.RNNLM
# output_neurons: !ref <output_neurons>
# embedding_dim: !ref <emb_size>
# activation: !name:torch.nn.LeakyReLU
# dropout: 0.0
# rnn_layers: 2
# rnn_neurons: 2048
# dnn_blocks: 1
# dnn_neurons: 512
# return_hidden: True # For inference
tokenizer: !new:sentencepiece.SentencePieceProcessor
model_file: !ref <pretrained_tokenizer_path>/5000_unigram.model
encoder_w2v2: !ref <encoder_w2v2>
embedding: !ref <embedding>
decoder: !ref <decoder>
ctc_lin: !ref <ctc_lin>
seq_lin: !ref <seq_lin>
#lm_model: !ref <lm_model>
model: !new:torch.nn.ModuleList
- [!ref <encoder_w2v2>, !ref <embedding>, !ref <decoder>, !ref <ctc_lin>, !ref <seq_lin>]
############################## Decoding & optimiser ############################
#coverage_scorer: !new:speechbrain.decoders.scorer.CoverageScorer
# vocab_size: !ref <output_neurons>
#rnnlm_scorer: !new:speechbrain.decoders.scorer.RNNLMScorer
# language_model: !ref <lm_model>
# temperature: !ref <temperature_lm>
#scorer: !new:speechbrain.decoders.scorer.ScorerBuilder
# full_scorers: [!ref <rnnlm_scorer>,
# !ref <coverage_scorer>]
# weights:
# rnnlm: !ref <lm_weight>
# coverage: !ref <coverage_penalty>
# Search
greedy_search: !new:speechbrain.decoders.S2SRNNGreedySearcher
embedding: !ref <embedding>
decoder: !ref <decoder>
linear: !ref <seq_lin>
bos_index: !ref <bos_index>
eos_index: !ref <eos_index>
min_decode_ratio: !ref <min_decode_ratio>
max_decode_ratio: !ref <max_decode_ratio>
test_search: !new:speechbrain.decoders.S2SRNNBeamSearcher
embedding: !ref <embedding>
decoder: !ref <decoder>
linear: !ref <seq_lin>
bos_index: !ref <bos_index>
eos_index: !ref <eos_index>
min_decode_ratio: !ref <min_decode_ratio>
max_decode_ratio: !ref <max_decode_ratio>
beam_size: !ref <test_beam_size>
eos_threshold: !ref <eos_threshold>
using_max_attn_shift: !ref <using_max_attn_shift>
max_attn_shift: !ref <max_attn_shift>
temperature: !ref <temperature>
#scorer: !ref <scorer>
############################## Augmentations ###################################
# Speed perturbation
speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
orig_freq: 16000
speeds: [95, 100, 105]
# Frequency drop: randomly drops a number of frequency bands to zero.
drop_freq: !new:speechbrain.augment.time_domain.DropFreq
drop_freq_low: 0
drop_freq_high: 1
drop_freq_count_low: 1
drop_freq_count_high: 3
drop_freq_width: 0.05
# Time drop: randomly drops a number of temporal chunks.
drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
drop_length_low: 1000
drop_length_high: 2000
drop_count_low: 1
drop_count_high: 5
# Augmenter: Combines previously defined augmentations to perform data augmentation
wav_augment: !new:speechbrain.augment.augmenter.Augmenter
concat_original: False
min_augmentations: 1
max_augmentations: 3
augment_prob: 0.5
augmentations: [
!ref <speed_perturb>,
!ref <drop_freq>,
!ref <drop_chunk>]
############################## Logging and Pretrainer ##########################
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: !ref <save_folder>
model: !ref <model>
scheduler: !ref <lr_annealing>
counter: !ref <epoch_counter>
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
save_file: !ref <train_log>
error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
split_tokens: True
# The pretrainer allows a mapping between pretrained files and instances that
# are declared in the yaml. E.g here, we will download the file lm.ckpt
# and it will be loaded into "lm" which is pointing to the <lm_model> defined
# before.
#pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
# collect_in: !ref <lm_folder>
# loadables:
# lm: !ref <lm_model>
# paths:
# lm: !ref <lm_folder>/save/CKPT+2024-07-19+14-16-05+00/model.ckpt