update flash-attn patch for 70B/GQA and inference using helper from flash-attn tests
Browse files
src/axolotl/monkeypatch/llama_attn_hijack_flash.py
CHANGED
@@ -2,26 +2,53 @@
|
|
2 |
|
3 |
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
4 |
|
|
|
5 |
from typing import Optional, Tuple
|
6 |
|
7 |
import torch
|
|
|
8 |
import transformers
|
9 |
from einops import rearrange
|
10 |
from flash_attn.bert_padding import pad_input, unpad_input
|
|
|
|
|
|
|
11 |
|
12 |
try:
|
13 |
-
from flash_attn.flash_attn_interface import
|
|
|
|
|
|
|
14 |
except ImportError:
|
|
|
|
|
|
|
15 |
from flash_attn.flash_attn_interface import (
|
16 |
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
|
17 |
)
|
18 |
|
19 |
-
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
20 |
|
21 |
-
|
|
|
|
|
|
|
|
|
22 |
|
23 |
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
self,
|
26 |
hidden_states: torch.Tensor,
|
27 |
attention_mask: Optional[torch.Tensor] = None,
|
@@ -37,124 +64,275 @@ def forward(
|
|
37 |
# pylint: disable=duplicate-code
|
38 |
bsz, q_len, _ = hidden_states.size()
|
39 |
|
40 |
-
|
41 |
-
self.
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
self.
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
self.v_proj(
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
# [bsz, q_len, nh, hd]
|
56 |
# [bsz, nh, q_len, hd]
|
57 |
|
58 |
kv_seq_len = key_states.shape[-2]
|
59 |
-
|
|
|
60 |
|
61 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
62 |
query_states, key_states = apply_rotary_pos_emb(
|
63 |
query_states, key_states, cos, sin, position_ids
|
64 |
)
|
65 |
# [bsz, nh, t, hd]
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
if key_padding_mask is None:
|
82 |
-
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
83 |
-
max_s = q_len
|
84 |
-
cu_q_lens = torch.arange(
|
85 |
-
0,
|
86 |
-
(bsz + 1) * q_len,
|
87 |
-
step=q_len,
|
88 |
-
dtype=torch.int32,
|
89 |
-
device=qkv.device,
|
90 |
-
)
|
91 |
-
output = flash_attn_varlen_qkvpacked_func(
|
92 |
-
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
93 |
)
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
# special handling using sample packing
|
|
|
|
|
|
|
|
|
97 |
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
98 |
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
|
99 |
cu_q_lens = cu_q_lens.squeeze()
|
100 |
|
101 |
output = flash_attn_varlen_qkvpacked_func(
|
102 |
-
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=
|
103 |
)
|
104 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
three=3,
|
115 |
-
h=nheads,
|
116 |
)
|
117 |
output_unpad = flash_attn_varlen_qkvpacked_func(
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
0.0,
|
122 |
softmax_scale=None,
|
123 |
-
causal=
|
124 |
)
|
125 |
-
output =
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
|
|
|
|
141 |
|
|
|
|
|
|
|
142 |
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
)
|
152 |
-
|
153 |
-
|
154 |
|
|
|
155 |
|
156 |
-
|
157 |
-
|
158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
)
|
160 |
-
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
|
|
2 |
|
3 |
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
4 |
|
5 |
+
import warnings
|
6 |
from typing import Optional, Tuple
|
7 |
|
8 |
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
import transformers
|
11 |
from einops import rearrange
|
12 |
from flash_attn.bert_padding import pad_input, unpad_input
|
13 |
+
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
14 |
+
|
15 |
+
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
16 |
|
17 |
try:
|
18 |
+
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
19 |
+
flash_attn_varlen_kvpacked_func,
|
20 |
+
flash_attn_varlen_qkvpacked_func,
|
21 |
+
)
|
22 |
except ImportError:
|
23 |
+
from flash_attn.flash_attn_interface import (
|
24 |
+
flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func,
|
25 |
+
)
|
26 |
from flash_attn.flash_attn_interface import (
|
27 |
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
|
28 |
)
|
29 |
|
|
|
30 |
|
31 |
+
def replace_llama_attn_with_flash_attn():
|
32 |
+
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
33 |
+
_prepare_decoder_attention_mask
|
34 |
+
)
|
35 |
+
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
|
36 |
|
37 |
|
38 |
+
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
39 |
+
# requires the attention mask to be the same as the key_padding_mask
|
40 |
+
def _prepare_decoder_attention_mask(
|
41 |
+
self,
|
42 |
+
attention_mask,
|
43 |
+
input_shape,
|
44 |
+
inputs_embeds,
|
45 |
+
past_key_values_length,
|
46 |
+
): # pylint: disable=unused-argument
|
47 |
+
# [bsz, seq_len]
|
48 |
+
return attention_mask
|
49 |
+
|
50 |
+
|
51 |
+
def flashattn_forward(
|
52 |
self,
|
53 |
hidden_states: torch.Tensor,
|
54 |
attention_mask: Optional[torch.Tensor] = None,
|
|
|
64 |
# pylint: disable=duplicate-code
|
65 |
bsz, q_len, _ = hidden_states.size()
|
66 |
|
67 |
+
if not hasattr(self, "pretraining_tp"):
|
68 |
+
self.pretraining_tp = 1
|
69 |
+
|
70 |
+
if self.pretraining_tp > 1:
|
71 |
+
key_value_slicing = (
|
72 |
+
self.num_key_value_heads * self.head_dim
|
73 |
+
) // self.pretraining_tp
|
74 |
+
query_slices = self.q_proj.weight.split(
|
75 |
+
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
|
76 |
+
)
|
77 |
+
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
78 |
+
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
79 |
+
|
80 |
+
query_states = [
|
81 |
+
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
|
82 |
+
]
|
83 |
+
query_states = torch.cat(query_states, dim=-1)
|
84 |
+
|
85 |
+
key_states = [
|
86 |
+
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
|
87 |
+
]
|
88 |
+
key_states = torch.cat(key_states, dim=-1)
|
89 |
+
|
90 |
+
value_states = [
|
91 |
+
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
|
92 |
+
]
|
93 |
+
value_states = torch.cat(value_states, dim=-1)
|
94 |
+
|
95 |
+
else:
|
96 |
+
query_states = self.q_proj(hidden_states)
|
97 |
+
key_states = self.k_proj(hidden_states)
|
98 |
+
value_states = self.v_proj(hidden_states)
|
99 |
+
|
100 |
+
query_states = query_states.view(
|
101 |
+
bsz, q_len, self.num_heads, self.head_dim
|
102 |
+
).transpose(1, 2)
|
103 |
+
key_states = key_states.view(
|
104 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
105 |
+
).transpose(1, 2)
|
106 |
+
value_states = value_states.view(
|
107 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
108 |
+
).transpose(1, 2)
|
109 |
# [bsz, q_len, nh, hd]
|
110 |
# [bsz, nh, q_len, hd]
|
111 |
|
112 |
kv_seq_len = key_states.shape[-2]
|
113 |
+
if past_key_value is not None:
|
114 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
115 |
|
116 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
117 |
query_states, key_states = apply_rotary_pos_emb(
|
118 |
query_states, key_states, cos, sin, position_ids
|
119 |
)
|
120 |
# [bsz, nh, t, hd]
|
121 |
+
|
122 |
+
if past_key_value is not None:
|
123 |
+
# reuse k, v, self_attention
|
124 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
125 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
126 |
+
|
127 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
128 |
+
|
129 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
130 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
131 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
132 |
+
|
133 |
+
if output_attentions:
|
134 |
+
warnings.warn(
|
135 |
+
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
)
|
137 |
+
|
138 |
+
#
|
139 |
+
# flash-attn v2 start
|
140 |
+
#
|
141 |
+
|
142 |
+
if self.training:
|
143 |
+
# during training q,k,v always have same seqlen
|
144 |
+
assert key_states.shape == query_states.shape
|
145 |
+
is_causal = True
|
146 |
+
else:
|
147 |
+
# turn off FA causal mask after first inference autoregressive iteration
|
148 |
+
# only on first autoregressive step q,k,v have same seqlen
|
149 |
+
is_causal = key_states.shape == query_states.shape
|
150 |
+
|
151 |
+
if self.training and attention_mask.shape[0] == 1:
|
152 |
# special handling using sample packing
|
153 |
+
qkv = torch.stack(
|
154 |
+
[query_states, key_states, value_states], dim=2
|
155 |
+
) # [bsz, nh, 3, q_len, hd]
|
156 |
+
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
157 |
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
158 |
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
|
159 |
cu_q_lens = cu_q_lens.squeeze()
|
160 |
|
161 |
output = flash_attn_varlen_qkvpacked_func(
|
162 |
+
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=is_causal
|
163 |
)
|
164 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
165 |
+
elif query_states.shape == key_states.shape:
|
166 |
+
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
167 |
+
query_states.transpose(1, 2),
|
168 |
+
key_states.transpose(1, 2),
|
169 |
+
value_states.transpose(1, 2),
|
170 |
+
qkvpacked=True,
|
171 |
+
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
172 |
+
# the attention_mask should be the same as the key_padding_mask
|
173 |
+
key_padding_mask=attention_mask,
|
|
|
|
|
174 |
)
|
175 |
output_unpad = flash_attn_varlen_qkvpacked_func(
|
176 |
+
qkv_unpad,
|
177 |
+
cu_seqlens_q,
|
178 |
+
max_seqlen_q,
|
179 |
0.0,
|
180 |
softmax_scale=None,
|
181 |
+
causal=is_causal,
|
182 |
)
|
183 |
+
output = output_pad_fn(output_unpad)
|
184 |
+
else:
|
185 |
+
( # pylint: disable=unbalanced-tuple-unpacking
|
186 |
+
q_unpad,
|
187 |
+
kv_unpad,
|
188 |
+
cu_seqlens_q,
|
189 |
+
cu_seqlens_k,
|
190 |
+
max_seqlen_q,
|
191 |
+
max_seqlen_k,
|
192 |
+
_,
|
193 |
+
_,
|
194 |
+
output_pad_fn,
|
195 |
+
) = generate_qkv(
|
196 |
+
query_states.transpose(1, 2),
|
197 |
+
key_states.transpose(1, 2),
|
198 |
+
value_states.transpose(1, 2),
|
199 |
+
kvpacked=True,
|
200 |
+
key_padding_mask=attention_mask,
|
201 |
)
|
202 |
+
output_unpad = flash_attn_varlen_kvpacked_func(
|
203 |
+
q_unpad,
|
204 |
+
kv_unpad,
|
205 |
+
cu_seqlens_q,
|
206 |
+
cu_seqlens_k,
|
207 |
+
max_seqlen_q,
|
208 |
+
max_seqlen_k,
|
209 |
+
0.0,
|
210 |
+
softmax_scale=None,
|
211 |
+
causal=is_causal,
|
212 |
+
)
|
213 |
+
output = output_pad_fn(output_unpad)
|
214 |
|
215 |
+
attn_output = output
|
216 |
+
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
217 |
+
raise ValueError(
|
218 |
+
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
219 |
+
f" {attn_output.size()}"
|
220 |
+
)
|
221 |
+
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
222 |
|
223 |
+
#
|
224 |
+
# flash-attn v2 end
|
225 |
+
#
|
226 |
|
227 |
+
if self.pretraining_tp > 1:
|
228 |
+
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
229 |
+
o_proj_slices = self.o_proj.weight.split(
|
230 |
+
self.hidden_size // self.pretraining_tp, dim=1
|
231 |
+
)
|
232 |
+
attn_output = sum(
|
233 |
+
F.linear(attn_output[i], o_proj_slices[i])
|
234 |
+
for i in range(self.pretraining_tp)
|
235 |
+
)
|
236 |
+
else:
|
237 |
+
attn_output = self.o_proj(attn_output)
|
238 |
|
239 |
+
return attn_output, None, past_key_value
|
240 |
|
241 |
+
|
242 |
+
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
|
243 |
+
def generate_qkv(
|
244 |
+
q,
|
245 |
+
k,
|
246 |
+
v,
|
247 |
+
query_padding_mask=None,
|
248 |
+
key_padding_mask=None,
|
249 |
+
kvpacked=False,
|
250 |
+
qkvpacked=False,
|
251 |
+
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
|
252 |
+
"""
|
253 |
+
Arguments:
|
254 |
+
q: (batch_size, seqlen_q, nheads, d)
|
255 |
+
k: (batch_size, seqlen_k, nheads_k, d)
|
256 |
+
v: (batch_size, seqlen_k, nheads_k, d)
|
257 |
+
query_padding_mask: (batch_size, seqlen), bool
|
258 |
+
key_padding_mask: (batch_size, seqlen), bool
|
259 |
+
"""
|
260 |
+
assert not (kvpacked and qkvpacked)
|
261 |
+
batch_size, seqlen_q, nheads, d = q.shape
|
262 |
+
_, seqlen_k, nheads_k, _ = k.shape
|
263 |
+
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
264 |
+
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
265 |
+
|
266 |
+
if query_padding_mask is not None:
|
267 |
+
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
|
268 |
+
q, query_padding_mask
|
269 |
+
)
|
270 |
+
|
271 |
+
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
|
272 |
+
output_unpad, indices_q, batch_size, seqlen_q
|
273 |
+
)
|
274 |
+
|
275 |
+
else:
|
276 |
+
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
277 |
+
cu_seqlens_q = torch.arange(
|
278 |
+
0,
|
279 |
+
(batch_size + 1) * seqlen_q,
|
280 |
+
step=seqlen_q,
|
281 |
+
dtype=torch.int32,
|
282 |
+
device=q_unpad.device,
|
283 |
+
)
|
284 |
+
max_seqlen_q = seqlen_q
|
285 |
+
|
286 |
+
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
|
287 |
+
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
288 |
+
)
|
289 |
+
|
290 |
+
if key_padding_mask is not None:
|
291 |
+
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
292 |
+
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
|
293 |
+
else:
|
294 |
+
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
295 |
+
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
296 |
+
cu_seqlens_k = torch.arange(
|
297 |
+
0,
|
298 |
+
(batch_size + 1) * seqlen_k,
|
299 |
+
step=seqlen_k,
|
300 |
+
dtype=torch.int32,
|
301 |
+
device=k_unpad.device,
|
302 |
+
)
|
303 |
+
max_seqlen_k = seqlen_k
|
304 |
+
|
305 |
+
if qkvpacked:
|
306 |
+
assert nheads == nheads_k
|
307 |
+
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
308 |
+
qkv = torch.stack([q, k, v], dim=2)
|
309 |
+
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
|
310 |
+
|
311 |
+
if kvpacked:
|
312 |
+
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
313 |
+
kv = torch.stack([k, v], dim=2)
|
314 |
+
return (
|
315 |
+
q_unpad,
|
316 |
+
kv_unpad,
|
317 |
+
cu_seqlens_q,
|
318 |
+
cu_seqlens_k,
|
319 |
+
max_seqlen_q,
|
320 |
+
max_seqlen_k,
|
321 |
+
q,
|
322 |
+
kv,
|
323 |
+
output_pad_fn,
|
324 |
+
)
|
325 |
+
|
326 |
+
return (
|
327 |
+
q_unpad,
|
328 |
+
k_unpad,
|
329 |
+
v_unpad,
|
330 |
+
cu_seqlens_q,
|
331 |
+
cu_seqlens_k,
|
332 |
+
max_seqlen_q,
|
333 |
+
max_seqlen_k,
|
334 |
+
q,
|
335 |
+
k,
|
336 |
+
v,
|
337 |
+
output_pad_fn,
|
338 |
)
|
|