wanchichen commited on
Commit
1b9d89a
·
verified ·
1 Parent(s): b974e77

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +27 -1
README.md CHANGED
@@ -166,7 +166,7 @@ The code for XEUS is still in progress of being merged into the main ESPnet repo
166
  pip install -e git+git://github.com/wanchichen/espnet.git@ssl
167
  ```
168
 
169
- XEUS supports [Flash Attention], which can be installed as follows:
170
 
171
  ```
172
  pip install flash-attn --no-build-isolation
@@ -174,6 +174,9 @@ pip install flash-attn --no-build-isolation
174
 
175
  ## Usage
176
 
 
 
 
177
  ```python
178
  from torch.nn.utils.rnn import pad_sequence
179
  from espnet2.tasks.ssl import SSLTask
@@ -187,6 +190,10 @@ xeus_model, xeus_train_args = SSLTask.build_model_from_file(
187
  device,
188
  )
189
 
 
 
 
 
190
  wavs, sampling_rate = sf.read('/path/to/audio.wav') # sampling rate should be 16000
191
  wav_lengths = torch.LongTensor([len(wav) for wav in [wavs]]).to(device)
192
  wavs = pad_sequence([wavs], batch_first=True).to(device)
@@ -195,6 +202,25 @@ wavs = pad_sequence([wavs], batch_first=True).to(device)
195
  feats = xeus_model.encode(wavs, wav_lengths, use_mask=False, use_final_output=False)[0][-1] # take the output of the last layer
196
  ```
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  ## Results
199
 
200
  ![image/png](https://cdn-uploads.huggingface.co/production/uploads/630438615c70c21d0eae6613/RCAWBxSuDLXJ5zdj-OBdn.png)
 
166
  pip install -e git+git://github.com/wanchichen/espnet.git@ssl
167
  ```
168
 
169
+ XEUS supports [Flash Attention](), which can be installed as follows:
170
 
171
  ```
172
  pip install flash-attn --no-build-isolation
 
174
 
175
  ## Usage
176
 
177
+
178
+ Default Usage:
179
+
180
  ```python
181
  from torch.nn.utils.rnn import pad_sequence
182
  from espnet2.tasks.ssl import SSLTask
 
190
  device,
191
  )
192
 
193
+ use_flash_attn = False
194
+ [layer.use_flash_attn = True for layer in xeus_model.encoder.encoders]
195
+ xeus_model.use_flash_attn
196
+
197
  wavs, sampling_rate = sf.read('/path/to/audio.wav') # sampling rate should be 16000
198
  wav_lengths = torch.LongTensor([len(wav) for wav in [wavs]]).to(device)
199
  wavs = pad_sequence([wavs], batch_first=True).to(device)
 
202
  feats = xeus_model.encode(wavs, wav_lengths, use_mask=False, use_final_output=False)[0][-1] # take the output of the last layer
203
  ```
204
 
205
+ With Flash Attention:
206
+
207
+ ```python
208
+ [layer.use_flash_attn = True for layer in xeus_model.encoder.encoders]
209
+
210
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
211
+ feats = xeus_model.encode(wavs, wav_lengths, use_mask=False, use_final_output=False)[0][-1]
212
+ ```
213
+
214
+ Tune the masking settings:
215
+
216
+ ```python
217
+
218
+ xeus_model.masker.mask_prob = 0.65 # default 0.8
219
+ xeus_model.masker.mask_length = 20 # default 10
220
+ xeus_model.masker.mask_selection = 'static' # default uniform
221
+ xeus_model.train()
222
+ ```
223
+
224
  ## Results
225
 
226
  ![image/png](https://cdn-uploads.huggingface.co/production/uploads/630438615c70c21d0eae6613/RCAWBxSuDLXJ5zdj-OBdn.png)