casperhansen winglian commited on
Commit
a045db0
·
unverified ·
1 Parent(s): e1b214c

Mistral: Sliding Window Attention with Flash Attention and Sample Packing (#732)

Browse files

* Implement Mistral FA + SWA + Sample Packing

* Handle unbroadcastable tensor

* chore: lint

* Simplify _prepare_decoder_attention_mask

* Uncomment window size

* Upgrade flash-attn to minimum of 2.3.0 to support SWA

* Add original condition to avoid error during inference

* chore: lint

* use torchscript to prevent oom

* chore: pylint

---------

Co-authored-by: Wing Lian <[email protected]>

setup.py CHANGED
@@ -46,7 +46,7 @@ setup(
46
  dependency_links=dependency_links,
47
  extras_require={
48
  "flash-attn": [
49
- "flash-attn>=2.2.1",
50
  ],
51
  "deepspeed": [
52
  "deepspeed",
 
46
  dependency_links=dependency_links,
47
  extras_require={
48
  "flash-attn": [
49
+ "flash-attn>=2.3.0",
50
  ],
51
  "deepspeed": [
52
  "deepspeed",
src/axolotl/monkeypatch/mistral_attn_hijack_flash.py CHANGED
@@ -14,6 +14,9 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor
14
  flash_attn_varlen_qkvpacked_func,
15
  )
16
  from transformers.modeling_outputs import BaseModelOutputWithPast
 
 
 
17
  from transformers.models.mistral.modeling_mistral import (
18
  MistralDecoderLayer as OriginalMistralDecoderLayer,
19
  )
@@ -42,6 +45,44 @@ def replace_mistral_attn_with_flash_attn(
42
  )
43
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  # Disable the transformation of the attention mask in LlamaModel as the flash attention
46
  # requires the attention mask to be the same as the key_padding_mask
47
  def _prepare_decoder_attention_mask(
@@ -53,11 +94,29 @@ def _prepare_decoder_attention_mask(
53
  sliding_window,
54
  ): # pylint: disable=unused-argument
55
  # [bsz, seq_len]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  return attention_mask
57
 
58
 
59
  def flashattn_forward(
60
- self,
61
  hidden_states: torch.Tensor,
62
  attention_mask: Optional[torch.Tensor] = None,
63
  position_ids: Optional[torch.LongTensor] = None,
@@ -91,10 +150,41 @@ def flashattn_forward(
91
  query_states, key_states, cos, sin, position_ids
92
  )
93
 
 
 
 
 
 
 
 
 
 
 
94
  if past_key_value is not None:
95
- # reuse k, v, self_attention
96
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
97
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  past_key_value = (key_states, value_states) if use_cache else None
100
 
@@ -120,7 +210,13 @@ def flashattn_forward(
120
  qkv = rearrange(qkv, "b s ... -> (b s) ...")
121
 
122
  output = flash_attn_varlen_qkvpacked_func(
123
- qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
 
 
 
 
 
 
124
  )
125
  output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
126
  elif query_states.shape == key_states.shape:
@@ -146,6 +242,7 @@ def flashattn_forward(
146
  0.0,
147
  softmax_scale=None,
148
  causal=is_causal,
 
149
  )
150
  output = output_pad_fn(output_unpad)
151
  else:
@@ -157,6 +254,7 @@ def flashattn_forward(
157
  query_states,
158
  torch.stack([key_states, value_states], 2),
159
  causal=is_causal,
 
160
  )
161
  else:
162
  ( # pylint: disable=unbalanced-tuple-unpacking
@@ -191,6 +289,7 @@ def flashattn_forward(
191
  0.0,
192
  softmax_scale=None,
193
  causal=is_causal,
 
194
  )
195
  output = output_pad_fn(output_unpad)
196
 
 
14
  flash_attn_varlen_qkvpacked_func,
15
  )
16
  from transformers.modeling_outputs import BaseModelOutputWithPast
17
+ from transformers.models.mistral.modeling_mistral import (
18
+ MistralAttention as OriginalMistralAttention,
19
+ )
20
  from transformers.models.mistral.modeling_mistral import (
21
  MistralDecoderLayer as OriginalMistralDecoderLayer,
22
  )
 
45
  )
46
 
47
 
48
+ @torch.jit.script
49
+ def _make_sliding_window_causal_mask(
50
+ bsz: int,
51
+ tgt_len: int,
52
+ dtype: torch.dtype,
53
+ device: torch.device,
54
+ past_key_values_length: int = 0,
55
+ sliding_window: int = 4096,
56
+ ):
57
+ """
58
+ Make causal mask used for sliding window attention
59
+ """
60
+ tensor = torch.full(
61
+ (tgt_len, tgt_len),
62
+ fill_value=1,
63
+ device=device,
64
+ )
65
+ mask = torch.tril(tensor, diagonal=0)
66
+ # make the mask banded to account for sliding window
67
+ # NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
68
+ mask = torch.triu(mask, diagonal=-sliding_window + 1)
69
+ mask = torch.log(mask).to(dtype)
70
+
71
+ if past_key_values_length > 0:
72
+ mask = torch.cat(
73
+ [
74
+ torch.zeros(
75
+ tgt_len, past_key_values_length, dtype=dtype, device=device
76
+ ),
77
+ mask,
78
+ ],
79
+ dim=-1,
80
+ )
81
+ return mask[None, None, :, :].expand(
82
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
83
+ )
84
+
85
+
86
  # Disable the transformation of the attention mask in LlamaModel as the flash attention
87
  # requires the attention mask to be the same as the key_padding_mask
88
  def _prepare_decoder_attention_mask(
 
94
  sliding_window,
95
  ): # pylint: disable=unused-argument
96
  # [bsz, seq_len]
97
+ if attention_mask is None:
98
+ return attention_mask
99
+
100
+ # NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
101
+ # Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
102
+ if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
103
+ sliding_window_mask = _make_sliding_window_causal_mask(
104
+ bsz=input_shape[0],
105
+ tgt_len=input_shape[1],
106
+ dtype=inputs_embeds.dtype,
107
+ device=inputs_embeds.device,
108
+ past_key_values_length=past_key_values_length,
109
+ sliding_window=sliding_window,
110
+ )
111
+ attention_mask = attention_mask + sliding_window_mask
112
+ else:
113
+ LOG.info("skipping sliding window mask, not broadcastable with attention mask")
114
+
115
  return attention_mask
116
 
117
 
118
  def flashattn_forward(
119
+ self: OriginalMistralAttention,
120
  hidden_states: torch.Tensor,
121
  attention_mask: Optional[torch.Tensor] = None,
122
  position_ids: Optional[torch.LongTensor] = None,
 
150
  query_states, key_states, cos, sin, position_ids
151
  )
152
 
153
+ use_sliding_windows = (
154
+ hasattr(self.config, "sliding_window") is not None
155
+ and kv_seq_len > self.config.sliding_window
156
+ )
157
+
158
+ if use_sliding_windows:
159
+ window_size = (self.config.sliding_window, self.config.sliding_window)
160
+ else:
161
+ window_size = (-1, -1)
162
+
163
  if past_key_value is not None:
164
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
165
+ if (
166
+ hasattr(self.config, "sliding_window")
167
+ and kv_seq_len > self.config.sliding_window
168
+ ):
169
+ slicing_tokens = kv_seq_len - self.config.sliding_window
170
+
171
+ past_key = past_key_value[0]
172
+ past_value = past_key_value[1]
173
+
174
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
175
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
176
+
177
+ if past_key.shape[-2] != self.config.sliding_window - 1:
178
+ raise ValueError(
179
+ f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
180
+ f" {past_key.shape}"
181
+ )
182
+
183
+ past_key_value = (past_key, past_value) if use_cache else None
184
+
185
+ if past_key_value is not None:
186
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
187
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
188
 
189
  past_key_value = (key_states, value_states) if use_cache else None
190
 
 
210
  qkv = rearrange(qkv, "b s ... -> (b s) ...")
211
 
212
  output = flash_attn_varlen_qkvpacked_func(
213
+ qkv,
214
+ cu_seqlens,
215
+ max_seqlen,
216
+ 0.0,
217
+ softmax_scale=None,
218
+ causal=True,
219
+ window_size=window_size,
220
  )
221
  output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
222
  elif query_states.shape == key_states.shape:
 
242
  0.0,
243
  softmax_scale=None,
244
  causal=is_causal,
245
+ window_size=window_size,
246
  )
247
  output = output_pad_fn(output_unpad)
248
  else:
 
254
  query_states,
255
  torch.stack([key_states, value_states], 2),
256
  causal=is_causal,
257
+ window_size=window_size,
258
  )
259
  else:
260
  ( # pylint: disable=unbalanced-tuple-unpacking
 
289
  0.0,
290
  softmax_scale=None,
291
  causal=is_causal,
292
+ window_size=window_size,
293
  )
294
  output = output_pad_fn(output_unpad)
295