wanchichen
commited on
Update README.md
Browse files
README.md
CHANGED
@@ -166,7 +166,7 @@ The code for XEUS is still in progress of being merged into the main ESPnet repo
|
|
166 |
pip install -e git+git://github.com/wanchichen/espnet.git@ssl
|
167 |
```
|
168 |
|
169 |
-
XEUS supports [Flash Attention], which can be installed as follows:
|
170 |
|
171 |
```
|
172 |
pip install flash-attn --no-build-isolation
|
@@ -174,6 +174,9 @@ pip install flash-attn --no-build-isolation
|
|
174 |
|
175 |
## Usage
|
176 |
|
|
|
|
|
|
|
177 |
```python
|
178 |
from torch.nn.utils.rnn import pad_sequence
|
179 |
from espnet2.tasks.ssl import SSLTask
|
@@ -187,6 +190,10 @@ xeus_model, xeus_train_args = SSLTask.build_model_from_file(
|
|
187 |
device,
|
188 |
)
|
189 |
|
|
|
|
|
|
|
|
|
190 |
wavs, sampling_rate = sf.read('/path/to/audio.wav') # sampling rate should be 16000
|
191 |
wav_lengths = torch.LongTensor([len(wav) for wav in [wavs]]).to(device)
|
192 |
wavs = pad_sequence([wavs], batch_first=True).to(device)
|
@@ -195,6 +202,25 @@ wavs = pad_sequence([wavs], batch_first=True).to(device)
|
|
195 |
feats = xeus_model.encode(wavs, wav_lengths, use_mask=False, use_final_output=False)[0][-1] # take the output of the last layer
|
196 |
```
|
197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
## Results
|
199 |
|
200 |
![image/png](https://cdn-uploads.huggingface.co/production/uploads/630438615c70c21d0eae6613/RCAWBxSuDLXJ5zdj-OBdn.png)
|
|
|
166 |
pip install -e git+git://github.com/wanchichen/espnet.git@ssl
|
167 |
```
|
168 |
|
169 |
+
XEUS supports [Flash Attention](), which can be installed as follows:
|
170 |
|
171 |
```
|
172 |
pip install flash-attn --no-build-isolation
|
|
|
174 |
|
175 |
## Usage
|
176 |
|
177 |
+
|
178 |
+
Default Usage:
|
179 |
+
|
180 |
```python
|
181 |
from torch.nn.utils.rnn import pad_sequence
|
182 |
from espnet2.tasks.ssl import SSLTask
|
|
|
190 |
device,
|
191 |
)
|
192 |
|
193 |
+
use_flash_attn = False
|
194 |
+
[layer.use_flash_attn = True for layer in xeus_model.encoder.encoders]
|
195 |
+
xeus_model.use_flash_attn
|
196 |
+
|
197 |
wavs, sampling_rate = sf.read('/path/to/audio.wav') # sampling rate should be 16000
|
198 |
wav_lengths = torch.LongTensor([len(wav) for wav in [wavs]]).to(device)
|
199 |
wavs = pad_sequence([wavs], batch_first=True).to(device)
|
|
|
202 |
feats = xeus_model.encode(wavs, wav_lengths, use_mask=False, use_final_output=False)[0][-1] # take the output of the last layer
|
203 |
```
|
204 |
|
205 |
+
With Flash Attention:
|
206 |
+
|
207 |
+
```python
|
208 |
+
[layer.use_flash_attn = True for layer in xeus_model.encoder.encoders]
|
209 |
+
|
210 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
211 |
+
feats = xeus_model.encode(wavs, wav_lengths, use_mask=False, use_final_output=False)[0][-1]
|
212 |
+
```
|
213 |
+
|
214 |
+
Tune the masking settings:
|
215 |
+
|
216 |
+
```python
|
217 |
+
|
218 |
+
xeus_model.masker.mask_prob = 0.65 # default 0.8
|
219 |
+
xeus_model.masker.mask_length = 20 # default 10
|
220 |
+
xeus_model.masker.mask_selection = 'static' # default uniform
|
221 |
+
xeus_model.train()
|
222 |
+
```
|
223 |
+
|
224 |
## Results
|
225 |
|
226 |
![image/png](https://cdn-uploads.huggingface.co/production/uploads/630438615c70c21d0eae6613/RCAWBxSuDLXJ5zdj-OBdn.png)
|