winglian commited on
Commit
7fabc4d
·
unverified ·
1 Parent(s): 9a5eb39

Mixtral official (#942)

Browse files

* multipack support for official mixtral implementation

* fix patch to load multipack for mixtral

* chore: lint

examples/mistral/mixtral.yml CHANGED
@@ -1,5 +1,5 @@
1
- base_model: DiscoResearch/mixtral-7b-8expert
2
- model_type: MixtralForCausalLM
3
  tokenizer_type: LlamaTokenizer
4
  trust_remote_code: true
5
 
 
1
+ base_model: mistralai/Mixtral-8x7B-v0.1
2
+ model_type: AutoModelForCausalLM
3
  tokenizer_type: LlamaTokenizer
4
  trust_remote_code: true
5
 
requirements.txt CHANGED
@@ -2,7 +2,7 @@
2
  auto-gptq==0.5.1
3
  packaging
4
  peft==0.6.0
5
- transformers @ git+https://github.com/huggingface/transformers.git@df5c5c62ae253055336f5bb0828ca8e3e15ab6bd
6
  tokenizers==0.15.0
7
  bitsandbytes>=0.41.1
8
  accelerate==0.24.1
 
2
  auto-gptq==0.5.1
3
  packaging
4
  peft==0.6.0
5
+ transformers @ git+https://github.com/huggingface/transformers.git@e5079b0b2abcef11ecbdae60ba4a6636c57b725d
6
  tokenizers==0.15.0
7
  bitsandbytes>=0.41.1
8
  accelerate==0.24.1
src/axolotl/models/mixtral/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- """
2
- Custom modeling code for mixtral
3
- """
4
-
5
- from .configuration_moe_mistral import MixtralConfig # noqa
6
- from .modeling_moe_mistral import MixtralForCausalLM # noqa
 
 
 
 
 
 
 
src/axolotl/models/mixtral/configuration_moe_mistral.py DELETED
@@ -1,154 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ Mistral model configuration"""
16
-
17
- from transformers.configuration_utils import PretrainedConfig
18
- from transformers.utils import logging
19
-
20
- logger = logging.get_logger(__name__)
21
-
22
- MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
23
- "mistralai/Mistral-7B-v0.1": "https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json",
24
- "mistralai/Mistral-7B-Instruct-v0.1": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/config.json",
25
- }
26
-
27
-
28
- class MixtralConfig(PretrainedConfig):
29
- r"""
30
- This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an
31
- Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration
32
- with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1.
33
-
34
- [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
35
- [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
36
-
37
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38
- documentation from [`PretrainedConfig`] for more information.
39
-
40
-
41
- Args:
42
- vocab_size (`int`, *optional*, defaults to 32000):
43
- Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the
44
- `inputs_ids` passed when calling [`MistralModel`]
45
- hidden_size (`int`, *optional*, defaults to 4096):
46
- Dimension of the hidden representations.
47
- intermediate_size (`int`, *optional*, defaults to 14336):
48
- Dimension of the MLP representations.
49
- num_hidden_layers (`int`, *optional*, defaults to 32):
50
- Number of hidden layers in the Transformer encoder.
51
- num_attention_heads (`int`, *optional*, defaults to 32):
52
- Number of attention heads for each attention layer in the Transformer encoder.
53
- num_key_value_heads (`int`, *optional*, defaults to 8):
54
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
55
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
56
- `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
57
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
58
- by meanpooling all the original heads within that group. For more details checkout [this
59
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
60
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
61
- The non-linear activation function (function or string) in the decoder.
62
- max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
63
- The maximum sequence length that this model might ever be used with. Mistral's sliding window attention
64
- allows sequence of up to 4096*32 tokens.
65
- initializer_range (`float`, *optional*, defaults to 0.02):
66
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
67
- rms_norm_eps (`float`, *optional*, defaults to 1e-06):
68
- The epsilon used by the rms normalization layers.
69
- use_cache (`bool`, *optional*, defaults to `True`):
70
- Whether or not the model should return the last key/values attentions (not used by all models). Only
71
- relevant if `config.is_decoder=True`.
72
- pad_token_id (`int`, *optional*):
73
- The id of the padding token.
74
- bos_token_id (`int`, *optional*, defaults to 1):
75
- The id of the "beginning-of-sequence" token.
76
- eos_token_id (`int`, *optional*, defaults to 2):
77
- The id of the "end-of-sequence" token.
78
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
79
- Whether the model's input and output word embeddings should be tied.
80
- rope_theta (`float`, *optional*, defaults to 10000.0):
81
- The base period of the RoPE embeddings.
82
- sliding_window (`int`, *optional*, defaults to 4096):
83
- Sliding window attention window size. If not specified, will default to `4096`.
84
- attention_dropout (`float`, *optional*, defaults to 0.0):
85
- The dropout ratio for the attention probabilities.
86
-
87
- ```python
88
- >>> from transformers import MistralModel, MistralConfig
89
-
90
- >>> # Initializing a Mistral 7B style configuration
91
- >>> configuration = MixtralConfig()
92
-
93
- >>> # Initializing a model from the Mistral 7B style configuration
94
- >>> model = MixtralModel(configuration)
95
-
96
- >>> # Accessing the model configuration
97
- >>> configuration = model.config
98
- ```"""
99
-
100
- model_type = "mistral"
101
- keys_to_ignore_at_inference = ["past_key_values"]
102
-
103
- def __init__(
104
- self,
105
- vocab_size=32000,
106
- hidden_size=4096,
107
- intermediate_size=14336,
108
- num_hidden_layers=32,
109
- num_attention_heads=32,
110
- num_key_value_heads=8,
111
- hidden_act="silu",
112
- max_position_embeddings=4096 * 32,
113
- initializer_range=0.02,
114
- rms_norm_eps=1e-6,
115
- use_cache=True,
116
- pad_token_id=None,
117
- bos_token_id=1,
118
- eos_token_id=2,
119
- tie_word_embeddings=False,
120
- rope_theta=10000.0,
121
- attention_dropout=0.0,
122
- num_experts_per_token=2,
123
- num_experts=8,
124
- **kwargs,
125
- ):
126
- self.vocab_size = vocab_size
127
- self.max_position_embeddings = max_position_embeddings
128
- self.hidden_size = hidden_size
129
- self.intermediate_size = intermediate_size
130
- self.num_hidden_layers = num_hidden_layers
131
- self.num_attention_heads = num_attention_heads
132
-
133
- # for backward compatibility
134
- if num_key_value_heads is None:
135
- num_key_value_heads = num_attention_heads
136
-
137
- self.num_key_value_heads = num_key_value_heads
138
- self.hidden_act = hidden_act
139
- self.initializer_range = initializer_range
140
- self.rms_norm_eps = rms_norm_eps
141
- self.use_cache = use_cache
142
- self.rope_theta = rope_theta
143
- self.attention_dropout = attention_dropout
144
- self.num_experts = num_experts
145
- self.num_experts_per_token = num_experts_per_token
146
-
147
- # pylint: disable=duplicate-code
148
- super().__init__(
149
- pad_token_id=pad_token_id,
150
- bos_token_id=bos_token_id,
151
- eos_token_id=eos_token_id,
152
- tie_word_embeddings=tie_word_embeddings,
153
- **kwargs,
154
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/axolotl/models/mixtral/modeling_moe_mistral.py DELETED
@@ -1,1505 +0,0 @@
1
- # pylint: skip-file
2
- # coding=utf-8
3
- # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
4
- #
5
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
- # and OPT implementations in this library. It has been modified from its
7
- # original forms to accommodate minor architectural differences compared
8
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
- #
10
- # Licensed under the Apache License, Version 2.0 (the "License");
11
- # you may not use this file except in compliance with the License.
12
- # You may obtain a copy of the License at
13
- #
14
- # http://www.apache.org/licenses/LICENSE-2.0
15
- #
16
- # Unless required by applicable law or agreed to in writing, software
17
- # distributed under the License is distributed on an "AS IS" BASIS,
18
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
- # See the License for the specific language governing permissions and
20
- # limitations under the License.
21
- """ PyTorch Mistral model."""
22
- import inspect
23
- import math
24
- import warnings
25
- from typing import List, Optional, Tuple, Union
26
-
27
- import torch
28
- import torch.nn.functional as F
29
- import torch.utils.checkpoint
30
- from einops import rearrange
31
- from torch import nn
32
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
33
- from transformers.cache_utils import Cache, DynamicCache
34
- from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
35
- from transformers.modeling_outputs import (
36
- BaseModelOutputWithPast,
37
- CausalLMOutputWithPast,
38
- SequenceClassifierOutputWithPast,
39
- )
40
- from transformers.modeling_utils import PreTrainedModel
41
- from transformers.utils import (
42
- add_start_docstrings,
43
- add_start_docstrings_to_model_forward,
44
- is_flash_attn_2_available,
45
- is_flash_attn_greater_or_equal_2_10,
46
- logging,
47
- replace_return_docstrings,
48
- )
49
-
50
- from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids
51
- from .configuration_moe_mistral import MixtralConfig
52
-
53
- if is_flash_attn_2_available():
54
- from flash_attn import (
55
- flash_attn_func,
56
- flash_attn_varlen_func,
57
- flash_attn_varlen_qkvpacked_func,
58
- )
59
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
60
-
61
- _flash_supports_window_size = "window_size" in list(
62
- inspect.signature(flash_attn_func).parameters
63
- )
64
-
65
-
66
- logger = logging.get_logger(__name__)
67
-
68
- _CONFIG_FOR_DOC = "MixtralConfig"
69
-
70
-
71
- # Copied from transformers.models.llama.modeling_llama._get_unpad_data
72
- def _get_unpad_data(attention_mask):
73
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
74
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
75
- max_seqlen_in_batch = seqlens_in_batch.max().item()
76
- cu_seqlens = F.pad(
77
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
78
- )
79
- return (
80
- indices,
81
- cu_seqlens,
82
- max_seqlen_in_batch,
83
- )
84
-
85
-
86
- # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral
87
- class MistralRMSNorm(nn.Module):
88
- def __init__(self, hidden_size, eps=1e-6):
89
- """
90
- MistralRMSNorm is equivalent to T5LayerNorm
91
- """
92
- super().__init__()
93
- self.weight = nn.Parameter(torch.ones(hidden_size))
94
- self.variance_epsilon = eps
95
-
96
- def forward(self, hidden_states):
97
- input_dtype = hidden_states.dtype
98
- hidden_states = hidden_states.to(torch.float32)
99
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
100
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
101
- return self.weight * hidden_states.to(input_dtype)
102
-
103
-
104
- # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
105
- class MistralRotaryEmbedding(nn.Module):
106
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
107
- super().__init__()
108
-
109
- self.dim = dim
110
- self.max_position_embeddings = max_position_embeddings
111
- self.base = base
112
- inv_freq = 1.0 / (
113
- self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
114
- )
115
- self.register_buffer("inv_freq", inv_freq, persistent=False)
116
-
117
- # Build here to make `torch.jit.trace` work.
118
- self._set_cos_sin_cache(
119
- seq_len=max_position_embeddings,
120
- device=self.inv_freq.device,
121
- dtype=torch.get_default_dtype(),
122
- )
123
-
124
- def _set_cos_sin_cache(self, seq_len, device, dtype):
125
- self.max_seq_len_cached = seq_len
126
- t = torch.arange(
127
- self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
128
- )
129
-
130
- freqs = torch.outer(t, self.inv_freq)
131
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
132
- emb = torch.cat((freqs, freqs), dim=-1)
133
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
134
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
135
-
136
- def forward(self, x, seq_len=None):
137
- # x: [bs, num_attention_heads, seq_len, head_size]
138
- if seq_len > self.max_seq_len_cached:
139
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
140
-
141
- return (
142
- self.cos_cached[:seq_len].to(dtype=x.dtype),
143
- self.sin_cached[:seq_len].to(dtype=x.dtype),
144
- )
145
-
146
-
147
- # Copied from transformers.models.llama.modeling_llama.rotate_half
148
- def rotate_half(x):
149
- """Rotates half the hidden dims of the input."""
150
- x1 = x[..., : x.shape[-1] // 2]
151
- x2 = x[..., x.shape[-1] // 2 :]
152
- return torch.cat((-x2, x1), dim=-1)
153
-
154
-
155
- # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
156
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
157
- """Applies Rotary Position Embedding to the query and key tensors.
158
-
159
- Args:
160
- q (`torch.Tensor`): The query tensor.
161
- k (`torch.Tensor`): The key tensor.
162
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
163
- sin (`torch.Tensor`): The sine part of the rotary embedding.
164
- position_ids (`torch.Tensor`):
165
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
166
- used to pass offsetted position ids when working with a KV-cache.
167
- unsqueeze_dim (`int`, *optional*, defaults to 1):
168
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
169
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
170
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
171
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
172
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
173
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
174
- Returns:
175
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
176
- """
177
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
178
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
179
- q_embed = (q * cos) + (rotate_half(q) * sin)
180
- k_embed = (k * cos) + (rotate_half(k) * sin)
181
- return q_embed, k_embed
182
-
183
-
184
- class FeedForward(nn.Module):
185
- def __init__(self, config):
186
- """
187
- Initialize the FeedForward module.
188
-
189
- Args:
190
- dim (int): Input dimension.
191
- hidden_dim (int): Hidden dimension of the feedforward layer.
192
- multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
193
- ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
194
-
195
- Attributes:
196
- w1 (ColumnParallelLinear): Linear transformation for the first layer.
197
- w2 (RowParallelLinear): Linear transformation for the second layer.
198
- w3 (ColumnParallelLinear): Linear transformation for the third layer.
199
-
200
- """
201
- super().__init__()
202
-
203
- self.w1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
204
- self.w2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
205
- self.w3 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
206
-
207
- def forward(self, x):
208
- return self.w2(F.silu(self.w1(x)) * self.w3(x))
209
-
210
-
211
- class MoE(nn.Module):
212
- def __init__(
213
- self,
214
- config,
215
- ):
216
- super().__init__()
217
- self.config = config
218
- self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
219
- self.experts = nn.ModuleList(
220
- [FeedForward(config) for i in range(config.num_experts)]
221
- )
222
-
223
- def forward(self, x):
224
- orig_shape = x.shape
225
- x = x.view(-1, x.shape[-1])
226
-
227
- scores = self.gate(x).softmax(dim=-1)
228
- expert_weights, expert_indices = torch.topk(
229
- scores, self.config.num_experts_per_token, dim=-1
230
- )
231
- flat_expert_indices = expert_indices.view(-1)
232
-
233
- x = x.repeat_interleave(self.config.num_experts_per_token, dim=0)
234
- y = torch.empty_like(x)
235
- for i, expert in enumerate(self.experts):
236
- y[flat_expert_indices == i] = expert(x[flat_expert_indices == i])
237
- y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(
238
- dim=1
239
- )
240
- return y.view(*orig_shape)
241
-
242
-
243
- # Copied from transformers.models.llama.modeling_llama.repeat_kv
244
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
245
- """
246
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
247
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
248
- """
249
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
250
- if n_rep == 1:
251
- return hidden_states
252
- hidden_states = hidden_states[:, :, None, :, :].expand(
253
- batch, num_key_value_heads, n_rep, slen, head_dim
254
- )
255
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
256
-
257
-
258
- class MistralAttention(nn.Module):
259
- """
260
- Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
261
- and "Generating Long Sequences with Sparse Transformers".
262
- """
263
-
264
- def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None):
265
- super().__init__()
266
- self.config = config
267
- self.layer_idx = layer_idx
268
- if layer_idx is None:
269
- logger.warning_once(
270
- f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
271
- "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
272
- "when creating this class."
273
- )
274
-
275
- self.hidden_size = config.hidden_size
276
- self.num_heads = config.num_attention_heads
277
- self.head_dim = self.hidden_size // self.num_heads
278
- self.num_key_value_heads = config.num_key_value_heads
279
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
280
- self.max_position_embeddings = config.max_position_embeddings
281
- self.rope_theta = config.rope_theta
282
- self.is_causal = True
283
- self.attention_dropout = config.attention_dropout
284
-
285
- if (self.head_dim * self.num_heads) != self.hidden_size:
286
- raise ValueError(
287
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
288
- f" and `num_heads`: {self.num_heads})."
289
- )
290
- self.q_proj = nn.Linear(
291
- self.hidden_size, self.num_heads * self.head_dim, bias=False
292
- )
293
- self.k_proj = nn.Linear(
294
- self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
295
- )
296
- self.v_proj = nn.Linear(
297
- self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
298
- )
299
- self.o_proj = nn.Linear(
300
- self.num_heads * self.head_dim, self.hidden_size, bias=False
301
- )
302
-
303
- self.rotary_emb = MistralRotaryEmbedding(
304
- self.head_dim,
305
- max_position_embeddings=self.max_position_embeddings,
306
- base=self.rope_theta,
307
- )
308
-
309
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
310
- return (
311
- tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
312
- .transpose(1, 2)
313
- .contiguous()
314
- )
315
-
316
- def forward(
317
- self,
318
- hidden_states: torch.Tensor,
319
- attention_mask: Optional[torch.Tensor] = None,
320
- position_ids: Optional[torch.LongTensor] = None,
321
- past_key_value: Optional[Cache] = None,
322
- output_attentions: bool = False,
323
- use_cache: bool = False,
324
- **kwargs,
325
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
326
- if "padding_mask" in kwargs:
327
- warnings.warn(
328
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
329
- )
330
- bsz, q_len, _ = hidden_states.size()
331
-
332
- query_states = self.q_proj(hidden_states)
333
- key_states = self.k_proj(hidden_states)
334
- value_states = self.v_proj(hidden_states)
335
-
336
- query_states = query_states.view(
337
- bsz, q_len, self.num_heads, self.head_dim
338
- ).transpose(1, 2)
339
- key_states = key_states.view(
340
- bsz, q_len, self.num_key_value_heads, self.head_dim
341
- ).transpose(1, 2)
342
- value_states = value_states.view(
343
- bsz, q_len, self.num_key_value_heads, self.head_dim
344
- ).transpose(1, 2)
345
-
346
- kv_seq_len = key_states.shape[-2]
347
- if past_key_value is not None:
348
- if self.layer_idx is None:
349
- raise ValueError(
350
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
351
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
352
- "with a layer index."
353
- )
354
- kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
355
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
356
- query_states, key_states = apply_rotary_pos_emb(
357
- query_states, key_states, cos, sin, position_ids
358
- )
359
-
360
- if past_key_value is not None:
361
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
362
- key_states, value_states = past_key_value.update(
363
- key_states, value_states, self.layer_idx, cache_kwargs
364
- )
365
-
366
- # repeat k/v heads if n_kv_heads < n_heads
367
- key_states = repeat_kv(key_states, self.num_key_value_groups)
368
- value_states = repeat_kv(value_states, self.num_key_value_groups)
369
-
370
- attn_weights = torch.matmul(
371
- query_states, key_states.transpose(2, 3)
372
- ) / math.sqrt(self.head_dim)
373
-
374
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
375
- raise ValueError(
376
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
377
- f" {attn_weights.size()}"
378
- )
379
-
380
- if attention_mask is not None:
381
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
382
- raise ValueError(
383
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
384
- )
385
-
386
- attn_weights = attn_weights + attention_mask
387
-
388
- # upcast attention to fp32
389
- attn_weights = nn.functional.softmax(
390
- attn_weights, dim=-1, dtype=torch.float32
391
- ).to(query_states.dtype)
392
- attn_weights = nn.functional.dropout(
393
- attn_weights, p=self.attention_dropout, training=self.training
394
- )
395
- attn_output = torch.matmul(attn_weights, value_states)
396
-
397
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
398
- raise ValueError(
399
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
400
- f" {attn_output.size()}"
401
- )
402
-
403
- attn_output = attn_output.transpose(1, 2).contiguous()
404
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
405
-
406
- attn_output = self.o_proj(attn_output)
407
-
408
- if not output_attentions:
409
- attn_weights = None
410
-
411
- return attn_output, attn_weights, past_key_value
412
-
413
-
414
- class MistralFlashAttention2(MistralAttention):
415
- """
416
- Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays
417
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
418
- flash attention and deal with padding tokens in case the input contains any of them.
419
- """
420
-
421
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
422
- def __init__(self, *args, **kwargs):
423
- super().__init__(*args, **kwargs)
424
-
425
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
426
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
427
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
428
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
429
-
430
- def forward(
431
- self,
432
- hidden_states: torch.Tensor,
433
- attention_mask: Optional[torch.Tensor] = None,
434
- position_ids: Optional[torch.LongTensor] = None,
435
- past_key_value: Optional[Cache] = None,
436
- output_attentions: bool = False,
437
- use_cache: bool = False,
438
- cu_seqlens: Optional[torch.Tensor] = None,
439
- max_seqlen: Optional[torch.Tensor] = None,
440
- **kwargs,
441
- ):
442
- if "padding_mask" in kwargs:
443
- warnings.warn(
444
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
445
- )
446
-
447
- # overwrite attention_mask with padding_mask
448
- attention_mask = kwargs.pop("padding_mask")
449
- bsz, q_len, _ = hidden_states.size()
450
-
451
- query_states = self.q_proj(hidden_states)
452
- key_states = self.k_proj(hidden_states)
453
- value_states = self.v_proj(hidden_states)
454
-
455
- query_states = query_states.view(
456
- bsz, q_len, self.num_heads, self.head_dim
457
- ).transpose(1, 2)
458
- key_states = key_states.view(
459
- bsz, q_len, self.num_key_value_heads, self.head_dim
460
- ).transpose(1, 2)
461
- value_states = value_states.view(
462
- bsz, q_len, self.num_key_value_heads, self.head_dim
463
- ).transpose(1, 2)
464
-
465
- kv_seq_len = key_states.shape[-2]
466
- if past_key_value is not None:
467
- kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
468
-
469
- # Because the input can be padded, the absolute sequence length depends on the max position id.
470
- rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
471
- cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
472
-
473
- query_states, key_states = apply_rotary_pos_emb(
474
- query_states, key_states, cos, sin, position_ids
475
- )
476
-
477
- use_sliding_windows = (
478
- _flash_supports_window_size
479
- and getattr(self.config, "sliding_window", None) is not None
480
- and kv_seq_len > self.config.sliding_window
481
- )
482
-
483
- if not _flash_supports_window_size:
484
- logger.warning_once(
485
- "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
486
- " make sure to upgrade flash-attn library."
487
- )
488
-
489
- if past_key_value is not None:
490
- # Activate slicing cache only if the config has a value `sliding_windows` attribute
491
- if (
492
- getattr(self.config, "sliding_window", None) is not None
493
- and kv_seq_len > self.config.sliding_window
494
- ):
495
- slicing_tokens = 1 - self.config.sliding_window
496
-
497
- past_key = past_key_value[0]
498
- past_value = past_key_value[1]
499
-
500
- past_key = past_key[:, :, slicing_tokens:, :].contiguous()
501
- past_value = past_value[:, :, slicing_tokens:, :].contiguous()
502
-
503
- if past_key.shape[-2] != self.config.sliding_window - 1:
504
- raise ValueError(
505
- f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
506
- f" {past_key.shape}"
507
- )
508
-
509
- past_key_value = (past_key, past_value)
510
-
511
- if attention_mask is not None:
512
- attention_mask = attention_mask[:, slicing_tokens:]
513
- attention_mask = torch.cat(
514
- [attention_mask, torch.ones_like(attention_mask[:, -1:])],
515
- dim=-1,
516
- )
517
-
518
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
519
- key_states, value_states = past_key_value.update(
520
- key_states, value_states, self.layer_idx, cache_kwargs
521
- )
522
-
523
- # repeat k/v heads if n_kv_heads < n_heads
524
- key_states = repeat_kv(key_states, self.num_key_value_groups)
525
- value_states = repeat_kv(value_states, self.num_key_value_groups)
526
- dropout_rate = 0.0 if not self.training else self.attention_dropout
527
-
528
- if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
529
- # special handling using sample packing
530
- qkv = torch.stack(
531
- [query_states, key_states, value_states], dim=2
532
- ) # [bsz, nh, 3, q_len, hd]
533
- qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
534
- qkv = rearrange(qkv, "b s ... -> (b s) ...")
535
-
536
- attn_output = flash_attn_varlen_qkvpacked_func(
537
- qkv,
538
- cu_seqlens,
539
- max_seqlen,
540
- dropout_p=dropout_rate,
541
- softmax_scale=None,
542
- causal=True,
543
- )
544
- attn_output = rearrange(attn_output, "(b s) ... -> b s ...", b=bsz)
545
- else:
546
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
547
- # therefore the input hidden states gets silently casted in float32. Hence, we need
548
- # cast them back in float16 just to be sure everything works as expected.
549
- input_dtype = query_states.dtype
550
- if input_dtype == torch.float32:
551
- # Handle the case where the model is quantized
552
- if hasattr(self.config, "_pre_quantization_dtype"):
553
- target_dtype = self.config._pre_quantization_dtype
554
- else:
555
- target_dtype = self.q_proj.weight.dtype
556
-
557
- logger.warning_once(
558
- f"The input hidden states seems to be silently casted in float32, this might be related to"
559
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
560
- f" {target_dtype}."
561
- )
562
-
563
- query_states = query_states.to(target_dtype)
564
- key_states = key_states.to(target_dtype)
565
- value_states = value_states.to(target_dtype)
566
-
567
- # Reashape to the expected shape for Flash Attention
568
- query_states = query_states.transpose(1, 2)
569
- key_states = key_states.transpose(1, 2)
570
- value_states = value_states.transpose(1, 2)
571
-
572
- attn_output = self._flash_attention_forward(
573
- query_states,
574
- key_states,
575
- value_states,
576
- attention_mask,
577
- q_len,
578
- dropout=dropout_rate,
579
- use_sliding_windows=use_sliding_windows,
580
- )
581
-
582
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
583
- attn_output = self.o_proj(attn_output)
584
-
585
- if not output_attentions:
586
- attn_weights = None
587
-
588
- return attn_output, attn_weights, past_key_value
589
-
590
- def _flash_attention_forward(
591
- self,
592
- query_states,
593
- key_states,
594
- value_states,
595
- attention_mask,
596
- query_length,
597
- dropout=0.0,
598
- softmax_scale=None,
599
- use_sliding_windows=False,
600
- ):
601
- """
602
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
603
- first unpad the input, then computes the attention scores and pad the final attention scores.
604
-
605
- Args:
606
- query_states (`torch.Tensor`):
607
- Input query states to be passed to Flash Attention API
608
- key_states (`torch.Tensor`):
609
- Input key states to be passed to Flash Attention API
610
- value_states (`torch.Tensor`):
611
- Input value states to be passed to Flash Attention API
612
- attention_mask (`torch.Tensor`):
613
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
614
- position of padding tokens and 1 for the position of non-padding tokens.
615
- dropout (`int`, *optional*):
616
- Attention dropout
617
- softmax_scale (`float`, *optional*):
618
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
619
- use_sliding_windows (`bool`, *optional*):
620
- Whether to activate sliding window attention.
621
- """
622
- if not self._flash_attn_uses_top_left_mask:
623
- causal = self.is_causal
624
- else:
625
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
626
- causal = self.is_causal and query_length != 1
627
-
628
- # Contains at least one padding token in the sequence
629
- if attention_mask is not None:
630
- batch_size = query_states.shape[0]
631
- (
632
- query_states,
633
- key_states,
634
- value_states,
635
- indices_q,
636
- cu_seq_lens,
637
- max_seq_lens,
638
- ) = self._upad_input(
639
- query_states, key_states, value_states, attention_mask, query_length
640
- )
641
-
642
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
643
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
644
-
645
- if not use_sliding_windows:
646
- attn_output_unpad = flash_attn_varlen_func(
647
- query_states,
648
- key_states,
649
- value_states,
650
- cu_seqlens_q=cu_seqlens_q,
651
- cu_seqlens_k=cu_seqlens_k,
652
- max_seqlen_q=max_seqlen_in_batch_q,
653
- max_seqlen_k=max_seqlen_in_batch_k,
654
- dropout_p=dropout,
655
- softmax_scale=softmax_scale,
656
- causal=causal,
657
- )
658
- else:
659
- attn_output_unpad = flash_attn_varlen_func(
660
- query_states,
661
- key_states,
662
- value_states,
663
- cu_seqlens_q=cu_seqlens_q,
664
- cu_seqlens_k=cu_seqlens_k,
665
- max_seqlen_q=max_seqlen_in_batch_q,
666
- max_seqlen_k=max_seqlen_in_batch_k,
667
- dropout_p=dropout,
668
- softmax_scale=softmax_scale,
669
- causal=causal,
670
- window_size=(
671
- self.config.sliding_window,
672
- self.config.sliding_window,
673
- ),
674
- )
675
-
676
- attn_output = pad_input(
677
- attn_output_unpad, indices_q, batch_size, query_length
678
- )
679
- else:
680
- if not use_sliding_windows:
681
- attn_output = flash_attn_func(
682
- query_states,
683
- key_states,
684
- value_states,
685
- dropout,
686
- softmax_scale=softmax_scale,
687
- causal=causal,
688
- )
689
- else:
690
- attn_output = flash_attn_func(
691
- query_states,
692
- key_states,
693
- value_states,
694
- dropout,
695
- softmax_scale=softmax_scale,
696
- causal=causal,
697
- window_size=(
698
- self.config.sliding_window,
699
- self.config.sliding_window,
700
- ),
701
- )
702
-
703
- return attn_output
704
-
705
- def _upad_input(
706
- self, query_layer, key_layer, value_layer, attention_mask, query_length
707
- ):
708
- batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
709
-
710
- # On the first iteration we need to properly re-create the padding mask
711
- # by slicing it on the proper place
712
- if kv_seq_len != attention_mask.shape[-1]:
713
- attention_mask_num_tokens = attention_mask.shape[-1]
714
- attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
715
-
716
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
717
-
718
- key_layer = index_first_axis(
719
- key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
720
- )
721
- value_layer = index_first_axis(
722
- value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
723
- )
724
-
725
- if query_length == kv_seq_len:
726
- query_layer = index_first_axis(
727
- query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim),
728
- indices_k,
729
- )
730
- cu_seqlens_q = cu_seqlens_k
731
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
732
- indices_q = indices_k
733
- elif query_length == 1:
734
- max_seqlen_in_batch_q = 1
735
- cu_seqlens_q = torch.arange(
736
- batch_size + 1, dtype=torch.int32, device=query_layer.device
737
- ) # There is a memcpy here, that is very bad.
738
- indices_q = cu_seqlens_q[:-1]
739
- query_layer = query_layer.squeeze(1)
740
- else:
741
- # The -q_len: slice assumes left padding.
742
- attention_mask = attention_mask[:, -query_length:]
743
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
744
- query_layer, attention_mask
745
- )
746
-
747
- return (
748
- query_layer,
749
- key_layer,
750
- value_layer,
751
- indices_q,
752
- (cu_seqlens_q, cu_seqlens_k),
753
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
754
- )
755
-
756
-
757
- class MixtralDecoderLayer(nn.Module):
758
- def __init__(self, config: MixtralConfig, layer_idx: int):
759
- super().__init__()
760
- self.hidden_size = config.hidden_size
761
- self.self_attn = MistralFlashAttention2(config, layer_idx=layer_idx)
762
- self.mlp = MoE(config)
763
- self.input_layernorm = MistralRMSNorm(
764
- config.hidden_size, eps=config.rms_norm_eps
765
- )
766
- self.post_attention_layernorm = MistralRMSNorm(
767
- config.hidden_size, eps=config.rms_norm_eps
768
- )
769
-
770
- def forward(
771
- self,
772
- hidden_states: torch.Tensor,
773
- attention_mask: Optional[torch.Tensor] = None,
774
- position_ids: Optional[torch.LongTensor] = None,
775
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
776
- output_attentions: Optional[bool] = False,
777
- use_cache: Optional[bool] = False,
778
- cu_seqlens: Optional[torch.Tensor] = None,
779
- max_seqlen: Optional[torch.Tensor] = None,
780
- **kwargs,
781
- ) -> Tuple[
782
- torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
783
- ]:
784
- if "padding_mask" in kwargs:
785
- warnings.warn(
786
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
787
- )
788
- """
789
- Args:
790
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
791
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
792
- `(batch, sequence_length)` where padding elements are indicated by 0.
793
- output_attentions (`bool`, *optional*):
794
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
795
- returned tensors for more detail.
796
- use_cache (`bool`, *optional*):
797
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
798
- (see `past_key_values`).
799
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
800
- """
801
-
802
- residual = hidden_states
803
-
804
- hidden_states = self.input_layernorm(hidden_states)
805
-
806
- # Self Attention
807
- # pylint: disable=duplicate-code
808
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
809
- hidden_states=hidden_states,
810
- attention_mask=attention_mask,
811
- position_ids=position_ids,
812
- past_key_value=past_key_value,
813
- output_attentions=output_attentions,
814
- use_cache=use_cache,
815
- cu_seqlens=cu_seqlens,
816
- max_seqlen=max_seqlen,
817
- )
818
- hidden_states = residual + hidden_states
819
-
820
- # Fully Connected
821
- residual = hidden_states
822
- hidden_states = self.post_attention_layernorm(hidden_states)
823
- hidden_states = self.mlp(hidden_states)
824
- hidden_states = residual + hidden_states
825
-
826
- outputs = (hidden_states,)
827
-
828
- if output_attentions:
829
- outputs += (self_attn_weights,)
830
-
831
- if use_cache:
832
- outputs += (present_key_value,)
833
-
834
- return outputs
835
-
836
-
837
- MISTRAL_START_DOCSTRING = r"""
838
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
839
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
840
- etc.)
841
-
842
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
843
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
844
- and behavior.
845
-
846
- Parameters:
847
- config ([`MixtralConfig`]):
848
- Model configuration class with all the parameters of the model. Initializing with a config file does not
849
- load the weights associated with the model, only the configuration. Check out the
850
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
851
- """
852
-
853
-
854
- @add_start_docstrings(
855
- "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
856
- MISTRAL_START_DOCSTRING,
857
- )
858
- class MixtralPreTrainedModel(PreTrainedModel):
859
- config_class = MixtralConfig
860
- base_model_prefix = "model"
861
- supports_gradient_checkpointing = True
862
- _no_split_modules = ["MixtralDecoderLayer"]
863
- _skip_keys_device_placement = "past_key_values"
864
- _supports_flash_attn_2 = True
865
- _supports_cache_class = True
866
-
867
- def _init_weights(self, module):
868
- std = self.config.initializer_range
869
- if isinstance(module, nn.Linear):
870
- module.weight.data.normal_(mean=0.0, std=std)
871
- if module.bias is not None:
872
- module.bias.data.zero_()
873
- elif isinstance(module, nn.Embedding):
874
- module.weight.data.normal_(mean=0.0, std=std)
875
- if module.padding_idx is not None:
876
- module.weight.data[module.padding_idx].zero_()
877
-
878
-
879
- MISTRAL_INPUTS_DOCSTRING = r"""
880
- Args:
881
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
882
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
883
- it.
884
-
885
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
886
- [`PreTrainedTokenizer.__call__`] for details.
887
-
888
- [What are input IDs?](../glossary#input-ids)
889
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
890
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
891
-
892
- - 1 for tokens that are **not masked**,
893
- - 0 for tokens that are **masked**.
894
-
895
- [What are attention masks?](../glossary#attention-mask)
896
-
897
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
898
- [`PreTrainedTokenizer.__call__`] for details.
899
-
900
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
901
- `past_key_values`).
902
-
903
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
904
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
905
- information on the default strategy.
906
-
907
- - 1 indicates the head is **not masked**,
908
- - 0 indicates the head is **masked**.
909
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
910
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
911
- config.n_positions - 1]`.
912
-
913
- [What are position IDs?](../glossary#position-ids)
914
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
915
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
916
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
917
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
918
-
919
- Two formats are allowed:
920
- - a [`~cache_utils.Cache`] instance;
921
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
922
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
923
- cache format.
924
-
925
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
926
- legacy cache format will be returned.
927
-
928
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
929
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
930
- of shape `(batch_size, sequence_length)`.
931
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
932
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
933
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
934
- model's internal embedding lookup matrix.
935
- use_cache (`bool`, *optional*):
936
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
937
- `past_key_values`).
938
- output_attentions (`bool`, *optional*):
939
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
940
- tensors for more detail.
941
- output_hidden_states (`bool`, *optional*):
942
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
943
- more detail.
944
- return_dict (`bool`, *optional*):
945
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
946
- """
947
-
948
-
949
- @add_start_docstrings(
950
- "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
951
- MISTRAL_START_DOCSTRING,
952
- )
953
- class MistralModel(MixtralPreTrainedModel):
954
- """
955
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]
956
-
957
- Args:
958
- config: MixtralConfig
959
- """
960
-
961
- def __init__(self, config: MixtralConfig):
962
- super().__init__(config)
963
- self.padding_idx = config.pad_token_id
964
- self.vocab_size = config.vocab_size
965
-
966
- self.embed_tokens = nn.Embedding(
967
- config.vocab_size, config.hidden_size, self.padding_idx
968
- )
969
- self.layers = nn.ModuleList(
970
- [
971
- MixtralDecoderLayer(config, layer_idx)
972
- for layer_idx in range(config.num_hidden_layers)
973
- ]
974
- )
975
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
976
- self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
977
-
978
- self.gradient_checkpointing = False
979
- # Initialize weights and apply final processing
980
- self.post_init()
981
-
982
- def get_input_embeddings(self):
983
- return self.embed_tokens
984
-
985
- def set_input_embeddings(self, value):
986
- self.embed_tokens = value
987
-
988
- @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
989
- def forward(
990
- self,
991
- input_ids: torch.LongTensor = None,
992
- attention_mask: Optional[torch.Tensor] = None,
993
- position_ids: Optional[torch.LongTensor] = None,
994
- past_key_values: Optional[List[torch.FloatTensor]] = None,
995
- inputs_embeds: Optional[torch.FloatTensor] = None,
996
- use_cache: Optional[bool] = None,
997
- output_attentions: Optional[bool] = None,
998
- output_hidden_states: Optional[bool] = None,
999
- return_dict: Optional[bool] = None,
1000
- ) -> Union[Tuple, BaseModelOutputWithPast]:
1001
- output_attentions = (
1002
- output_attentions
1003
- if output_attentions is not None
1004
- else self.config.output_attentions
1005
- )
1006
- output_hidden_states = (
1007
- output_hidden_states
1008
- if output_hidden_states is not None
1009
- else self.config.output_hidden_states
1010
- )
1011
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1012
-
1013
- return_dict = (
1014
- return_dict if return_dict is not None else self.config.use_return_dict
1015
- )
1016
-
1017
- # retrieve input_ids and inputs_embeds
1018
- if input_ids is not None and inputs_embeds is not None:
1019
- raise ValueError(
1020
- "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
1021
- )
1022
- elif input_ids is not None:
1023
- batch_size, seq_length = input_ids.shape
1024
- elif inputs_embeds is not None:
1025
- batch_size, seq_length, _ = inputs_embeds.shape
1026
- else:
1027
- raise ValueError(
1028
- "You have to specify either decoder_input_ids or decoder_inputs_embeds"
1029
- )
1030
-
1031
- seq_length_with_past = seq_length
1032
- past_key_values_length = 0
1033
-
1034
- if use_cache:
1035
- use_legacy_cache = not isinstance(past_key_values, Cache)
1036
- if use_legacy_cache:
1037
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1038
- past_key_values_length = past_key_values.get_seq_length()
1039
- seq_length_with_past = seq_length_with_past + past_key_values_length
1040
-
1041
- cu_seqlens = None
1042
- max_seqlen = None
1043
- if position_ids is None:
1044
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1045
- position_ids = torch.arange(
1046
- past_key_values_length,
1047
- seq_length + past_key_values_length,
1048
- dtype=torch.long,
1049
- device=device,
1050
- )
1051
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1052
- else:
1053
- position_ids = position_ids.view(-1, seq_length).long()
1054
- cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
1055
- cu_seqlens = cu_seqlens.squeeze()
1056
-
1057
- if inputs_embeds is None:
1058
- inputs_embeds = self.embed_tokens(input_ids)
1059
-
1060
- if (
1061
- attention_mask is not None
1062
- and hasattr(self.config, "_flash_attn_2_enabled")
1063
- and self.config._flash_attn_2_enabled
1064
- and use_cache
1065
- ):
1066
- is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1067
- if is_padding_right:
1068
- raise ValueError(
1069
- "You are attempting to perform batched generation with padding_side='right'"
1070
- " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
1071
- " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1072
- )
1073
-
1074
- if getattr(self.config, "_flash_attn_2_enabled", False):
1075
- # 2d mask is passed through the layers
1076
- attention_mask = (
1077
- attention_mask
1078
- if (attention_mask is not None and 0 in attention_mask)
1079
- else None
1080
- )
1081
- else:
1082
- # 4d mask is passed through the layers
1083
- attention_mask = _prepare_4d_causal_attention_mask(
1084
- attention_mask,
1085
- (batch_size, seq_length),
1086
- inputs_embeds,
1087
- past_key_values_length,
1088
- )
1089
-
1090
- hidden_states = inputs_embeds
1091
-
1092
- if self.gradient_checkpointing and self.training:
1093
- if use_cache:
1094
- logger.warning_once(
1095
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1096
- )
1097
- use_cache = False
1098
-
1099
- # decoder layers
1100
- all_hidden_states = () if output_hidden_states else None
1101
- all_self_attns = () if output_attentions else None
1102
- next_decoder_cache = None
1103
-
1104
- for decoder_layer in self.layers:
1105
- if output_hidden_states:
1106
- all_hidden_states += (hidden_states,)
1107
-
1108
- if self.gradient_checkpointing and self.training:
1109
- layer_outputs = self._gradient_checkpointing_func(
1110
- decoder_layer.__call__,
1111
- hidden_states,
1112
- attention_mask,
1113
- position_ids,
1114
- past_key_values,
1115
- output_attentions,
1116
- use_cache,
1117
- cu_seqlens,
1118
- max_seqlen,
1119
- )
1120
- else:
1121
- layer_outputs = decoder_layer(
1122
- hidden_states,
1123
- attention_mask=attention_mask,
1124
- position_ids=position_ids,
1125
- past_key_value=past_key_values,
1126
- output_attentions=output_attentions,
1127
- use_cache=use_cache,
1128
- cu_seqlens=cu_seqlens,
1129
- max_seqlen=max_seqlen,
1130
- )
1131
-
1132
- hidden_states = layer_outputs[0]
1133
-
1134
- if use_cache:
1135
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1136
-
1137
- if output_attentions:
1138
- all_self_attns += (layer_outputs[1],)
1139
-
1140
- hidden_states = self.norm(hidden_states)
1141
-
1142
- # add hidden states from the last decoder layer
1143
- if output_hidden_states:
1144
- all_hidden_states += (hidden_states,)
1145
-
1146
- next_cache = None
1147
- if use_cache:
1148
- next_cache = (
1149
- next_decoder_cache.to_legacy_cache()
1150
- if use_legacy_cache
1151
- else next_decoder_cache
1152
- )
1153
-
1154
- if not return_dict:
1155
- return tuple(
1156
- v
1157
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1158
- if v is not None
1159
- )
1160
- return BaseModelOutputWithPast(
1161
- last_hidden_state=hidden_states,
1162
- past_key_values=next_cache,
1163
- hidden_states=all_hidden_states,
1164
- attentions=all_self_attns,
1165
- )
1166
-
1167
-
1168
- class MixtralForCausalLM(MixtralPreTrainedModel):
1169
- _tied_weights_keys = ["lm_head.weight"]
1170
-
1171
- def __init__(self, config):
1172
- super().__init__(config)
1173
- self.model = MistralModel(config)
1174
- self.vocab_size = config.vocab_size
1175
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1176
-
1177
- # Initialize weights and apply final processing
1178
- self.post_init()
1179
-
1180
- def get_input_embeddings(self):
1181
- return self.model.embed_tokens
1182
-
1183
- def set_input_embeddings(self, value):
1184
- self.model.embed_tokens = value
1185
-
1186
- def get_output_embeddings(self):
1187
- return self.lm_head
1188
-
1189
- def set_output_embeddings(self, new_embeddings):
1190
- self.lm_head = new_embeddings
1191
-
1192
- def set_decoder(self, decoder):
1193
- self.model = decoder
1194
-
1195
- def get_decoder(self):
1196
- return self.model
1197
-
1198
- def _init_weights(self, module):
1199
- return
1200
-
1201
- @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
1202
- @replace_return_docstrings(
1203
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1204
- )
1205
- def forward(
1206
- self,
1207
- input_ids: torch.LongTensor = None,
1208
- attention_mask: Optional[torch.Tensor] = None,
1209
- position_ids: Optional[torch.LongTensor] = None,
1210
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1211
- inputs_embeds: Optional[torch.FloatTensor] = None,
1212
- labels: Optional[torch.LongTensor] = None,
1213
- use_cache: Optional[bool] = None,
1214
- output_attentions: Optional[bool] = None,
1215
- output_hidden_states: Optional[bool] = None,
1216
- return_dict: Optional[bool] = None,
1217
- ) -> Union[Tuple, CausalLMOutputWithPast]:
1218
- r"""
1219
- Args:
1220
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1221
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1222
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1223
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1224
-
1225
- Returns:
1226
-
1227
- Example:
1228
-
1229
- ```python
1230
- >>> from transformers import AutoTokenizer, MistralForCausalLM
1231
-
1232
- >>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1233
- >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1234
-
1235
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
1236
- >>> inputs = tokenizer(prompt, return_tensors="pt")
1237
-
1238
- >>> # Generate
1239
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1240
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1241
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1242
- ```"""
1243
-
1244
- output_attentions = (
1245
- output_attentions
1246
- if output_attentions is not None
1247
- else self.config.output_attentions
1248
- )
1249
- output_hidden_states = (
1250
- output_hidden_states
1251
- if output_hidden_states is not None
1252
- else self.config.output_hidden_states
1253
- )
1254
- return_dict = (
1255
- return_dict if return_dict is not None else self.config.use_return_dict
1256
- )
1257
-
1258
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1259
- outputs = self.model(
1260
- input_ids=input_ids,
1261
- attention_mask=attention_mask,
1262
- position_ids=position_ids,
1263
- past_key_values=past_key_values,
1264
- inputs_embeds=inputs_embeds,
1265
- use_cache=use_cache,
1266
- output_attentions=output_attentions,
1267
- output_hidden_states=output_hidden_states,
1268
- return_dict=return_dict,
1269
- )
1270
-
1271
- hidden_states = outputs[0]
1272
- logits = self.lm_head(hidden_states)
1273
- logits = logits.float()
1274
-
1275
- loss = None
1276
- if labels is not None:
1277
- # Shift so that tokens < n predict n
1278
- shift_logits = logits[..., :-1, :].contiguous()
1279
- shift_labels = labels[..., 1:].contiguous()
1280
- # Flatten the tokens
1281
- loss_fct = CrossEntropyLoss()
1282
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1283
- shift_labels = shift_labels.view(-1)
1284
- # Enable model parallelism
1285
- shift_labels = shift_labels.to(shift_logits.device)
1286
- loss = loss_fct(shift_logits, shift_labels)
1287
-
1288
- if not return_dict:
1289
- output = (logits,) + outputs[1:]
1290
- return (loss,) + output if loss is not None else output
1291
-
1292
- return CausalLMOutputWithPast(
1293
- loss=loss,
1294
- logits=logits,
1295
- past_key_values=outputs.past_key_values,
1296
- hidden_states=outputs.hidden_states,
1297
- attentions=outputs.attentions,
1298
- )
1299
-
1300
- def prepare_inputs_for_generation(
1301
- self,
1302
- input_ids,
1303
- past_key_values=None,
1304
- attention_mask=None,
1305
- inputs_embeds=None,
1306
- **kwargs,
1307
- ):
1308
- # Omit tokens covered by past_key_values
1309
- if past_key_values is not None:
1310
- if isinstance(past_key_values, Cache):
1311
- cache_length = past_key_values.get_seq_length()
1312
- past_length = past_key_values.seen_tokens
1313
- else:
1314
- cache_length = past_length = past_key_values[0][0].shape[2]
1315
-
1316
- # Keep only the unprocessed tokens:
1317
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1318
- # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1319
- # input)
1320
- if (
1321
- attention_mask is not None
1322
- and attention_mask.shape[1] > input_ids.shape[1]
1323
- ):
1324
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1325
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1326
- # input_ids based on the past_length.
1327
- elif past_length < input_ids.shape[1]:
1328
- input_ids = input_ids[:, past_length:]
1329
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1330
-
1331
- # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
1332
- # older attention values, as their corresponding values are not part of the input.
1333
- if cache_length < past_length and attention_mask is not None:
1334
- attention_mask = attention_mask[
1335
- :, -(cache_length + input_ids.shape[1]) :
1336
- ]
1337
-
1338
- position_ids = kwargs.get("position_ids", None)
1339
- if attention_mask is not None and position_ids is None:
1340
- # create position_ids on the fly for batch generation
1341
- position_ids = attention_mask.long().cumsum(-1) - 1
1342
- position_ids.masked_fill_(attention_mask == 0, 1)
1343
- if past_key_values:
1344
- position_ids = position_ids[:, -input_ids.shape[1] :]
1345
-
1346
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1347
- if inputs_embeds is not None and past_key_values is None:
1348
- model_inputs = {"inputs_embeds": inputs_embeds}
1349
- else:
1350
- model_inputs = {"input_ids": input_ids}
1351
-
1352
- model_inputs.update(
1353
- {
1354
- "position_ids": position_ids,
1355
- "past_key_values": past_key_values,
1356
- "use_cache": kwargs.get("use_cache"),
1357
- "attention_mask": attention_mask,
1358
- }
1359
- )
1360
- return model_inputs
1361
-
1362
- @staticmethod
1363
- def _reorder_cache(past_key_values, beam_idx):
1364
- reordered_past = ()
1365
- for layer_past in past_key_values:
1366
- reordered_past += (
1367
- tuple(
1368
- past_state.index_select(0, beam_idx.to(past_state.device))
1369
- for past_state in layer_past
1370
- ),
1371
- )
1372
- return reordered_past
1373
-
1374
-
1375
- @add_start_docstrings(
1376
- """
1377
- The Mistral Model transformer with a sequence classification head on top (linear layer).
1378
-
1379
- [`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1380
- (e.g. GPT-2) do.
1381
-
1382
- Since it does classification on the last token, it requires to know the position of the last token. If a
1383
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1384
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1385
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1386
- each row of the batch).
1387
- """,
1388
- MISTRAL_START_DOCSTRING,
1389
- )
1390
- # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL
1391
- class MistralForSequenceClassification(MixtralPreTrainedModel):
1392
- def __init__(self, config):
1393
- super().__init__(config)
1394
- self.num_labels = config.num_labels
1395
- self.model = MistralModel(config)
1396
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1397
-
1398
- # Initialize weights and apply final processing
1399
- self.post_init()
1400
-
1401
- def get_input_embeddings(self):
1402
- return self.model.embed_tokens
1403
-
1404
- def set_input_embeddings(self, value):
1405
- self.model.embed_tokens = value
1406
-
1407
- @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
1408
- def forward(
1409
- self,
1410
- input_ids: torch.LongTensor = None,
1411
- attention_mask: Optional[torch.Tensor] = None,
1412
- position_ids: Optional[torch.LongTensor] = None,
1413
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1414
- inputs_embeds: Optional[torch.FloatTensor] = None,
1415
- labels: Optional[torch.LongTensor] = None,
1416
- use_cache: Optional[bool] = None,
1417
- output_attentions: Optional[bool] = None,
1418
- output_hidden_states: Optional[bool] = None,
1419
- return_dict: Optional[bool] = None,
1420
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1421
- r"""
1422
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1423
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1424
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1425
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1426
- """
1427
- return_dict = (
1428
- return_dict if return_dict is not None else self.config.use_return_dict
1429
- )
1430
-
1431
- transformer_outputs = self.model(
1432
- input_ids,
1433
- attention_mask=attention_mask,
1434
- position_ids=position_ids,
1435
- past_key_values=past_key_values,
1436
- inputs_embeds=inputs_embeds,
1437
- use_cache=use_cache,
1438
- output_attentions=output_attentions,
1439
- output_hidden_states=output_hidden_states,
1440
- return_dict=return_dict,
1441
- )
1442
- hidden_states = transformer_outputs[0]
1443
- logits = self.score(hidden_states)
1444
-
1445
- if input_ids is not None:
1446
- batch_size = input_ids.shape[0]
1447
- else:
1448
- batch_size = inputs_embeds.shape[0]
1449
-
1450
- if self.config.pad_token_id is None and batch_size != 1:
1451
- raise ValueError(
1452
- "Cannot handle batch sizes > 1 if no padding token is defined."
1453
- )
1454
- if self.config.pad_token_id is None:
1455
- sequence_lengths = -1
1456
- else:
1457
- if input_ids is not None:
1458
- sequence_lengths = (
1459
- torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1460
- ).to(logits.device)
1461
- else:
1462
- sequence_lengths = -1
1463
-
1464
- pooled_logits = logits[
1465
- torch.arange(batch_size, device=logits.device), sequence_lengths
1466
- ]
1467
-
1468
- loss = None
1469
- if labels is not None:
1470
- labels = labels.to(logits.device)
1471
- if self.config.problem_type is None:
1472
- if self.num_labels == 1:
1473
- self.config.problem_type = "regression"
1474
- elif self.num_labels > 1 and (
1475
- labels.dtype == torch.long or labels.dtype == torch.int
1476
- ):
1477
- self.config.problem_type = "single_label_classification"
1478
- else:
1479
- self.config.problem_type = "multi_label_classification"
1480
-
1481
- if self.config.problem_type == "regression":
1482
- loss_fct = MSELoss()
1483
- if self.num_labels == 1:
1484
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1485
- else:
1486
- loss = loss_fct(pooled_logits, labels)
1487
- elif self.config.problem_type == "single_label_classification":
1488
- loss_fct = CrossEntropyLoss()
1489
- loss = loss_fct(
1490
- pooled_logits.view(-1, self.num_labels), labels.view(-1)
1491
- )
1492
- elif self.config.problem_type == "multi_label_classification":
1493
- loss_fct = BCEWithLogitsLoss()
1494
- loss = loss_fct(pooled_logits, labels)
1495
- if not return_dict:
1496
- output = (pooled_logits,) + transformer_outputs[1:]
1497
- return ((loss,) + output) if loss is not None else output
1498
-
1499
- return SequenceClassifierOutputWithPast(
1500
- loss=loss,
1501
- logits=pooled_logits,
1502
- past_key_values=transformer_outputs.past_key_values,
1503
- hidden_states=transformer_outputs.hidden_states,
1504
- attentions=transformer_outputs.attentions,
1505
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/axolotl/monkeypatch/mixtral/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Patches to support multipack for mixtral
3
+ """
4
+ import transformers
5
+
6
+
7
+ def replace_mixtral_attn_with_multipack_flash_attn():
8
+ from .modeling_mixtral import (
9
+ MixtralMultipackFlashAttention2,
10
+ mixtral_decoder_layer_forward,
11
+ mixtral_model_forward,
12
+ )
13
+
14
+ transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward = (
15
+ mixtral_decoder_layer_forward
16
+ )
17
+ transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
18
+ mixtral_model_forward
19
+ )
20
+ transformers.models.mixtral.modeling_mixtral.MISTRAL_ATTENTION_CLASSES[
21
+ "flash_attention_2"
22
+ ] = MixtralMultipackFlashAttention2
src/axolotl/monkeypatch/mixtral/modeling_mixtral.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mixtral modeling for multipack
3
+ """
4
+ # pylint: disable=missing-module-docstring,unused-argument,protected-access,pointless-string-statement,duplicate-code
5
+ import logging
6
+ import warnings
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ from einops import rearrange
11
+ from flash_attn import flash_attn_varlen_qkvpacked_func
12
+ from transformers import Cache, DynamicCache
13
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
14
+ from transformers.modeling_outputs import MoeModelOutputWithPast
15
+ from transformers.models.mixtral.modeling_mixtral import (
16
+ MixtralFlashAttention2,
17
+ apply_rotary_pos_emb,
18
+ repeat_kv,
19
+ )
20
+
21
+ from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
22
+
23
+ LOG = logging.getLogger("axolotl.monkeypatch.mixtral")
24
+
25
+
26
+ class MixtralMultipackFlashAttention2(MixtralFlashAttention2):
27
+ """
28
+ Custom multipack implementation w flash attention 2
29
+ """
30
+
31
+ def __init__(self, *args, **kwargs):
32
+ super().__init__(*args, **kwargs)
33
+ self._flash_attn_uses_top_left_mask = True
34
+
35
+ def forward(
36
+ self,
37
+ hidden_states: torch.Tensor,
38
+ attention_mask: Optional[torch.Tensor] = None,
39
+ position_ids: Optional[torch.LongTensor] = None,
40
+ past_key_value: Optional[Cache] = None,
41
+ output_attentions: bool = False,
42
+ use_cache: bool = False,
43
+ cu_seqlens: Optional[torch.Tensor] = None,
44
+ max_seqlen: Optional[torch.Tensor] = None,
45
+ **kwargs,
46
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
47
+ if "padding_mask" in kwargs:
48
+ warnings.warn(
49
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
50
+ )
51
+ bsz, q_len, _ = hidden_states.size()
52
+
53
+ query_states = self.q_proj(hidden_states)
54
+ key_states = self.k_proj(hidden_states)
55
+ value_states = self.v_proj(hidden_states)
56
+
57
+ query_states = query_states.view(
58
+ bsz, q_len, self.num_heads, self.head_dim
59
+ ).transpose(1, 2)
60
+ key_states = key_states.view(
61
+ bsz, q_len, self.num_key_value_heads, self.head_dim
62
+ ).transpose(1, 2)
63
+ value_states = value_states.view(
64
+ bsz, q_len, self.num_key_value_heads, self.head_dim
65
+ ).transpose(1, 2)
66
+
67
+ kv_seq_len = key_states.shape[-2]
68
+ if past_key_value is not None:
69
+ if self.layer_idx is None:
70
+ raise ValueError(
71
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
72
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
73
+ "with a layer index."
74
+ )
75
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
76
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
77
+ query_states, key_states = apply_rotary_pos_emb(
78
+ query_states, key_states, cos, sin, position_ids
79
+ )
80
+
81
+ if past_key_value is not None:
82
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
83
+ key_states, value_states = past_key_value.update(
84
+ key_states, value_states, self.layer_idx, cache_kwargs
85
+ )
86
+
87
+ # repeat k/v heads if n_kv_heads < n_heads
88
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
89
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
90
+
91
+ if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
92
+ # special handling using sample packing
93
+ qkv = torch.stack(
94
+ [query_states, key_states, value_states], dim=2
95
+ ) # [bsz, nh, 3, q_len, hd]
96
+ qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
97
+ qkv = rearrange(qkv, "b s ... -> (b s) ...")
98
+
99
+ attn_output = flash_attn_varlen_qkvpacked_func(
100
+ qkv,
101
+ cu_seqlens,
102
+ max_seqlen,
103
+ dropout_p=self.attention_dropout,
104
+ softmax_scale=None,
105
+ causal=True,
106
+ )
107
+ attn_output = rearrange(attn_output, "(b s) ... -> b s ...", b=bsz)
108
+
109
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
110
+ attn_output = self.o_proj(attn_output)
111
+
112
+ if not output_attentions:
113
+ attn_weights = None
114
+
115
+ return attn_output, attn_weights, past_key_value
116
+
117
+
118
+ def mixtral_decoder_layer_forward(
119
+ self,
120
+ hidden_states: torch.Tensor,
121
+ attention_mask: Optional[torch.Tensor] = None,
122
+ position_ids: Optional[torch.LongTensor] = None,
123
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
124
+ output_attentions: Optional[bool] = False,
125
+ output_router_logits: Optional[bool] = False,
126
+ use_cache: Optional[bool] = False,
127
+ cu_seqlens: Optional[torch.Tensor] = None,
128
+ max_seqlen: Optional[torch.Tensor] = None,
129
+ **kwargs,
130
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
131
+ if "padding_mask" in kwargs:
132
+ warnings.warn(
133
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
134
+ )
135
+ """
136
+ Args:
137
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
138
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
139
+ `(batch, sequence_length)` where padding elements are indicated by 0.
140
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
141
+ output_attentions (`bool`, *optional*):
142
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
143
+ returned tensors for more detail.
144
+ output_router_logits (`bool`, *optional*):
145
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
146
+ should not be returned during inference.
147
+ use_cache (`bool`, *optional*):
148
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
149
+ (see `past_key_values`).
150
+ """
151
+
152
+ residual = hidden_states
153
+
154
+ hidden_states = self.input_layernorm(hidden_states)
155
+
156
+ # Self Attention
157
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
158
+ hidden_states=hidden_states,
159
+ attention_mask=attention_mask,
160
+ position_ids=position_ids,
161
+ past_key_value=past_key_value,
162
+ output_attentions=output_attentions,
163
+ use_cache=use_cache,
164
+ cu_seqlens=cu_seqlens,
165
+ max_seqlen=max_seqlen,
166
+ )
167
+ hidden_states = residual + hidden_states
168
+
169
+ # Fully Connected
170
+ residual = hidden_states
171
+ hidden_states = self.post_attention_layernorm(hidden_states)
172
+ hidden_states, router_logits = self.block_sparse_moe(hidden_states)
173
+ hidden_states = residual + hidden_states
174
+
175
+ outputs = (hidden_states,)
176
+
177
+ if output_attentions:
178
+ outputs += (self_attn_weights,)
179
+
180
+ if use_cache:
181
+ outputs += (present_key_value,)
182
+
183
+ if output_router_logits:
184
+ outputs += (router_logits,)
185
+
186
+ return outputs
187
+
188
+
189
+ def mixtral_model_forward(
190
+ self,
191
+ input_ids: torch.LongTensor = None,
192
+ attention_mask: Optional[torch.Tensor] = None,
193
+ position_ids: Optional[torch.LongTensor] = None,
194
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
195
+ inputs_embeds: Optional[torch.FloatTensor] = None,
196
+ use_cache: Optional[bool] = None,
197
+ output_attentions: Optional[bool] = None,
198
+ output_hidden_states: Optional[bool] = None,
199
+ output_router_logits: Optional[bool] = None,
200
+ return_dict: Optional[bool] = None,
201
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
202
+ output_attentions = (
203
+ output_attentions
204
+ if output_attentions is not None
205
+ else self.config.output_attentions
206
+ )
207
+ output_router_logits = (
208
+ output_router_logits
209
+ if output_router_logits is not None
210
+ else self.config.output_router_logits
211
+ )
212
+ output_hidden_states = (
213
+ output_hidden_states
214
+ if output_hidden_states is not None
215
+ else self.config.output_hidden_states
216
+ )
217
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
218
+
219
+ return_dict = (
220
+ return_dict if return_dict is not None else self.config.use_return_dict
221
+ )
222
+
223
+ # retrieve input_ids and inputs_embeds
224
+ if input_ids is not None and inputs_embeds is not None:
225
+ raise ValueError(
226
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
227
+ )
228
+ if input_ids is not None:
229
+ batch_size, seq_length = input_ids.shape
230
+ elif inputs_embeds is not None:
231
+ batch_size, seq_length, _ = inputs_embeds.shape
232
+ else:
233
+ raise ValueError(
234
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
235
+ )
236
+
237
+ past_key_values_length = 0
238
+
239
+ if use_cache:
240
+ use_legacy_cache = not isinstance(past_key_values, Cache)
241
+ if use_legacy_cache:
242
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
243
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
244
+
245
+ cu_seqlens = None
246
+ max_seqlen = None
247
+ if position_ids is None:
248
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
249
+ position_ids = torch.arange(
250
+ past_key_values_length,
251
+ seq_length + past_key_values_length,
252
+ dtype=torch.long,
253
+ device=device,
254
+ )
255
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
256
+ else:
257
+ position_ids = position_ids.view(-1, seq_length).long()
258
+ cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
259
+ cu_seqlens = cu_seqlens.squeeze()
260
+
261
+ if inputs_embeds is None:
262
+ inputs_embeds = self.embed_tokens(input_ids)
263
+
264
+ if attention_mask is not None and self._use_flash_attention_2 and use_cache:
265
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
266
+ if is_padding_right:
267
+ raise ValueError(
268
+ "You are attempting to perform batched generation with padding_side='right'"
269
+ " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
270
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
271
+ )
272
+
273
+ if self._use_flash_attention_2:
274
+ # 2d mask is passed through the layers
275
+ attention_mask = (
276
+ attention_mask
277
+ if (attention_mask is not None and 0 in attention_mask)
278
+ else None
279
+ )
280
+ else:
281
+ # 4d mask is passed through the layers
282
+ attention_mask = _prepare_4d_causal_attention_mask(
283
+ attention_mask,
284
+ (batch_size, seq_length),
285
+ inputs_embeds,
286
+ past_key_values_length,
287
+ sliding_window=self.config.sliding_window,
288
+ )
289
+
290
+ hidden_states = inputs_embeds
291
+
292
+ if self.gradient_checkpointing and self.training:
293
+ if use_cache:
294
+ LOG.warning_once(
295
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
296
+ )
297
+ use_cache = False
298
+
299
+ # decoder layers
300
+ all_hidden_states = () if output_hidden_states else None
301
+ all_self_attns = () if output_attentions else None
302
+ all_router_logits = () if output_router_logits else None
303
+ next_decoder_cache = None
304
+
305
+ for decoder_layer in self.layers:
306
+ if output_hidden_states:
307
+ all_hidden_states += (hidden_states,)
308
+
309
+ if self.gradient_checkpointing and self.training:
310
+ layer_outputs = self._gradient_checkpointing_func(
311
+ decoder_layer.__call__,
312
+ hidden_states,
313
+ attention_mask,
314
+ position_ids,
315
+ past_key_values,
316
+ output_attentions,
317
+ output_router_logits,
318
+ use_cache,
319
+ cu_seqlens,
320
+ max_seqlen,
321
+ )
322
+ else:
323
+ layer_outputs = decoder_layer(
324
+ hidden_states,
325
+ attention_mask=attention_mask,
326
+ position_ids=position_ids,
327
+ past_key_value=past_key_values,
328
+ output_attentions=output_attentions,
329
+ output_router_logits=output_router_logits,
330
+ use_cache=use_cache,
331
+ cu_seqlens=cu_seqlens,
332
+ max_seqlen=max_seqlen,
333
+ )
334
+
335
+ hidden_states = layer_outputs[0]
336
+
337
+ if use_cache:
338
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
339
+
340
+ if output_attentions:
341
+ all_self_attns += (layer_outputs[1],)
342
+
343
+ if output_router_logits:
344
+ all_router_logits += (layer_outputs[-1],)
345
+
346
+ hidden_states = self.norm(hidden_states)
347
+
348
+ # add hidden states from the last decoder layer
349
+ if output_hidden_states:
350
+ all_hidden_states += (hidden_states,)
351
+
352
+ next_cache = None
353
+ if use_cache:
354
+ next_cache = (
355
+ next_decoder_cache.to_legacy_cache()
356
+ if use_legacy_cache
357
+ else next_decoder_cache
358
+ )
359
+
360
+ if not return_dict:
361
+ return tuple(
362
+ v
363
+ for v in [
364
+ hidden_states,
365
+ next_cache,
366
+ all_hidden_states,
367
+ all_self_attns,
368
+ all_router_logits,
369
+ ]
370
+ if v is not None
371
+ )
372
+
373
+ return MoeModelOutputWithPast(
374
+ last_hidden_state=hidden_states,
375
+ past_key_values=next_cache,
376
+ hidden_states=all_hidden_states,
377
+ attentions=all_self_attns,
378
+ router_logits=all_router_logits,
379
+ )
src/axolotl/utils/models.py CHANGED
@@ -54,25 +54,19 @@ def check_model_config(cfg: DictDefault, model_config: AutoConfig):
54
  def load_model_config(cfg):
55
  model_config_name = cfg.base_model_config or cfg.base_model
56
  trust_remote_code = cfg.trust_remote_code is True
57
- model_type = cfg.model_type
58
-
59
- if model_type == "MixtralForCausalLM":
60
- from axolotl.models.mixtral.configuration_moe_mistral import MixtralConfig
61
 
62
- model_config = MixtralConfig.from_pretrained(model_config_name)
63
- else:
64
- try:
65
- model_config = AutoConfig.from_pretrained(
66
- model_config_name, trust_remote_code=trust_remote_code
 
 
 
 
 
67
  )
68
- except ValueError as err:
69
- if "mamba" in model_config_name:
70
- return addict.Dict(
71
- {
72
- "model_type": "mamba",
73
- }
74
- )
75
- raise err
76
 
77
  if cfg.model_config:
78
  for key, val in cfg.model_config.items():
@@ -255,6 +249,18 @@ def load_model(
255
  LOG.info("patching with flash attention")
256
  replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
257
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  if cfg.is_llama_derived_model and cfg.xpos_rope:
259
  from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
260
  replace_llama_rope_with_xpos_rope,
@@ -302,15 +308,22 @@ def load_model(
302
  bnb_4bit_quant_type="nf4",
303
  )
304
  # sample packing uses custom FA2 patch
305
- if cfg.flash_attention and not cfg.sample_packing:
306
- if (
307
- cfg.is_llama_derived_model
308
- or cfg.is_falcon_derived_model
309
- or cfg.is_mistral_derived_model
310
- ):
311
- # TODO enable once properly supported in transformers
312
- # model_kwargs["attn_implementation"] = "flash_attention_2"
313
- model_kwargs["use_flash_attention_2"] = True # legacy, to be deprecated
 
 
 
 
 
 
 
314
 
315
  try:
316
  if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
@@ -372,15 +385,6 @@ def load_model(
372
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
373
  **model_kwargs,
374
  )
375
- elif model_type == "MixtralForCausalLM":
376
- from axolotl.models.mixtral import MixtralForCausalLM
377
-
378
- model = MixtralForCausalLM.from_pretrained(
379
- base_model,
380
- load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
381
- load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
382
- **model_kwargs,
383
- )
384
  elif model_type == "MambaLMHeadModel":
385
  # FIXME this is janky at best and hacked together to make it work
386
  MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
 
54
  def load_model_config(cfg):
55
  model_config_name = cfg.base_model_config or cfg.base_model
56
  trust_remote_code = cfg.trust_remote_code is True
 
 
 
 
57
 
58
+ try:
59
+ model_config = AutoConfig.from_pretrained(
60
+ model_config_name, trust_remote_code=trust_remote_code
61
+ )
62
+ except ValueError as err:
63
+ if "mamba" in model_config_name:
64
+ return addict.Dict(
65
+ {
66
+ "model_type": "mamba",
67
+ }
68
  )
69
+ raise err
 
 
 
 
 
 
 
70
 
71
  if cfg.model_config:
72
  for key, val in cfg.model_config.items():
 
249
  LOG.info("patching with flash attention")
250
  replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
251
 
252
+ if (
253
+ cfg.model_config_type == "mixtral"
254
+ and cfg.flash_attention
255
+ and cfg.sample_packing
256
+ ):
257
+ from axolotl.monkeypatch.mixtral import (
258
+ replace_mixtral_attn_with_multipack_flash_attn,
259
+ )
260
+
261
+ LOG.info("patching with flash attention")
262
+ replace_mixtral_attn_with_multipack_flash_attn()
263
+
264
  if cfg.is_llama_derived_model and cfg.xpos_rope:
265
  from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
266
  replace_llama_rope_with_xpos_rope,
 
308
  bnb_4bit_quant_type="nf4",
309
  )
310
  # sample packing uses custom FA2 patch
311
+ if cfg.flash_attention:
312
+ if not cfg.sample_packing:
313
+ if (
314
+ cfg.is_llama_derived_model
315
+ or cfg.is_falcon_derived_model
316
+ or cfg.is_mistral_derived_model
317
+ or model_config.model_type == "mixtral"
318
+ ):
319
+ model_config._attn_implementation = ( # pylint: disable=protected-access
320
+ "flash_attention_2"
321
+ )
322
+ else:
323
+ if model_config.model_type == "mixtral":
324
+ model_config._attn_implementation = ( # pylint: disable=protected-access
325
+ "flash_attention_2"
326
+ )
327
 
328
  try:
329
  if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
 
385
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
386
  **model_kwargs,
387
  )
 
 
 
 
 
 
 
 
 
388
  elif model_type == "MambaLMHeadModel":
389
  # FIXME this is janky at best and hacked together to make it work
390
  MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name