Move flash_attn assert from __init__ into calling func (#32)
Browse files- Change the assert to warning in __init__ (e4f4e18f099e6fd012bfed0bd3170dc2ea209592)
- Update modeling_phi3_small.py (e203eff1c173b6ff6f356c6a9e9e9fb9f2371c72)
- Update modeling_phi3_small.py (bef256d02055a0870d55b9d21fc35d68bdee1c61)
- Moved flash_attn assert close to the caller (e4770004838da0fec2d9f8cdb60f8417e5394ba7)
Co-authored-by: Roger Feng <[email protected]>
- modeling_phi3_small.py +2 -1
modeling_phi3_small.py
CHANGED
@@ -215,7 +215,6 @@ class Phi3SmallSelfAttention(nn.Module):
|
|
215 |
f"Layer {layer_idx + 1} is using dense attention since it is divisible by "
|
216 |
f"{self.config.dense_attention_every_n_layers}"
|
217 |
)
|
218 |
-
assert is_flash_attention_available, "Flash Attention is not available, but is needed for dense attention"
|
219 |
else:
|
220 |
# BlockSparse related Parameters
|
221 |
self.blocksparse_params = BlockSparseParams.from_config(config)
|
@@ -419,6 +418,8 @@ class Phi3SmallSelfAttention(nn.Module):
|
|
419 |
avoid doing that.
|
420 |
|
421 |
"""
|
|
|
|
|
422 |
attention_dropout_prob = self.attention_dropout_rate if self.training else 0.0
|
423 |
# Get into the correct shape for the Flash Attention API
|
424 |
# shape: (bs, seq_len, nqp, hn)
|
|
|
215 |
f"Layer {layer_idx + 1} is using dense attention since it is divisible by "
|
216 |
f"{self.config.dense_attention_every_n_layers}"
|
217 |
)
|
|
|
218 |
else:
|
219 |
# BlockSparse related Parameters
|
220 |
self.blocksparse_params = BlockSparseParams.from_config(config)
|
|
|
418 |
avoid doing that.
|
419 |
|
420 |
"""
|
421 |
+
assert is_flash_attention_available, "Flash Attention is not available, but is needed for dense attention"
|
422 |
+
|
423 |
attention_dropout_prob = self.attention_dropout_rate if self.training else 0.0
|
424 |
# Get into the correct shape for the Flash Attention API
|
425 |
# shape: (bs, seq_len, nqp, hn)
|