tmm1 commited on
Commit
d773384
·
1 Parent(s): 985dcbc

update flash-attn patch for 70B/GQA and inference using helper from flash-attn tests

Browse files
src/axolotl/monkeypatch/llama_attn_hijack_flash.py CHANGED
@@ -2,26 +2,53 @@
2
 
3
  # copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
4
 
 
5
  from typing import Optional, Tuple
6
 
7
  import torch
 
8
  import transformers
9
  from einops import rearrange
10
  from flash_attn.bert_padding import pad_input, unpad_input
 
 
 
11
 
12
  try:
13
- from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
 
 
 
14
  except ImportError:
 
 
 
15
  from flash_attn.flash_attn_interface import (
16
  flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
17
  )
18
 
19
- from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
20
 
21
- from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
 
 
 
 
22
 
23
 
24
- def forward(
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  self,
26
  hidden_states: torch.Tensor,
27
  attention_mask: Optional[torch.Tensor] = None,
@@ -37,124 +64,275 @@ def forward(
37
  # pylint: disable=duplicate-code
38
  bsz, q_len, _ = hidden_states.size()
39
 
40
- query_states = (
41
- self.q_proj(hidden_states)
42
- .view(bsz, q_len, self.num_heads, self.head_dim)
43
- .transpose(1, 2)
44
- )
45
- key_states = (
46
- self.k_proj(hidden_states)
47
- .view(bsz, q_len, self.num_heads, self.head_dim)
48
- .transpose(1, 2)
49
- )
50
- value_states = (
51
- self.v_proj(hidden_states)
52
- .view(bsz, q_len, self.num_heads, self.head_dim)
53
- .transpose(1, 2)
54
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  # [bsz, q_len, nh, hd]
56
  # [bsz, nh, q_len, hd]
57
 
58
  kv_seq_len = key_states.shape[-2]
59
- assert past_key_value is None, "past_key_value is not supported"
 
60
 
61
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
62
  query_states, key_states = apply_rotary_pos_emb(
63
  query_states, key_states, cos, sin, position_ids
64
  )
65
  # [bsz, nh, t, hd]
66
- assert not output_attentions, "output_attentions is not supported"
67
- assert not use_cache, "use_cache is not supported"
68
-
69
- # Flash attention codes from
70
- # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
71
-
72
- # transform the data into the format required by flash attention
73
- qkv = torch.stack(
74
- [query_states, key_states, value_states], dim=2
75
- ) # [bsz, nh, 3, q_len, hd]
76
- qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
77
- # We have disabled _prepare_decoder_attention_mask in LlamaModel
78
- # the attention_mask should be the same as the key_padding_mask
79
- key_padding_mask = attention_mask
80
-
81
- if key_padding_mask is None:
82
- qkv = rearrange(qkv, "b s ... -> (b s) ...")
83
- max_s = q_len
84
- cu_q_lens = torch.arange(
85
- 0,
86
- (bsz + 1) * q_len,
87
- step=q_len,
88
- dtype=torch.int32,
89
- device=qkv.device,
90
- )
91
- output = flash_attn_varlen_qkvpacked_func(
92
- qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
93
  )
94
- output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
95
- elif attention_mask.shape[0] == 1:
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  # special handling using sample packing
 
 
 
 
97
  qkv = rearrange(qkv, "b s ... -> (b s) ...")
98
  cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
99
  cu_q_lens = cu_q_lens.squeeze()
100
 
101
  output = flash_attn_varlen_qkvpacked_func(
102
- qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
103
  )
104
  output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
105
- else:
106
- nheads = qkv.shape[-2]
107
-
108
- # pylint: disable=invalid-name
109
- x = rearrange(qkv, "b s three h d -> b s (three h d)")
110
- x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
111
- x_unpad = rearrange(
112
- x_unpad,
113
- "nnz (three h d) -> nnz three h d",
114
- three=3,
115
- h=nheads,
116
  )
117
  output_unpad = flash_attn_varlen_qkvpacked_func(
118
- x_unpad,
119
- cu_q_lens,
120
- max_s,
121
  0.0,
122
  softmax_scale=None,
123
- causal=True,
124
  )
125
- output = rearrange(
126
- pad_input(
127
- rearrange(output_unpad, "nnz h d -> nnz (h d)"),
128
- indices,
129
- bsz,
130
- q_len,
131
- ),
132
- "b s (h d) -> b s h d",
133
- h=nheads,
 
 
 
 
 
 
 
 
 
134
  )
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- return (
137
- self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
138
- None,
139
- None,
140
- )
 
 
141
 
 
 
 
142
 
143
- # Disable the transformation of the attention mask in LlamaModel as the flash attention
144
- # requires the attention mask to be the same as the key_padding_mask
145
- def _prepare_decoder_attention_mask(
146
- self,
147
- attention_mask,
148
- input_shape,
149
- inputs_embeds,
150
- past_key_values_length,
151
- ): # pylint: disable=unused-argument
152
- # [bsz, seq_len]
153
- return attention_mask
154
 
 
155
 
156
- def replace_llama_attn_with_flash_attn():
157
- transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
158
- _prepare_decoder_attention_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  )
160
- transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
 
2
 
3
  # copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
4
 
5
+ import warnings
6
  from typing import Optional, Tuple
7
 
8
  import torch
9
+ import torch.nn.functional as F
10
  import transformers
11
  from einops import rearrange
12
  from flash_attn.bert_padding import pad_input, unpad_input
13
+ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
14
+
15
+ from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
16
 
17
  try:
18
+ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
19
+ flash_attn_varlen_kvpacked_func,
20
+ flash_attn_varlen_qkvpacked_func,
21
+ )
22
  except ImportError:
23
+ from flash_attn.flash_attn_interface import (
24
+ flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func,
25
+ )
26
  from flash_attn.flash_attn_interface import (
27
  flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
28
  )
29
 
 
30
 
31
+ def replace_llama_attn_with_flash_attn():
32
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
33
+ _prepare_decoder_attention_mask
34
+ )
35
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
36
 
37
 
38
+ # Disable the transformation of the attention mask in LlamaModel as the flash attention
39
+ # requires the attention mask to be the same as the key_padding_mask
40
+ def _prepare_decoder_attention_mask(
41
+ self,
42
+ attention_mask,
43
+ input_shape,
44
+ inputs_embeds,
45
+ past_key_values_length,
46
+ ): # pylint: disable=unused-argument
47
+ # [bsz, seq_len]
48
+ return attention_mask
49
+
50
+
51
+ def flashattn_forward(
52
  self,
53
  hidden_states: torch.Tensor,
54
  attention_mask: Optional[torch.Tensor] = None,
 
64
  # pylint: disable=duplicate-code
65
  bsz, q_len, _ = hidden_states.size()
66
 
67
+ if not hasattr(self, "pretraining_tp"):
68
+ self.pretraining_tp = 1
69
+
70
+ if self.pretraining_tp > 1:
71
+ key_value_slicing = (
72
+ self.num_key_value_heads * self.head_dim
73
+ ) // self.pretraining_tp
74
+ query_slices = self.q_proj.weight.split(
75
+ (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
76
+ )
77
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
78
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
79
+
80
+ query_states = [
81
+ F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
82
+ ]
83
+ query_states = torch.cat(query_states, dim=-1)
84
+
85
+ key_states = [
86
+ F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
87
+ ]
88
+ key_states = torch.cat(key_states, dim=-1)
89
+
90
+ value_states = [
91
+ F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
92
+ ]
93
+ value_states = torch.cat(value_states, dim=-1)
94
+
95
+ else:
96
+ query_states = self.q_proj(hidden_states)
97
+ key_states = self.k_proj(hidden_states)
98
+ value_states = self.v_proj(hidden_states)
99
+
100
+ query_states = query_states.view(
101
+ bsz, q_len, self.num_heads, self.head_dim
102
+ ).transpose(1, 2)
103
+ key_states = key_states.view(
104
+ bsz, q_len, self.num_key_value_heads, self.head_dim
105
+ ).transpose(1, 2)
106
+ value_states = value_states.view(
107
+ bsz, q_len, self.num_key_value_heads, self.head_dim
108
+ ).transpose(1, 2)
109
  # [bsz, q_len, nh, hd]
110
  # [bsz, nh, q_len, hd]
111
 
112
  kv_seq_len = key_states.shape[-2]
113
+ if past_key_value is not None:
114
+ kv_seq_len += past_key_value[0].shape[-2]
115
 
116
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
117
  query_states, key_states = apply_rotary_pos_emb(
118
  query_states, key_states, cos, sin, position_ids
119
  )
120
  # [bsz, nh, t, hd]
121
+
122
+ if past_key_value is not None:
123
+ # reuse k, v, self_attention
124
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
125
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
126
+
127
+ past_key_value = (key_states, value_states) if use_cache else None
128
+
129
+ # repeat k/v heads if n_kv_heads < n_heads
130
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
131
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
132
+
133
+ if output_attentions:
134
+ warnings.warn(
135
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
 
 
 
 
 
 
 
 
 
 
 
 
136
  )
137
+
138
+ #
139
+ # flash-attn v2 start
140
+ #
141
+
142
+ if self.training:
143
+ # during training q,k,v always have same seqlen
144
+ assert key_states.shape == query_states.shape
145
+ is_causal = True
146
+ else:
147
+ # turn off FA causal mask after first inference autoregressive iteration
148
+ # only on first autoregressive step q,k,v have same seqlen
149
+ is_causal = key_states.shape == query_states.shape
150
+
151
+ if self.training and attention_mask.shape[0] == 1:
152
  # special handling using sample packing
153
+ qkv = torch.stack(
154
+ [query_states, key_states, value_states], dim=2
155
+ ) # [bsz, nh, 3, q_len, hd]
156
+ qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
157
  qkv = rearrange(qkv, "b s ... -> (b s) ...")
158
  cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
159
  cu_q_lens = cu_q_lens.squeeze()
160
 
161
  output = flash_attn_varlen_qkvpacked_func(
162
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=is_causal
163
  )
164
  output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
165
+ elif query_states.shape == key_states.shape:
166
+ qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
167
+ query_states.transpose(1, 2),
168
+ key_states.transpose(1, 2),
169
+ value_states.transpose(1, 2),
170
+ qkvpacked=True,
171
+ # We have disabled _prepare_decoder_attention_mask in LlamaModel
172
+ # the attention_mask should be the same as the key_padding_mask
173
+ key_padding_mask=attention_mask,
 
 
174
  )
175
  output_unpad = flash_attn_varlen_qkvpacked_func(
176
+ qkv_unpad,
177
+ cu_seqlens_q,
178
+ max_seqlen_q,
179
  0.0,
180
  softmax_scale=None,
181
+ causal=is_causal,
182
  )
183
+ output = output_pad_fn(output_unpad)
184
+ else:
185
+ ( # pylint: disable=unbalanced-tuple-unpacking
186
+ q_unpad,
187
+ kv_unpad,
188
+ cu_seqlens_q,
189
+ cu_seqlens_k,
190
+ max_seqlen_q,
191
+ max_seqlen_k,
192
+ _,
193
+ _,
194
+ output_pad_fn,
195
+ ) = generate_qkv(
196
+ query_states.transpose(1, 2),
197
+ key_states.transpose(1, 2),
198
+ value_states.transpose(1, 2),
199
+ kvpacked=True,
200
+ key_padding_mask=attention_mask,
201
  )
202
+ output_unpad = flash_attn_varlen_kvpacked_func(
203
+ q_unpad,
204
+ kv_unpad,
205
+ cu_seqlens_q,
206
+ cu_seqlens_k,
207
+ max_seqlen_q,
208
+ max_seqlen_k,
209
+ 0.0,
210
+ softmax_scale=None,
211
+ causal=is_causal,
212
+ )
213
+ output = output_pad_fn(output_unpad)
214
 
215
+ attn_output = output
216
+ if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
217
+ raise ValueError(
218
+ f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
219
+ f" {attn_output.size()}"
220
+ )
221
+ attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
222
 
223
+ #
224
+ # flash-attn v2 end
225
+ #
226
 
227
+ if self.pretraining_tp > 1:
228
+ attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
229
+ o_proj_slices = self.o_proj.weight.split(
230
+ self.hidden_size // self.pretraining_tp, dim=1
231
+ )
232
+ attn_output = sum(
233
+ F.linear(attn_output[i], o_proj_slices[i])
234
+ for i in range(self.pretraining_tp)
235
+ )
236
+ else:
237
+ attn_output = self.o_proj(attn_output)
238
 
239
+ return attn_output, None, past_key_value
240
 
241
+
242
+ # based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
243
+ def generate_qkv(
244
+ q,
245
+ k,
246
+ v,
247
+ query_padding_mask=None,
248
+ key_padding_mask=None,
249
+ kvpacked=False,
250
+ qkvpacked=False,
251
+ ): # pylint: disable=invalid-name,unnecessary-lambda-assignment
252
+ """
253
+ Arguments:
254
+ q: (batch_size, seqlen_q, nheads, d)
255
+ k: (batch_size, seqlen_k, nheads_k, d)
256
+ v: (batch_size, seqlen_k, nheads_k, d)
257
+ query_padding_mask: (batch_size, seqlen), bool
258
+ key_padding_mask: (batch_size, seqlen), bool
259
+ """
260
+ assert not (kvpacked and qkvpacked)
261
+ batch_size, seqlen_q, nheads, d = q.shape
262
+ _, seqlen_k, nheads_k, _ = k.shape
263
+ assert k.shape == (batch_size, seqlen_k, nheads_k, d)
264
+ assert v.shape == (batch_size, seqlen_k, nheads_k, d)
265
+
266
+ if query_padding_mask is not None:
267
+ q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
268
+ q, query_padding_mask
269
+ )
270
+
271
+ output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
272
+ output_unpad, indices_q, batch_size, seqlen_q
273
+ )
274
+
275
+ else:
276
+ q_unpad = rearrange(q, "b s h d -> (b s) h d")
277
+ cu_seqlens_q = torch.arange(
278
+ 0,
279
+ (batch_size + 1) * seqlen_q,
280
+ step=seqlen_q,
281
+ dtype=torch.int32,
282
+ device=q_unpad.device,
283
+ )
284
+ max_seqlen_q = seqlen_q
285
+
286
+ output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
287
+ output_unpad, "(b s) h d -> b s h d", b=batch_size
288
+ )
289
+
290
+ if key_padding_mask is not None:
291
+ k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
292
+ v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
293
+ else:
294
+ k_unpad = rearrange(k, "b s h d -> (b s) h d")
295
+ v_unpad = rearrange(v, "b s h d -> (b s) h d")
296
+ cu_seqlens_k = torch.arange(
297
+ 0,
298
+ (batch_size + 1) * seqlen_k,
299
+ step=seqlen_k,
300
+ dtype=torch.int32,
301
+ device=k_unpad.device,
302
+ )
303
+ max_seqlen_k = seqlen_k
304
+
305
+ if qkvpacked:
306
+ assert nheads == nheads_k
307
+ qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
308
+ qkv = torch.stack([q, k, v], dim=2)
309
+ return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
310
+
311
+ if kvpacked:
312
+ kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
313
+ kv = torch.stack([k, v], dim=2)
314
+ return (
315
+ q_unpad,
316
+ kv_unpad,
317
+ cu_seqlens_q,
318
+ cu_seqlens_k,
319
+ max_seqlen_q,
320
+ max_seqlen_k,
321
+ q,
322
+ kv,
323
+ output_pad_fn,
324
+ )
325
+
326
+ return (
327
+ q_unpad,
328
+ k_unpad,
329
+ v_unpad,
330
+ cu_seqlens_q,
331
+ cu_seqlens_k,
332
+ max_seqlen_q,
333
+ max_seqlen_k,
334
+ q,
335
+ k,
336
+ v,
337
+ output_pad_fn,
338
  )