winglian commited on
Commit
db8a8af
·
unverified ·
1 Parent(s): 1470650

adds llama and mistral dropout support (#858)

Browse files

* adds llama and mistral dropout support

* gracefully handle attention dropout if not available yet

src/axolotl/monkeypatch/llama_attn_hijack_flash.py CHANGED
@@ -321,6 +321,8 @@ def flashattn_forward(
321
  # only on first autoregressive step q,k,v have same seqlen
322
  is_causal = key_states.shape == query_states.shape
323
 
 
 
324
  if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
325
  # special handling using sample packing
326
  qkv = torch.stack(
@@ -330,7 +332,12 @@ def flashattn_forward(
330
  qkv = rearrange(qkv, "b s ... -> (b s) ...")
331
 
332
  output = flash_attn_varlen_qkvpacked_func(
333
- qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
 
 
 
 
 
334
  )
335
  output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
336
  elif query_states.shape == key_states.shape:
@@ -353,7 +360,7 @@ def flashattn_forward(
353
  qkv_unpad,
354
  cu_seqlens_q,
355
  max_seqlen_q,
356
- 0.0,
357
  softmax_scale=None,
358
  causal=is_causal,
359
  )
@@ -366,6 +373,7 @@ def flashattn_forward(
366
  output = flash_attn_kvpacked_func(
367
  query_states,
368
  torch.stack([key_states, value_states], 2),
 
369
  causal=is_causal,
370
  )
371
  else:
@@ -398,7 +406,7 @@ def flashattn_forward(
398
  cu_seqlens_k,
399
  max_seqlen_q,
400
  max_seqlen_k,
401
- 0.0,
402
  softmax_scale=None,
403
  causal=is_causal,
404
  )
 
321
  # only on first autoregressive step q,k,v have same seqlen
322
  is_causal = key_states.shape == query_states.shape
323
 
324
+ dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
325
+
326
  if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
327
  # special handling using sample packing
328
  qkv = torch.stack(
 
332
  qkv = rearrange(qkv, "b s ... -> (b s) ...")
333
 
334
  output = flash_attn_varlen_qkvpacked_func(
335
+ qkv,
336
+ cu_seqlens,
337
+ max_seqlen,
338
+ dropout_p=dropout_rate,
339
+ softmax_scale=None,
340
+ causal=True,
341
  )
342
  output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
343
  elif query_states.shape == key_states.shape:
 
360
  qkv_unpad,
361
  cu_seqlens_q,
362
  max_seqlen_q,
363
+ dropout_p=dropout_rate,
364
  softmax_scale=None,
365
  causal=is_causal,
366
  )
 
373
  output = flash_attn_kvpacked_func(
374
  query_states,
375
  torch.stack([key_states, value_states], 2),
376
+ dropout_p=dropout_rate,
377
  causal=is_causal,
378
  )
379
  else:
 
406
  cu_seqlens_k,
407
  max_seqlen_q,
408
  max_seqlen_k,
409
+ dropout_p=dropout_rate,
410
  softmax_scale=None,
411
  causal=is_causal,
412
  )
src/axolotl/monkeypatch/mistral_attn_hijack_flash.py CHANGED
@@ -201,6 +201,8 @@ def flashattn_forward(
201
  # only on first autoregressive step q,k,v have same seqlen
202
  is_causal = key_states.shape == query_states.shape
203
 
 
 
204
  if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
205
  # special handling using sample packing
206
  qkv = torch.stack(
@@ -213,7 +215,7 @@ def flashattn_forward(
213
  qkv,
214
  cu_seqlens,
215
  max_seqlen,
216
- 0.0,
217
  softmax_scale=None,
218
  causal=True,
219
  window_size=window_size,
@@ -239,7 +241,7 @@ def flashattn_forward(
239
  qkv_unpad,
240
  cu_seqlens_q,
241
  max_seqlen_q,
242
- 0.0,
243
  softmax_scale=None,
244
  causal=is_causal,
245
  window_size=window_size,
@@ -253,6 +255,7 @@ def flashattn_forward(
253
  output = flash_attn_kvpacked_func(
254
  query_states,
255
  torch.stack([key_states, value_states], 2),
 
256
  causal=is_causal,
257
  window_size=window_size,
258
  )
@@ -286,7 +289,7 @@ def flashattn_forward(
286
  cu_seqlens_k,
287
  max_seqlen_q,
288
  max_seqlen_k,
289
- 0.0,
290
  softmax_scale=None,
291
  causal=is_causal,
292
  window_size=window_size,
 
201
  # only on first autoregressive step q,k,v have same seqlen
202
  is_causal = key_states.shape == query_states.shape
203
 
204
+ dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
205
+
206
  if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
207
  # special handling using sample packing
208
  qkv = torch.stack(
 
215
  qkv,
216
  cu_seqlens,
217
  max_seqlen,
218
+ dropout_p=dropout_rate,
219
  softmax_scale=None,
220
  causal=True,
221
  window_size=window_size,
 
241
  qkv_unpad,
242
  cu_seqlens_q,
243
  max_seqlen_q,
244
+ dropout_p=dropout_rate,
245
  softmax_scale=None,
246
  causal=is_causal,
247
  window_size=window_size,
 
255
  output = flash_attn_kvpacked_func(
256
  query_states,
257
  torch.stack([key_states, value_states], 2),
258
+ dropout_p=dropout_rate,
259
  causal=is_causal,
260
  window_size=window_size,
261
  )
 
289
  cu_seqlens_k,
290
  max_seqlen_q,
291
  max_seqlen_k,
292
+ dropout_p=dropout_rate,
293
  softmax_scale=None,
294
  causal=is_causal,
295
  window_size=window_size,