Mistral: Sliding Window Attention with Flash Attention and Sample Packing (#732)
Browse files* Implement Mistral FA + SWA + Sample Packing
* Handle unbroadcastable tensor
* chore: lint
* Simplify _prepare_decoder_attention_mask
* Uncomment window size
* Upgrade flash-attn to minimum of 2.3.0 to support SWA
* Add original condition to avoid error during inference
* chore: lint
* use torchscript to prevent oom
* chore: pylint
---------
Co-authored-by: Wing Lian <[email protected]>
- setup.py +1 -1
- src/axolotl/monkeypatch/mistral_attn_hijack_flash.py +104 -5
setup.py
CHANGED
@@ -46,7 +46,7 @@ setup(
|
|
46 |
dependency_links=dependency_links,
|
47 |
extras_require={
|
48 |
"flash-attn": [
|
49 |
-
"flash-attn>=2.
|
50 |
],
|
51 |
"deepspeed": [
|
52 |
"deepspeed",
|
|
|
46 |
dependency_links=dependency_links,
|
47 |
extras_require={
|
48 |
"flash-attn": [
|
49 |
+
"flash-attn>=2.3.0",
|
50 |
],
|
51 |
"deepspeed": [
|
52 |
"deepspeed",
|
src/axolotl/monkeypatch/mistral_attn_hijack_flash.py
CHANGED
@@ -14,6 +14,9 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor
|
|
14 |
flash_attn_varlen_qkvpacked_func,
|
15 |
)
|
16 |
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
|
|
|
|
|
17 |
from transformers.models.mistral.modeling_mistral import (
|
18 |
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
19 |
)
|
@@ -42,6 +45,44 @@ def replace_mistral_attn_with_flash_attn(
|
|
42 |
)
|
43 |
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
46 |
# requires the attention mask to be the same as the key_padding_mask
|
47 |
def _prepare_decoder_attention_mask(
|
@@ -53,11 +94,29 @@ def _prepare_decoder_attention_mask(
|
|
53 |
sliding_window,
|
54 |
): # pylint: disable=unused-argument
|
55 |
# [bsz, seq_len]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
return attention_mask
|
57 |
|
58 |
|
59 |
def flashattn_forward(
|
60 |
-
self,
|
61 |
hidden_states: torch.Tensor,
|
62 |
attention_mask: Optional[torch.Tensor] = None,
|
63 |
position_ids: Optional[torch.LongTensor] = None,
|
@@ -91,10 +150,41 @@ def flashattn_forward(
|
|
91 |
query_states, key_states, cos, sin, position_ids
|
92 |
)
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
if past_key_value is not None:
|
95 |
-
#
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
past_key_value = (key_states, value_states) if use_cache else None
|
100 |
|
@@ -120,7 +210,13 @@ def flashattn_forward(
|
|
120 |
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
121 |
|
122 |
output = flash_attn_varlen_qkvpacked_func(
|
123 |
-
qkv,
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
)
|
125 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
126 |
elif query_states.shape == key_states.shape:
|
@@ -146,6 +242,7 @@ def flashattn_forward(
|
|
146 |
0.0,
|
147 |
softmax_scale=None,
|
148 |
causal=is_causal,
|
|
|
149 |
)
|
150 |
output = output_pad_fn(output_unpad)
|
151 |
else:
|
@@ -157,6 +254,7 @@ def flashattn_forward(
|
|
157 |
query_states,
|
158 |
torch.stack([key_states, value_states], 2),
|
159 |
causal=is_causal,
|
|
|
160 |
)
|
161 |
else:
|
162 |
( # pylint: disable=unbalanced-tuple-unpacking
|
@@ -191,6 +289,7 @@ def flashattn_forward(
|
|
191 |
0.0,
|
192 |
softmax_scale=None,
|
193 |
causal=is_causal,
|
|
|
194 |
)
|
195 |
output = output_pad_fn(output_unpad)
|
196 |
|
|
|
14 |
flash_attn_varlen_qkvpacked_func,
|
15 |
)
|
16 |
from transformers.modeling_outputs import BaseModelOutputWithPast
|
17 |
+
from transformers.models.mistral.modeling_mistral import (
|
18 |
+
MistralAttention as OriginalMistralAttention,
|
19 |
+
)
|
20 |
from transformers.models.mistral.modeling_mistral import (
|
21 |
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
22 |
)
|
|
|
45 |
)
|
46 |
|
47 |
|
48 |
+
@torch.jit.script
|
49 |
+
def _make_sliding_window_causal_mask(
|
50 |
+
bsz: int,
|
51 |
+
tgt_len: int,
|
52 |
+
dtype: torch.dtype,
|
53 |
+
device: torch.device,
|
54 |
+
past_key_values_length: int = 0,
|
55 |
+
sliding_window: int = 4096,
|
56 |
+
):
|
57 |
+
"""
|
58 |
+
Make causal mask used for sliding window attention
|
59 |
+
"""
|
60 |
+
tensor = torch.full(
|
61 |
+
(tgt_len, tgt_len),
|
62 |
+
fill_value=1,
|
63 |
+
device=device,
|
64 |
+
)
|
65 |
+
mask = torch.tril(tensor, diagonal=0)
|
66 |
+
# make the mask banded to account for sliding window
|
67 |
+
# NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
|
68 |
+
mask = torch.triu(mask, diagonal=-sliding_window + 1)
|
69 |
+
mask = torch.log(mask).to(dtype)
|
70 |
+
|
71 |
+
if past_key_values_length > 0:
|
72 |
+
mask = torch.cat(
|
73 |
+
[
|
74 |
+
torch.zeros(
|
75 |
+
tgt_len, past_key_values_length, dtype=dtype, device=device
|
76 |
+
),
|
77 |
+
mask,
|
78 |
+
],
|
79 |
+
dim=-1,
|
80 |
+
)
|
81 |
+
return mask[None, None, :, :].expand(
|
82 |
+
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
83 |
+
)
|
84 |
+
|
85 |
+
|
86 |
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
87 |
# requires the attention mask to be the same as the key_padding_mask
|
88 |
def _prepare_decoder_attention_mask(
|
|
|
94 |
sliding_window,
|
95 |
): # pylint: disable=unused-argument
|
96 |
# [bsz, seq_len]
|
97 |
+
if attention_mask is None:
|
98 |
+
return attention_mask
|
99 |
+
|
100 |
+
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
|
101 |
+
# Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
|
102 |
+
if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
|
103 |
+
sliding_window_mask = _make_sliding_window_causal_mask(
|
104 |
+
bsz=input_shape[0],
|
105 |
+
tgt_len=input_shape[1],
|
106 |
+
dtype=inputs_embeds.dtype,
|
107 |
+
device=inputs_embeds.device,
|
108 |
+
past_key_values_length=past_key_values_length,
|
109 |
+
sliding_window=sliding_window,
|
110 |
+
)
|
111 |
+
attention_mask = attention_mask + sliding_window_mask
|
112 |
+
else:
|
113 |
+
LOG.info("skipping sliding window mask, not broadcastable with attention mask")
|
114 |
+
|
115 |
return attention_mask
|
116 |
|
117 |
|
118 |
def flashattn_forward(
|
119 |
+
self: OriginalMistralAttention,
|
120 |
hidden_states: torch.Tensor,
|
121 |
attention_mask: Optional[torch.Tensor] = None,
|
122 |
position_ids: Optional[torch.LongTensor] = None,
|
|
|
150 |
query_states, key_states, cos, sin, position_ids
|
151 |
)
|
152 |
|
153 |
+
use_sliding_windows = (
|
154 |
+
hasattr(self.config, "sliding_window") is not None
|
155 |
+
and kv_seq_len > self.config.sliding_window
|
156 |
+
)
|
157 |
+
|
158 |
+
if use_sliding_windows:
|
159 |
+
window_size = (self.config.sliding_window, self.config.sliding_window)
|
160 |
+
else:
|
161 |
+
window_size = (-1, -1)
|
162 |
+
|
163 |
if past_key_value is not None:
|
164 |
+
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
165 |
+
if (
|
166 |
+
hasattr(self.config, "sliding_window")
|
167 |
+
and kv_seq_len > self.config.sliding_window
|
168 |
+
):
|
169 |
+
slicing_tokens = kv_seq_len - self.config.sliding_window
|
170 |
+
|
171 |
+
past_key = past_key_value[0]
|
172 |
+
past_value = past_key_value[1]
|
173 |
+
|
174 |
+
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
175 |
+
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
176 |
+
|
177 |
+
if past_key.shape[-2] != self.config.sliding_window - 1:
|
178 |
+
raise ValueError(
|
179 |
+
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
180 |
+
f" {past_key.shape}"
|
181 |
+
)
|
182 |
+
|
183 |
+
past_key_value = (past_key, past_value) if use_cache else None
|
184 |
+
|
185 |
+
if past_key_value is not None:
|
186 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
187 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
188 |
|
189 |
past_key_value = (key_states, value_states) if use_cache else None
|
190 |
|
|
|
210 |
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
211 |
|
212 |
output = flash_attn_varlen_qkvpacked_func(
|
213 |
+
qkv,
|
214 |
+
cu_seqlens,
|
215 |
+
max_seqlen,
|
216 |
+
0.0,
|
217 |
+
softmax_scale=None,
|
218 |
+
causal=True,
|
219 |
+
window_size=window_size,
|
220 |
)
|
221 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
222 |
elif query_states.shape == key_states.shape:
|
|
|
242 |
0.0,
|
243 |
softmax_scale=None,
|
244 |
causal=is_causal,
|
245 |
+
window_size=window_size,
|
246 |
)
|
247 |
output = output_pad_fn(output_unpad)
|
248 |
else:
|
|
|
254 |
query_states,
|
255 |
torch.stack([key_states, value_states], 2),
|
256 |
causal=is_causal,
|
257 |
+
window_size=window_size,
|
258 |
)
|
259 |
else:
|
260 |
( # pylint: disable=unbalanced-tuple-unpacking
|
|
|
289 |
0.0,
|
290 |
softmax_scale=None,
|
291 |
causal=is_causal,
|
292 |
+
window_size=window_size,
|
293 |
)
|
294 |
output = output_pad_fn(output_unpad)
|
295 |
|