winglian commited on
Commit
b6ab8aa
·
unverified ·
1 Parent(s): 85b0be2

Mistral flash attn packing (#646)

Browse files

* add mistral monkeypatch

* add arg for decoder attention masl

* fix lint for duplicate code

* make sure to update transformers too

* tweak install for e2e

* move mistral patch to conditional

.github/workflows/tests.yml CHANGED
@@ -44,7 +44,7 @@ jobs:
44
 
45
  - name: Install dependencies
46
  run: |
47
- pip3 install -e .
48
  pip3 install -r requirements-tests.txt
49
 
50
  - name: Run tests
@@ -69,8 +69,7 @@ jobs:
69
 
70
  - name: Install dependencies
71
  run: |
72
- pip3 install -e .
73
- pip3 install flash-attn
74
  pip3 install -r requirements-tests.txt
75
 
76
  - name: Run e2e tests
 
44
 
45
  - name: Install dependencies
46
  run: |
47
+ pip3 install -U -e .
48
  pip3 install -r requirements-tests.txt
49
 
50
  - name: Run tests
 
69
 
70
  - name: Install dependencies
71
  run: |
72
+ pip3 install -U -e .[flash-attn]
 
73
  pip3 install -r requirements-tests.txt
74
 
75
  - name: Run e2e tests
requirements.txt CHANGED
@@ -4,7 +4,7 @@ torch==2.0.1
4
  auto-gptq
5
  packaging
6
  peft @ git+https://github.com/huggingface/peft.git
7
- transformers @ git+https://github.com/huggingface/transformers.git@0ac3875011d32dc85e0e83970507e3afe8f0febb
8
  bitsandbytes>=0.41.1
9
  accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
10
  deepspeed
 
4
  auto-gptq
5
  packaging
6
  peft @ git+https://github.com/huggingface/peft.git
7
+ transformers @ git+https://github.com/huggingface/transformers.git@78dd120
8
  bitsandbytes>=0.41.1
9
  accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
10
  deepspeed
src/axolotl/monkeypatch/mistral_attn_hijack_flash.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Flash attention monkey patch for mistral model"""
2
+ # pylint: disable=duplicate-code
3
+
4
+ import logging
5
+ import math
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ import torch
9
+ import transformers
10
+ from einops import rearrange
11
+ from torch import nn
12
+ from transformers.modeling_outputs import BaseModelOutputWithPast
13
+ from transformers.models.mistral.modeling_mistral import (
14
+ MistralDecoderLayer as OriginalMistralDecoderLayer,
15
+ )
16
+ from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
17
+
18
+ from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
19
+
20
+ try:
21
+ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
22
+ flash_attn_varlen_qkvpacked_func,
23
+ )
24
+ except ImportError:
25
+ from flash_attn.flash_attn_interface import (
26
+ flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
27
+ )
28
+
29
+
30
+ LOG = logging.getLogger("axolotl.monkeypatch.mistral")
31
+
32
+
33
+ def replace_mistral_attn_with_flash_attn(
34
+ packed: Optional[bool] = False,
35
+ ):
36
+ transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
37
+ _prepare_decoder_attention_mask
38
+ )
39
+ transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
40
+ flashattn_forward
41
+ )
42
+ if packed:
43
+ transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
44
+ MistralDecoderLayer
45
+ )
46
+ transformers.models.mistral.modeling_mistral.MistralModel.forward = (
47
+ mistral_model_forward
48
+ )
49
+
50
+
51
+ # Disable the transformation of the attention mask in LlamaModel as the flash attention
52
+ # requires the attention mask to be the same as the key_padding_mask
53
+ def _prepare_decoder_attention_mask(
54
+ self,
55
+ attention_mask,
56
+ input_shape,
57
+ inputs_embeds,
58
+ past_key_values_length,
59
+ sliding_window,
60
+ ): # pylint: disable=unused-argument
61
+ # [bsz, seq_len]
62
+ return attention_mask
63
+
64
+
65
+ def flashattn_forward(
66
+ self,
67
+ hidden_states: torch.Tensor,
68
+ attention_mask: Optional[torch.Tensor] = None,
69
+ position_ids: Optional[torch.LongTensor] = None,
70
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
71
+ output_attentions: bool = False,
72
+ use_cache: bool = False,
73
+ cu_seqlens: Optional[torch.Tensor] = None,
74
+ max_seqlen: Optional[torch.Tensor] = None,
75
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
76
+ bsz, q_len, _ = hidden_states.size()
77
+
78
+ query_states = self.q_proj(hidden_states)
79
+ key_states = self.k_proj(hidden_states)
80
+ value_states = self.v_proj(hidden_states)
81
+
82
+ query_states = query_states.view(
83
+ bsz, q_len, self.num_heads, self.head_dim
84
+ ).transpose(1, 2)
85
+ key_states = key_states.view(
86
+ bsz, q_len, self.num_key_value_heads, self.head_dim
87
+ ).transpose(1, 2)
88
+ value_states = value_states.view(
89
+ bsz, q_len, self.num_key_value_heads, self.head_dim
90
+ ).transpose(1, 2)
91
+
92
+ kv_seq_len = key_states.shape[-2]
93
+ if past_key_value is not None:
94
+ kv_seq_len += past_key_value[0].shape[-2]
95
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
96
+ query_states, key_states = apply_rotary_pos_emb(
97
+ query_states, key_states, cos, sin, position_ids
98
+ )
99
+
100
+ if past_key_value is not None:
101
+ # reuse k, v, self_attention
102
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
103
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
104
+
105
+ past_key_value = (key_states, value_states) if use_cache else None
106
+
107
+ # repeat k/v heads if n_kv_heads < n_heads
108
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
109
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
110
+
111
+ if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
112
+ # special handling using sample packing
113
+ qkv = torch.stack(
114
+ [query_states, key_states, value_states], dim=2
115
+ ) # [bsz, nh, 3, q_len, hd]
116
+ qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
117
+ qkv = rearrange(qkv, "b s ... -> (b s) ...")
118
+
119
+ output = flash_attn_varlen_qkvpacked_func(
120
+ qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
121
+ )
122
+ output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
123
+ attn_output = output
124
+ if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
125
+ raise ValueError(
126
+ f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
127
+ f" {attn_output.size()}"
128
+ )
129
+ attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
130
+ attn_weights = None
131
+ else:
132
+ attn_weights = torch.matmul(
133
+ query_states, key_states.transpose(2, 3)
134
+ ) / math.sqrt(self.head_dim)
135
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
136
+ raise ValueError(
137
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
138
+ f" {attn_weights.size()}"
139
+ )
140
+
141
+ if attention_mask is not None:
142
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
143
+ raise ValueError(
144
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
145
+ )
146
+
147
+ attn_weights = attn_weights + attention_mask
148
+
149
+ # upcast attention to fp32
150
+ attn_weights = nn.functional.softmax(
151
+ attn_weights, dim=-1, dtype=torch.float32
152
+ ).to(query_states.dtype)
153
+ attn_output = torch.matmul(attn_weights, value_states)
154
+
155
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
156
+ raise ValueError(
157
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
158
+ f" {attn_output.size()}"
159
+ )
160
+
161
+ attn_output = attn_output.transpose(1, 2).contiguous()
162
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
163
+
164
+ attn_output = self.o_proj(attn_output)
165
+
166
+ if not output_attentions:
167
+ attn_weights = None
168
+
169
+ return attn_output, attn_weights, past_key_value
170
+
171
+
172
+ def mistral_model_forward(
173
+ self,
174
+ input_ids: torch.LongTensor = None,
175
+ attention_mask: Optional[torch.Tensor] = None,
176
+ position_ids: Optional[torch.LongTensor] = None,
177
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
178
+ inputs_embeds: Optional[torch.FloatTensor] = None,
179
+ use_cache: Optional[bool] = None,
180
+ output_attentions: Optional[bool] = None,
181
+ output_hidden_states: Optional[bool] = None,
182
+ return_dict: Optional[bool] = None,
183
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
184
+ output_attentions = (
185
+ output_attentions
186
+ if output_attentions is not None
187
+ else self.config.output_attentions
188
+ )
189
+ output_hidden_states = (
190
+ output_hidden_states
191
+ if output_hidden_states is not None
192
+ else self.config.output_hidden_states
193
+ )
194
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
195
+
196
+ return_dict = (
197
+ return_dict if return_dict is not None else self.config.use_return_dict
198
+ )
199
+
200
+ # retrieve input_ids and inputs_embeds
201
+ if input_ids is not None and inputs_embeds is not None:
202
+ raise ValueError(
203
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
204
+ )
205
+ if input_ids is not None:
206
+ batch_size, seq_length = input_ids.shape
207
+ elif inputs_embeds is not None:
208
+ batch_size, seq_length, _ = inputs_embeds.shape
209
+ else:
210
+ raise ValueError(
211
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
212
+ )
213
+
214
+ seq_length_with_past = seq_length
215
+ past_key_values_length = 0
216
+
217
+ if past_key_values is not None:
218
+ past_key_values_length = past_key_values[0][0].shape[2]
219
+ seq_length_with_past = seq_length_with_past + past_key_values_length
220
+
221
+ cu_seqlens = None
222
+ max_seqlen = None
223
+ if position_ids is None:
224
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
225
+ position_ids = torch.arange(
226
+ past_key_values_length,
227
+ seq_length + past_key_values_length,
228
+ dtype=torch.long,
229
+ device=device,
230
+ )
231
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
232
+ else:
233
+ position_ids = position_ids.view(-1, seq_length).long()
234
+ cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
235
+ cu_seqlens = cu_seqlens.squeeze()
236
+
237
+ if inputs_embeds is None:
238
+ inputs_embeds = self.embed_tokens(input_ids)
239
+ # embed positions
240
+ if attention_mask is None:
241
+ attention_mask = torch.ones(
242
+ (batch_size, seq_length_with_past),
243
+ dtype=torch.bool,
244
+ device=inputs_embeds.device,
245
+ )
246
+ attention_mask = (
247
+ self._prepare_decoder_attention_mask( # pylint: disable=protected-access
248
+ attention_mask,
249
+ (batch_size, seq_length),
250
+ inputs_embeds,
251
+ past_key_values_length,
252
+ sliding_window=self.config.sliding_window,
253
+ )
254
+ )
255
+
256
+ hidden_states = inputs_embeds
257
+
258
+ if self.gradient_checkpointing and self.training:
259
+ if use_cache:
260
+ transformers.logger.warning_once(
261
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
262
+ )
263
+ use_cache = False
264
+
265
+ # decoder layers
266
+ all_hidden_states = () if output_hidden_states else None
267
+ all_self_attns = () if output_attentions else None
268
+ next_decoder_cache = () if use_cache else None
269
+
270
+ for idx, decoder_layer in enumerate(self.layers):
271
+ if output_hidden_states:
272
+ all_hidden_states += (hidden_states,)
273
+
274
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
275
+
276
+ if self.gradient_checkpointing and self.training:
277
+
278
+ def create_custom_forward(module):
279
+ def custom_forward(*inputs):
280
+ # None for past_key_value
281
+ return module(*inputs)
282
+
283
+ return custom_forward
284
+
285
+ layer_outputs = torch.utils.checkpoint.checkpoint(
286
+ create_custom_forward(decoder_layer),
287
+ hidden_states,
288
+ attention_mask,
289
+ position_ids,
290
+ past_key_value,
291
+ output_attentions,
292
+ None,
293
+ cu_seqlens,
294
+ max_seqlen,
295
+ )
296
+ else:
297
+ layer_outputs = decoder_layer(
298
+ hidden_states,
299
+ attention_mask=attention_mask,
300
+ position_ids=position_ids,
301
+ past_key_value=past_key_value,
302
+ output_attentions=output_attentions,
303
+ use_cache=use_cache,
304
+ cu_seqlens=cu_seqlens,
305
+ max_seqlen=max_seqlen,
306
+ )
307
+
308
+ hidden_states = layer_outputs[0]
309
+
310
+ if use_cache:
311
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
312
+
313
+ if output_attentions:
314
+ all_self_attns += (layer_outputs[1],)
315
+
316
+ hidden_states = self.norm(hidden_states)
317
+
318
+ # add hidden states from the last decoder layer
319
+ if output_hidden_states:
320
+ all_hidden_states += (hidden_states,)
321
+
322
+ next_cache = next_decoder_cache if use_cache else None
323
+ if not return_dict:
324
+ return tuple(
325
+ v
326
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
327
+ if v is not None
328
+ )
329
+ return BaseModelOutputWithPast(
330
+ last_hidden_state=hidden_states,
331
+ past_key_values=next_cache,
332
+ hidden_states=all_hidden_states,
333
+ attentions=all_self_attns,
334
+ )
335
+
336
+
337
+ class MistralDecoderLayer(OriginalMistralDecoderLayer):
338
+ """
339
+ patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens
340
+ """
341
+
342
+ def forward(
343
+ self,
344
+ hidden_states: torch.Tensor,
345
+ attention_mask: Optional[torch.Tensor] = None,
346
+ position_ids: Optional[torch.LongTensor] = None,
347
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
348
+ output_attentions: Optional[bool] = False,
349
+ use_cache: Optional[bool] = False,
350
+ cu_seqlens: Optional[torch.Tensor] = None,
351
+ max_seqlen: Optional[torch.Tensor] = None,
352
+ ) -> Tuple[
353
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
354
+ ]:
355
+ """
356
+ Args:
357
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
358
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
359
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
360
+ output_attentions (`bool`, *optional*):
361
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
362
+ returned tensors for more detail.
363
+ use_cache (`bool`, *optional*):
364
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
365
+ (see `past_key_values`).
366
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
367
+ cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
368
+ """
369
+
370
+ residual = hidden_states
371
+
372
+ hidden_states = self.input_layernorm(hidden_states)
373
+
374
+ # Self Attention
375
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
376
+ hidden_states=hidden_states,
377
+ attention_mask=attention_mask,
378
+ position_ids=position_ids,
379
+ past_key_value=past_key_value,
380
+ output_attentions=output_attentions,
381
+ use_cache=use_cache,
382
+ cu_seqlens=cu_seqlens,
383
+ max_seqlen=max_seqlen,
384
+ )
385
+ hidden_states = residual + hidden_states
386
+
387
+ # Fully Connected
388
+ residual = hidden_states
389
+ hidden_states = self.post_attention_layernorm(hidden_states)
390
+ hidden_states = self.mlp(hidden_states)
391
+ hidden_states = residual + hidden_states
392
+
393
+ outputs = (hidden_states,)
394
+
395
+ if output_attentions:
396
+ outputs += (self_attn_weights,)
397
+
398
+ if use_cache:
399
+ outputs += (present_key_value,)
400
+
401
+ return outputs
src/axolotl/utils/models.py CHANGED
@@ -150,6 +150,14 @@ def load_model(
150
  # Note: This might overwrite previous additional_special_tokens
151
  tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
152
 
 
 
 
 
 
 
 
 
153
  if cfg.is_llama_derived_model and cfg.xpos_rope:
154
  from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
155
  replace_llama_rope_with_xpos_rope,
 
150
  # Note: This might overwrite previous additional_special_tokens
151
  tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
152
 
153
+ if cfg.is_mistral_derived_model and cfg.flash_attention:
154
+ from axolotl.monkeypatch.mistral_attn_hijack_flash import (
155
+ replace_mistral_attn_with_flash_attn,
156
+ )
157
+
158
+ LOG.info("patching with flash attention")
159
+ replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
160
+
161
  if cfg.is_llama_derived_model and cfg.xpos_rope:
162
  from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
163
  replace_llama_rope_with_xpos_rope,