don't worry about dupes
Browse files
src/axolotl/flash_attn.py
CHANGED
@@ -25,6 +25,7 @@ def forward(
|
|
25 |
|
26 |
attention_mask: [bsz, q_len]
|
27 |
"""
|
|
|
28 |
bsz, q_len, _ = hidden_states.size()
|
29 |
|
30 |
query_states = (
|
|
|
25 |
|
26 |
attention_mask: [bsz, q_len]
|
27 |
"""
|
28 |
+
# pylint: disable=duplicate-code
|
29 |
bsz, q_len, _ = hidden_states.size()
|
30 |
|
31 |
query_states = (
|
src/axolotl/monkeypatch/llama_attn_hijack_xformers.py
CHANGED
@@ -35,6 +35,7 @@ def xformers_forward(
|
|
35 |
output_attentions: bool = False,
|
36 |
use_cache: bool = False,
|
37 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
38 |
bsz, q_len, _ = hidden_states.size()
|
39 |
|
40 |
query_states = (
|
@@ -143,6 +144,7 @@ def sdp_attention_forward(
|
|
143 |
output_attentions: bool = False,
|
144 |
use_cache: bool = False,
|
145 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
146 |
bsz, q_len, _ = hidden_states.size()
|
147 |
|
148 |
query_states = (
|
|
|
35 |
output_attentions: bool = False,
|
36 |
use_cache: bool = False,
|
37 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
38 |
+
# pylint: disable=duplicate-code
|
39 |
bsz, q_len, _ = hidden_states.size()
|
40 |
|
41 |
query_states = (
|
|
|
144 |
output_attentions: bool = False,
|
145 |
use_cache: bool = False,
|
146 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
147 |
+
# pylint: disable=duplicate-code
|
148 |
bsz, q_len, _ = hidden_states.size()
|
149 |
|
150 |
query_states = (
|