flash attention 2
Browse files- docker/Dockerfile-base +1 -1
- src/axolotl/flash_attn.py +3 -3
docker/Dockerfile-base
CHANGED
@@ -40,7 +40,7 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
|
40 |
|
41 |
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
|
42 |
cd flash-attention && \
|
43 |
-
git checkout
|
44 |
python3 setup.py bdist_wheel && \
|
45 |
cd csrc/fused_dense_lib && \
|
46 |
python3 setup.py bdist_wheel && \
|
|
|
40 |
|
41 |
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
|
42 |
cd flash-attention && \
|
43 |
+
git checkout v2.0.0 && \
|
44 |
python3 setup.py bdist_wheel && \
|
45 |
cd csrc/fused_dense_lib && \
|
46 |
python3 setup.py bdist_wheel && \
|
src/axolotl/flash_attn.py
CHANGED
@@ -8,7 +8,7 @@ import torch
|
|
8 |
import transformers
|
9 |
from einops import rearrange
|
10 |
from flash_attn.bert_padding import pad_input, unpad_input
|
11 |
-
from flash_attn.flash_attn_interface import
|
12 |
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
13 |
|
14 |
|
@@ -79,7 +79,7 @@ def forward(
|
|
79 |
dtype=torch.int32,
|
80 |
device=qkv.device,
|
81 |
)
|
82 |
-
output =
|
83 |
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
84 |
)
|
85 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
@@ -95,7 +95,7 @@ def forward(
|
|
95 |
three=3,
|
96 |
h=nheads,
|
97 |
)
|
98 |
-
output_unpad =
|
99 |
x_unpad,
|
100 |
cu_q_lens,
|
101 |
max_s,
|
|
|
8 |
import transformers
|
9 |
from einops import rearrange
|
10 |
from flash_attn.bert_padding import pad_input, unpad_input
|
11 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
12 |
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
13 |
|
14 |
|
|
|
79 |
dtype=torch.int32,
|
80 |
device=qkv.device,
|
81 |
)
|
82 |
+
output = flash_attn_varlen_qkvpacked_func(
|
83 |
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
84 |
)
|
85 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
|
|
95 |
three=3,
|
96 |
h=nheads,
|
97 |
)
|
98 |
+
output_unpad = flash_attn_varlen_qkvpacked_func(
|
99 |
x_unpad,
|
100 |
cu_q_lens,
|
101 |
max_s,
|