winglian commited on
Commit
fbf49a4
·
1 Parent(s): 58cf7e7

is_causal fix for evals?

Browse files
src/axolotl/monkeypatch/llama_attn_hijack_flash.py CHANGED
@@ -155,6 +155,8 @@ def flashattn_forward(
155
  # during training q,k,v always have same seqlen
156
  assert key_states.shape == query_states.shape
157
  is_causal = True
 
 
158
  else:
159
  # turn off FA causal mask after first inference autoregressive iteration
160
  # only on first autoregressive step q,k,v have same seqlen
 
155
  # during training q,k,v always have same seqlen
156
  assert key_states.shape == query_states.shape
157
  is_causal = True
158
+ elif past_key_value is None:
159
+ is_causal = True
160
  else:
161
  # turn off FA causal mask after first inference autoregressive iteration
162
  # only on first autoregressive step q,k,v have same seqlen