ver217 commited on
Commit
25d1d74
·
1 Parent(s): decbb79

init commit

Browse files
__init__.py ADDED
File without changes
configuration_grok1.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+
4
+ class Grok1Config(PretrainedConfig):
5
+ model_type = "grok-1"
6
+ keys_to_ignore_at_inference = ["past_key_values"]
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size=32000,
11
+ hidden_size=4096,
12
+ widening_factor=4.0,
13
+ num_hidden_layers=32,
14
+ num_attention_heads=32,
15
+ num_key_value_heads=32,
16
+ attn_output_multiplier=1.0,
17
+ max_attn_value=1.0,
18
+ max_position_embeddings=4096,
19
+ rms_norm_eps=1e-5,
20
+ use_cache=True,
21
+ pad_token_id=None,
22
+ bos_token_id=1,
23
+ eos_token_id=2,
24
+ tie_word_embeddings=True,
25
+ num_experts_per_tok=2,
26
+ num_experts=8,
27
+ output_router_logits=False,
28
+ router_aux_loss_coef=0.001,
29
+ **kwargs
30
+ ):
31
+ self.vocab_size = vocab_size
32
+ self.attn_output_multiplier = attn_output_multiplier
33
+ self.max_attn_value = max_attn_value
34
+ self.max_position_embeddings = max_position_embeddings
35
+ self.hidden_size = hidden_size
36
+ self.widening_factor = widening_factor
37
+ self.num_hidden_layers = num_hidden_layers
38
+ self.num_attention_heads = num_attention_heads
39
+
40
+ # for backward compatibility
41
+ if num_key_value_heads is None:
42
+ num_key_value_heads = num_attention_heads
43
+
44
+ self.num_key_value_heads = num_key_value_heads
45
+ self.rms_norm_eps = rms_norm_eps
46
+ self.use_cache = use_cache
47
+
48
+ self.num_experts_per_tok = num_experts_per_tok
49
+ self.num_experts = num_experts
50
+ self.output_router_logits = output_router_logits
51
+ self.router_aux_loss_coef = router_aux_loss_coef
52
+ super().__init__(
53
+ pad_token_id=pad_token_id,
54
+ bos_token_id=bos_token_id,
55
+ eos_token_id=eos_token_id,
56
+ tie_word_embeddings=tie_word_embeddings,
57
+ **kwargs,
58
+ )
modeling_grok1.py ADDED
@@ -0,0 +1,923 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers.modeling_utils import PreTrainedModel
7
+ from transformers.utils import logging
8
+
9
+ try:
10
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
11
+
12
+ HAS_MASK_UTILS = True
13
+ except ImportError:
14
+ HAS_MASK_UTILS = False
15
+
16
+ from .configuration_grok1 import Grok1Config
17
+ from .modeling_grok1_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ # copied from https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/models/mixtral/modeling_mixtral.py
23
+ def load_balancing_loss_func(
24
+ gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2
25
+ ) -> float:
26
+ r"""
27
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
28
+
29
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
30
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
31
+ experts is too unbalanced.
32
+
33
+ Args:
34
+ gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
35
+ Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, seqeunce_length, num_experts].
36
+ num_experts (`int`, *optional*):
37
+ Number of experts
38
+
39
+ Returns:
40
+ The auxiliary loss.
41
+ """
42
+ if gate_logits is None:
43
+ return 0
44
+
45
+ if isinstance(gate_logits, tuple):
46
+ # cat along the layers?
47
+ compute_device = gate_logits[0].device
48
+ gate_logits = torch.cat(
49
+ [gate.to(compute_device) for gate in gate_logits], dim=0
50
+ )
51
+
52
+ routing_weights, selected_experts = torch.topk(gate_logits, top_k, dim=-1)
53
+ routing_weights = routing_weights.softmax(dim=-1)
54
+
55
+ # cast the expert indices to int64, otherwise one-hot encoding will fail
56
+ if selected_experts.dtype != torch.int64:
57
+ selected_experts = selected_experts.to(torch.int64)
58
+
59
+ if len(selected_experts.shape) == 2:
60
+ selected_experts = selected_experts.unsqueeze(2)
61
+
62
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
63
+
64
+ # For a given token, determine if it was routed to a given expert.
65
+ expert_mask = torch.max(expert_mask, axis=-2).values
66
+
67
+ # cast to float32 otherwise mean will fail
68
+ expert_mask = expert_mask.to(torch.float32)
69
+ tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2)
70
+
71
+ router_prob_per_group_and_expert = torch.mean(routing_weights, axis=-1)
72
+ return torch.mean(
73
+ tokens_per_group_and_expert * router_prob_per_group_and_expert.unsqueeze(-1)
74
+ ) * (num_experts**2)
75
+
76
+
77
+ class RMSNorm(nn.Module):
78
+ def __init__(
79
+ self,
80
+ hidden_size: int,
81
+ eps: float = 1e-5,
82
+ create_scale: bool = True,
83
+ ) -> None:
84
+ super().__init__()
85
+ self.variance_epsilon = eps
86
+ if create_scale:
87
+ self.scale = nn.Parameter(torch.zeros(hidden_size))
88
+ else:
89
+ self.scale = 1.0
90
+
91
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
92
+ input_dtype = hidden_states.dtype
93
+ hidden_states = hidden_states.to(torch.float32)
94
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
95
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
96
+ hidden_states = self.scale * hidden_states
97
+ return hidden_states.to(input_dtype)
98
+
99
+
100
+ class RotaryEmbedding(nn.Module):
101
+ def __init__(
102
+ self, dim: int, max_position_embeddings: int = 2048, base: int = 10000
103
+ ) -> None:
104
+ super().__init__()
105
+ assert dim % 2 == 0
106
+ self.dim = dim
107
+ self.max_position_embeddings = max_position_embeddings
108
+ self.base = base
109
+ inv_freq = 1.0 / (
110
+ self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)
111
+ )
112
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
113
+
114
+ self._set_cos_sin_cache(
115
+ seq_len=max_position_embeddings,
116
+ device=self.inv_freq.device,
117
+ dtype=torch.get_default_dtype(),
118
+ )
119
+
120
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
121
+ self.max_seq_len_cached = seq_len
122
+ t = torch.arange(
123
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
124
+ )
125
+
126
+ freqs = torch.outer(t, self.inv_freq)
127
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
128
+ emb = torch.cat((freqs, freqs), dim=-1)
129
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
130
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
131
+
132
+ def forward(self, x, seq_len=None):
133
+ # x: [bs, num_attention_heads, seq_len, head_size]
134
+ if seq_len > self.max_seq_len_cached:
135
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
136
+
137
+ return (
138
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
139
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
140
+ )
141
+
142
+
143
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
144
+ def rotate_half(x):
145
+ """Rotates half the hidden dims of the input."""
146
+ x1 = x[..., : x.shape[-1] // 2]
147
+ x2 = x[..., x.shape[-1] // 2 :]
148
+ return torch.cat((-x2, x1), dim=-1)
149
+
150
+
151
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
152
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
153
+ """Applies Rotary Position Embedding to the query and key tensors.
154
+
155
+ Args:
156
+ q (`torch.Tensor`): The query tensor.
157
+ k (`torch.Tensor`): The key tensor.
158
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
159
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
160
+ position_ids (`torch.Tensor`):
161
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
162
+ used to pass offsetted position ids when working with a KV-cache.
163
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
164
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
165
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
166
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
167
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
168
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
169
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
170
+ Returns:
171
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
172
+ """
173
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
174
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
175
+ q_embed = (q * cos) + (rotate_half(q) * sin)
176
+ k_embed = (k * cos) + (rotate_half(k) * sin)
177
+ return q_embed, k_embed
178
+
179
+
180
+ class MultiHeadAttention(nn.Module):
181
+ def __init__(
182
+ self,
183
+ hidden_size: int,
184
+ num_heads: int,
185
+ num_key_value_heads: Optional[int] = None,
186
+ max_position_embeddings: int = 2048,
187
+ attn_output_multiplier: float = 1.0,
188
+ max_attn_val: float = 30.0,
189
+ ):
190
+ super().__init__()
191
+ self.hidden_size = hidden_size
192
+ self.num_heads = num_heads
193
+ self.head_dim = hidden_size // num_heads
194
+ if num_key_value_heads is None:
195
+ num_key_value_heads = num_heads
196
+ self.num_key_value_heads = num_key_value_heads
197
+ self.attn_output_multiplier = attn_output_multiplier
198
+ self.max_attn_val = max_attn_val
199
+
200
+ if (self.head_dim * self.num_heads) != self.hidden_size:
201
+ raise ValueError(
202
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
203
+ f" and `num_heads`: {self.num_heads})."
204
+ )
205
+
206
+ self.q_proj = nn.Linear(hidden_size, self.num_heads * self.head_dim, bias=False)
207
+ self.k_proj = nn.Linear(
208
+ hidden_size, self.num_key_value_heads * self.head_dim, bias=False
209
+ )
210
+ self.v_proj = nn.Linear(
211
+ hidden_size, self.num_key_value_heads * self.head_dim, bias=False
212
+ )
213
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, hidden_size, bias=False)
214
+
215
+ self.rotary_emb = RotaryEmbedding(
216
+ self.head_dim,
217
+ max_position_embeddings=max_position_embeddings,
218
+ )
219
+
220
+ def forward(
221
+ self,
222
+ hidden_states: torch.Tensor,
223
+ attention_mask: Optional[torch.Tensor] = None,
224
+ position_ids: Optional[torch.LongTensor] = None,
225
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
226
+ output_attentions: bool = False,
227
+ use_cache: bool = False,
228
+ **kwargs,
229
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
230
+ bsz, q_len, _ = hidden_states.size()
231
+
232
+ query_states = self.q_proj(hidden_states)
233
+ key_states = self.k_proj(hidden_states)
234
+ value_states = self.v_proj(hidden_states)
235
+
236
+ query_states = query_states.view(
237
+ bsz, q_len, self.num_heads, self.head_dim
238
+ ).transpose(1, 2)
239
+ key_states = key_states.view(
240
+ bsz, q_len, self.num_key_value_heads, self.head_dim
241
+ ).transpose(1, 2)
242
+ value_states = value_states.view(
243
+ bsz, q_len, self.num_key_value_heads, self.head_dim
244
+ ).transpose(1, 2)
245
+
246
+ kv_seq_len = key_states.shape[-2]
247
+ if past_key_value is not None:
248
+ kv_seq_len += past_key_value[0].shape[-2]
249
+
250
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
251
+ query_states, key_states = apply_rotary_pos_emb(
252
+ query_states, key_states, cos, sin, position_ids
253
+ )
254
+
255
+ if past_key_value is not None:
256
+ # reuse k, v, self_attention
257
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
258
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
259
+
260
+ past_key_value = (key_states, value_states) if use_cache else None
261
+
262
+ # TODO: repeat kv
263
+
264
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)).to(
265
+ torch.float
266
+ )
267
+ attn_weights = attn_weights * self.attn_output_multiplier
268
+ attn_weights = self.max_attn_val * F.tanh(attn_weights / self.max_attn_val)
269
+
270
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
271
+ raise ValueError(
272
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
273
+ f" {attn_weights.size()}"
274
+ )
275
+
276
+ if attention_mask is not None:
277
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
278
+ raise ValueError(
279
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
280
+ )
281
+
282
+ attn_weights = attn_weights + attention_mask
283
+
284
+ attn_weights = F.softmax(attn_weights, dim=-1).to(query_states.dtype)
285
+ attn_output = torch.matmul(attn_weights, value_states)
286
+
287
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
288
+ raise ValueError(
289
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
290
+ f" {attn_output.size()}"
291
+ )
292
+
293
+ attn_output = attn_output.transpose(1, 2).contiguous()
294
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
295
+
296
+ attn_output = self.o_proj(attn_output)
297
+
298
+ if not output_attentions:
299
+ attn_weights = None
300
+
301
+ return attn_output, attn_weights, past_key_value
302
+
303
+
304
+ class MoeMLP(nn.Module):
305
+ def __init__(
306
+ self,
307
+ hidden_dim: int,
308
+ ffn_dim: int,
309
+ ) -> None:
310
+ super().__init__()
311
+ self.linear_v = nn.Linear(hidden_dim, ffn_dim, bias=False)
312
+ self.linear_1 = nn.Linear(ffn_dim, hidden_dim, bias=False)
313
+ self.linear = nn.Linear(hidden_dim, ffn_dim, bias=False)
314
+ self.act_fn = nn.GELU()
315
+
316
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
317
+ current_hidden_states = self.act_fn(self.linear(hidden_states)) * self.linear_v(
318
+ hidden_states
319
+ )
320
+ current_hidden_states = self.linear_1(current_hidden_states)
321
+ return current_hidden_states
322
+
323
+
324
+ class MoeBlock(nn.Module):
325
+ def __init__(
326
+ self,
327
+ hidden_dim: int,
328
+ ffn_dim: int,
329
+ num_experts: int,
330
+ top_k: int,
331
+ ) -> None:
332
+ super().__init__()
333
+ self.num_experts = num_experts
334
+ self.top_k = top_k
335
+ self.gate = nn.Linear(hidden_dim, num_experts, bias=False)
336
+ self.experts = nn.ModuleList(
337
+ [MoeMLP(hidden_dim, ffn_dim) for _ in range(num_experts)]
338
+ )
339
+
340
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
341
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
342
+ hidden_states = hidden_states.view(-1, hidden_dim)
343
+ # router_logits: (batch * sequence_length, n_experts)
344
+ router_logits = self.gate(hidden_states)
345
+
346
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
347
+ routing_weights, selected_experts = torch.topk(
348
+ routing_weights, self.top_k, dim=-1
349
+ )
350
+ # we cast back to the input dtype
351
+ routing_weights = routing_weights.to(hidden_states.dtype)
352
+
353
+ final_hidden_states = torch.zeros(
354
+ (batch_size * sequence_length, hidden_dim),
355
+ dtype=hidden_states.dtype,
356
+ device=hidden_states.device,
357
+ )
358
+ # One hot encode the selected experts to create an expert mask
359
+ # this will be used to easily index which expert is going to be sollicitated
360
+ expert_mask = torch.nn.functional.one_hot(
361
+ selected_experts, num_classes=self.num_experts
362
+ ).permute(2, 1, 0)
363
+
364
+ # Loop over all available experts in the model and perform the computation on each expert
365
+ for expert_idx in range(self.num_experts):
366
+ expert_layer = self.experts[expert_idx]
367
+ idx, top_x = torch.where(expert_mask[expert_idx])
368
+
369
+ if top_x.shape[0] == 0:
370
+ continue
371
+
372
+ # in torch it is faster to index using lists than torch tensors
373
+ top_x_list = top_x.tolist()
374
+ idx_list = idx.tolist()
375
+
376
+ # Index the correct hidden states and compute the expert hidden state for
377
+ # the current expert. We need to make sure to multiply the output hidden
378
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
379
+ current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
380
+ current_hidden_states = (
381
+ expert_layer(current_state)
382
+ * routing_weights[top_x_list, idx_list, None]
383
+ )
384
+
385
+ # However `index_add_` only support torch tensors for indexing so we'll use
386
+ # the `top_x` tensor here.
387
+ final_hidden_states.index_add_(
388
+ 0, top_x, current_hidden_states.to(hidden_states.dtype)
389
+ )
390
+ final_hidden_states = final_hidden_states.reshape(
391
+ batch_size, sequence_length, hidden_dim
392
+ )
393
+ return final_hidden_states, router_logits
394
+
395
+
396
+ class DecoderLayer(nn.Module):
397
+ def __init__(
398
+ self,
399
+ hidden_size: int,
400
+ num_heads: int,
401
+ num_key_value_heads: int,
402
+ num_experts: int,
403
+ top_k: int,
404
+ widening_factor: float = 4.0,
405
+ max_position_embeddings: int = 2048,
406
+ attn_output_multiplier: float = 1.0,
407
+ max_attn_val: float = 30.0,
408
+ rms_norm_eps: float = 1e-5,
409
+ ) -> None:
410
+ super().__init__()
411
+ self.attn = MultiHeadAttention(
412
+ hidden_size,
413
+ num_heads,
414
+ num_key_value_heads,
415
+ max_position_embeddings=max_position_embeddings,
416
+ attn_output_multiplier=attn_output_multiplier,
417
+ max_attn_val=max_attn_val,
418
+ )
419
+ ffn_dim = int(hidden_size * widening_factor)
420
+ self.moe_block = MoeBlock(hidden_size, ffn_dim, num_experts, top_k)
421
+ self.pre_attn_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
422
+ self.post_attn_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
423
+ self.pre_moe_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
424
+ self.post_moe_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
425
+
426
+ def forward(
427
+ self,
428
+ hidden_states: torch.Tensor,
429
+ attention_mask: Optional[torch.Tensor] = None,
430
+ position_ids: Optional[torch.LongTensor] = None,
431
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
432
+ output_attentions: Optional[bool] = False,
433
+ output_router_logits: Optional[bool] = False,
434
+ use_cache: Optional[bool] = False,
435
+ **kwargs,
436
+ ) -> Tuple[
437
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
438
+ ]:
439
+ residual = hidden_states
440
+ hidden_states = self.pre_attn_norm(hidden_states)
441
+ hidden_states, attention_weights, present_key_value = self.attn(
442
+ hidden_states,
443
+ attention_mask=attention_mask,
444
+ position_ids=position_ids,
445
+ past_key_value=past_key_value,
446
+ output_attentions=output_attentions,
447
+ use_cache=use_cache,
448
+ )
449
+ hidden_states = self.post_attn_norm(hidden_states)
450
+ hidden_states = residual + hidden_states
451
+
452
+ residual = hidden_states
453
+ hidden_states = self.pre_moe_norm(hidden_states)
454
+ hidden_states, router_logits = self.moe_block(hidden_states)
455
+ hidden_states = self.post_moe_norm(hidden_states)
456
+ hidden_states = residual + hidden_states
457
+
458
+ outputs = (hidden_states,)
459
+ if output_attentions:
460
+ outputs += (attention_weights,)
461
+ if use_cache:
462
+ outputs += (present_key_value,)
463
+ if output_router_logits:
464
+ outputs += (router_logits,)
465
+ return outputs
466
+
467
+
468
+ class Grok1PretrainedModel(PreTrainedModel):
469
+ config_class = Grok1Config
470
+ base_model_prefix = "model"
471
+ supports_gradient_checkpointing = True
472
+ _no_split_modules = ["DecoderLayer"]
473
+ _skip_keys_device_placement = "past_key_values"
474
+ _supports_flash_attn_2 = False
475
+ _supports_cache_class = False
476
+
477
+ def _init_weights(self, module) -> None:
478
+ if isinstance(module, nn.Linear):
479
+ module.weight.data.zero_()
480
+ if module.bias is not None:
481
+ module.bias.data.zero_()
482
+ elif isinstance(module, nn.Embedding):
483
+ module.weight.data.zero_()
484
+
485
+
486
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
487
+ def _make_causal_mask(
488
+ input_ids_shape: torch.Size,
489
+ dtype: torch.dtype,
490
+ device: torch.device,
491
+ past_key_values_length: int = 0,
492
+ ):
493
+ """
494
+ Make causal mask used for bi-directional self-attention.
495
+ """
496
+ bsz, tgt_len = input_ids_shape
497
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
498
+ mask_cond = torch.arange(mask.size(-1), device=device)
499
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
500
+ mask = mask.to(dtype)
501
+
502
+ if past_key_values_length > 0:
503
+ mask = torch.cat(
504
+ [
505
+ torch.zeros(
506
+ tgt_len, past_key_values_length, dtype=dtype, device=device
507
+ ),
508
+ mask,
509
+ ],
510
+ dim=-1,
511
+ )
512
+ return mask[None, None, :, :].expand(
513
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
514
+ )
515
+
516
+
517
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
518
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
519
+ """
520
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
521
+ """
522
+ bsz, src_len = mask.size()
523
+ tgt_len = tgt_len if tgt_len is not None else src_len
524
+
525
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
526
+
527
+ inverted_mask = 1.0 - expanded_mask
528
+
529
+ return inverted_mask.masked_fill(
530
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
531
+ )
532
+
533
+
534
+ class Grok1Model(Grok1PretrainedModel):
535
+ def __init__(self, config: Grok1Config) -> None:
536
+ super().__init__(config)
537
+ self.padding_idx = config.pad_token_id
538
+ self.vocab_size = config.vocab_size
539
+
540
+ self.embed_tokens = nn.Embedding(
541
+ config.vocab_size, config.hidden_size, self.padding_idx
542
+ )
543
+ self.layers = nn.ModuleList(
544
+ [
545
+ DecoderLayer(
546
+ hidden_size=config.hidden_size,
547
+ num_heads=config.num_attention_heads,
548
+ num_key_value_heads=config.num_key_value_heads,
549
+ num_experts=config.num_experts,
550
+ top_k=config.num_experts_per_tok,
551
+ widening_factor=config.widening_factor,
552
+ max_position_embeddings=config.max_position_embeddings,
553
+ attn_output_multiplier=config.attn_output_multiplier,
554
+ max_attn_val=config.max_attn_value,
555
+ rms_norm_eps=config.rms_norm_eps,
556
+ )
557
+ for layer_idx in range(config.num_hidden_layers)
558
+ ]
559
+ )
560
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
561
+ self.gradient_checkpointing = False
562
+ self.post_init()
563
+
564
+ def get_input_embeddings(self):
565
+ return self.embed_tokens
566
+
567
+ def set_input_embeddings(self, value):
568
+ self.embed_tokens = value
569
+
570
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
571
+ def _prepare_decoder_attention_mask(
572
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
573
+ ):
574
+ # create causal mask
575
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
576
+ combined_attention_mask = None
577
+ if input_shape[-1] > 1:
578
+ combined_attention_mask = _make_causal_mask(
579
+ input_shape,
580
+ inputs_embeds.dtype,
581
+ device=inputs_embeds.device,
582
+ past_key_values_length=past_key_values_length,
583
+ )
584
+
585
+ if attention_mask is not None:
586
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
587
+ expanded_attn_mask = _expand_mask(
588
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
589
+ ).to(inputs_embeds.device)
590
+ combined_attention_mask = (
591
+ expanded_attn_mask
592
+ if combined_attention_mask is None
593
+ else expanded_attn_mask + combined_attention_mask
594
+ )
595
+
596
+ return combined_attention_mask
597
+
598
+ def forward(
599
+ self,
600
+ input_ids: torch.LongTensor = None,
601
+ attention_mask: Optional[torch.Tensor] = None,
602
+ position_ids: Optional[torch.LongTensor] = None,
603
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
604
+ inputs_embeds: Optional[torch.FloatTensor] = None,
605
+ use_cache: Optional[bool] = None,
606
+ output_attentions: Optional[bool] = None,
607
+ output_hidden_states: Optional[bool] = None,
608
+ output_router_logits: Optional[bool] = None,
609
+ return_dict: Optional[bool] = None,
610
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
611
+ output_attentions = (
612
+ output_attentions
613
+ if output_attentions is not None
614
+ else self.config.output_attentions
615
+ )
616
+ output_hidden_states = (
617
+ output_hidden_states
618
+ if output_hidden_states is not None
619
+ else self.config.output_hidden_states
620
+ )
621
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
622
+
623
+ return_dict = (
624
+ return_dict if return_dict is not None else self.config.use_return_dict
625
+ )
626
+
627
+ # retrieve input_ids and inputs_embeds
628
+ if input_ids is not None and inputs_embeds is not None:
629
+ raise ValueError(
630
+ "You cannot specify both input_ids and inputs_embeds at the same time"
631
+ )
632
+ elif input_ids is not None:
633
+ batch_size, seq_length = input_ids.shape[:2]
634
+ elif inputs_embeds is not None:
635
+ batch_size, seq_length = inputs_embeds.shape[:2]
636
+ else:
637
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
638
+
639
+ seq_length_with_past = seq_length
640
+ past_key_values_length = 0
641
+ if past_key_values is not None:
642
+ past_key_values_length = past_key_values[0][0].shape[2]
643
+ seq_length_with_past = seq_length_with_past + past_key_values_length
644
+
645
+ if position_ids is None:
646
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
647
+ position_ids = torch.arange(
648
+ past_key_values_length,
649
+ seq_length + past_key_values_length,
650
+ dtype=torch.long,
651
+ device=device,
652
+ )
653
+ position_ids = position_ids.unsqueeze(0)
654
+
655
+ if inputs_embeds is None:
656
+ inputs_embeds = self.embed_tokens(input_ids)
657
+
658
+ if HAS_MASK_UTILS:
659
+ # 4d mask is passed through the layers
660
+ attention_mask = _prepare_4d_causal_attention_mask(
661
+ attention_mask,
662
+ (batch_size, seq_length),
663
+ inputs_embeds,
664
+ past_key_values_length,
665
+ )
666
+ else:
667
+ if attention_mask is None:
668
+ attention_mask = torch.ones(
669
+ (batch_size, seq_length_with_past),
670
+ dtype=torch.bool,
671
+ device=inputs_embeds.device,
672
+ )
673
+ attention_mask = self._prepare_decoder_attention_mask(
674
+ attention_mask,
675
+ (batch_size, seq_length),
676
+ inputs_embeds,
677
+ past_key_values_length,
678
+ )
679
+
680
+ # embed positions
681
+ hidden_states = inputs_embeds
682
+
683
+ if self.gradient_checkpointing and self.training:
684
+ if use_cache:
685
+ logger.warning_once(
686
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
687
+ )
688
+ use_cache = False
689
+
690
+ # decoder layers
691
+ all_hidden_states = () if output_hidden_states else None
692
+ all_self_attns = () if output_attentions else None
693
+ all_router_logits = () if output_router_logits else None
694
+ next_decoder_cache = () if use_cache else None
695
+
696
+ for idx, decoder_layer in enumerate(self.layers):
697
+ if output_hidden_states:
698
+ all_hidden_states += (hidden_states,)
699
+
700
+ past_key_value = (
701
+ past_key_values[idx] if past_key_values is not None else None
702
+ )
703
+
704
+ if self.gradient_checkpointing and self.training:
705
+
706
+ def create_custom_forward(module):
707
+ def custom_forward(*inputs):
708
+ # None for past_key_value
709
+ return module(*inputs, past_key_value, output_attentions)
710
+
711
+ return custom_forward
712
+
713
+ layer_outputs = torch.utils.checkpoint.checkpoint(
714
+ create_custom_forward(decoder_layer),
715
+ hidden_states,
716
+ attention_mask,
717
+ position_ids,
718
+ )
719
+ else:
720
+ layer_outputs = decoder_layer(
721
+ hidden_states,
722
+ attention_mask=attention_mask,
723
+ position_ids=position_ids,
724
+ past_key_value=past_key_value,
725
+ output_attentions=output_attentions,
726
+ use_cache=use_cache,
727
+ )
728
+
729
+ hidden_states = layer_outputs[0]
730
+
731
+ if use_cache:
732
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
733
+
734
+ if output_attentions:
735
+ all_self_attns += (layer_outputs[1],)
736
+
737
+ if output_router_logits:
738
+ all_router_logits += (layer_outputs[-1],)
739
+
740
+ hidden_states = self.norm(hidden_states)
741
+
742
+ # add hidden states from the last decoder layer
743
+ if output_hidden_states:
744
+ all_hidden_states += (hidden_states,)
745
+ next_cache = next_decoder_cache if use_cache else None
746
+
747
+ if not return_dict:
748
+ return tuple(
749
+ v
750
+ for v in [
751
+ hidden_states,
752
+ next_cache,
753
+ all_hidden_states,
754
+ all_self_attns,
755
+ all_router_logits,
756
+ ]
757
+ if v is not None
758
+ )
759
+ return MoeModelOutputWithPast(
760
+ last_hidden_state=hidden_states,
761
+ past_key_values=next_cache,
762
+ hidden_states=all_hidden_states,
763
+ attentions=all_self_attns,
764
+ router_logits=all_router_logits,
765
+ )
766
+
767
+
768
+ class Grok1ModelForCausalLM(Grok1PretrainedModel):
769
+ _tied_weights_keys = ["lm_head.weight"]
770
+
771
+ def __init__(self, config: Grok1Config):
772
+ super().__init__(config)
773
+ self.model = Grok1Model(config)
774
+ self.vocab_size = config.vocab_size
775
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
776
+ self.router_aux_loss_coef = config.router_aux_loss_coef
777
+ self.num_experts = config.num_experts
778
+ self.num_experts_per_tok = config.num_experts_per_tok
779
+ self.post_init()
780
+
781
+ def get_input_embeddings(self):
782
+ return self.model.embed_tokens
783
+
784
+ def set_input_embeddings(self, value):
785
+ self.model.embed_tokens = value
786
+
787
+ def get_output_embeddings(self):
788
+ return self.lm_head
789
+
790
+ def set_output_embeddings(self, new_embeddings):
791
+ self.lm_head = new_embeddings
792
+
793
+ def set_decoder(self, decoder):
794
+ self.model = decoder
795
+
796
+ def get_decoder(self):
797
+ return self.model
798
+
799
+ def forward(
800
+ self,
801
+ input_ids: torch.LongTensor = None,
802
+ attention_mask: Optional[torch.Tensor] = None,
803
+ position_ids: Optional[torch.LongTensor] = None,
804
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
805
+ inputs_embeds: Optional[torch.FloatTensor] = None,
806
+ labels: Optional[torch.LongTensor] = None,
807
+ use_cache: Optional[bool] = None,
808
+ output_attentions: Optional[bool] = None,
809
+ output_hidden_states: Optional[bool] = None,
810
+ output_router_logits: Optional[bool] = None,
811
+ return_dict: Optional[bool] = None,
812
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
813
+ output_attentions = (
814
+ output_attentions
815
+ if output_attentions is not None
816
+ else self.config.output_attentions
817
+ )
818
+ output_router_logits = (
819
+ output_router_logits
820
+ if output_router_logits is not None
821
+ else self.config.output_router_logits
822
+ )
823
+
824
+ output_hidden_states = (
825
+ output_hidden_states
826
+ if output_hidden_states is not None
827
+ else self.config.output_hidden_states
828
+ )
829
+ return_dict = (
830
+ return_dict if return_dict is not None else self.config.use_return_dict
831
+ )
832
+
833
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
834
+ outputs = self.model(
835
+ input_ids=input_ids,
836
+ attention_mask=attention_mask,
837
+ position_ids=position_ids,
838
+ past_key_values=past_key_values,
839
+ inputs_embeds=inputs_embeds,
840
+ use_cache=use_cache,
841
+ output_attentions=output_attentions,
842
+ output_hidden_states=output_hidden_states,
843
+ output_router_logits=output_router_logits,
844
+ return_dict=return_dict,
845
+ )
846
+
847
+ hidden_states = outputs[0]
848
+ logits = self.lm_head(hidden_states)
849
+ logits = logits.float()
850
+
851
+ loss = None
852
+ if labels is not None:
853
+ # Shift so that tokens < n predict n
854
+ shift_logits = logits[..., :-1, :].contiguous()
855
+ shift_labels = labels[..., 1:].contiguous()
856
+ # Flatten the tokens
857
+ loss_fct = nn.CrossEntropyLoss()
858
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
859
+ shift_labels = shift_labels.view(-1)
860
+ # Enable model parallelism
861
+ shift_labels = shift_labels.to(shift_logits.device)
862
+ loss = loss_fct(shift_logits, shift_labels)
863
+
864
+ aux_loss = None
865
+ if output_router_logits:
866
+ aux_loss = load_balancing_loss_func(
867
+ outputs.router_logits if return_dict else outputs[-1],
868
+ self.num_experts,
869
+ self.num_experts_per_tok,
870
+ )
871
+ if labels is not None:
872
+ loss += self.router_aux_loss_coef * aux_loss
873
+
874
+ if not return_dict:
875
+ output = (logits,) + outputs[1:]
876
+ if output_router_logits:
877
+ output = (aux_loss,) + output
878
+ return (loss,) + output if loss is not None else output
879
+
880
+ return MoeCausalLMOutputWithPast(
881
+ loss=loss,
882
+ aux_loss=aux_loss,
883
+ logits=logits,
884
+ past_key_values=outputs.past_key_values,
885
+ hidden_states=outputs.hidden_states,
886
+ attentions=outputs.attentions,
887
+ router_logits=outputs.router_logits,
888
+ )
889
+
890
+ def prepare_inputs_for_generation(
891
+ self,
892
+ input_ids,
893
+ past_key_values=None,
894
+ attention_mask=None,
895
+ inputs_embeds=None,
896
+ **kwargs,
897
+ ):
898
+ if past_key_values:
899
+ input_ids = input_ids[:, -1:]
900
+
901
+ position_ids = kwargs.get("position_ids", None)
902
+ if attention_mask is not None and position_ids is None:
903
+ # create position_ids on the fly for batch generation
904
+ position_ids = attention_mask.long().cumsum(-1) - 1
905
+ position_ids.masked_fill_(attention_mask == 0, 1)
906
+ if past_key_values:
907
+ position_ids = position_ids[:, -1].unsqueeze(-1)
908
+
909
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
910
+ if inputs_embeds is not None and past_key_values is None:
911
+ model_inputs = {"inputs_embeds": inputs_embeds}
912
+ else:
913
+ model_inputs = {"input_ids": input_ids}
914
+
915
+ model_inputs.update(
916
+ {
917
+ "position_ids": position_ids,
918
+ "past_key_values": past_key_values,
919
+ "use_cache": kwargs.get("use_cache"),
920
+ "attention_mask": attention_mask,
921
+ }
922
+ )
923
+ return model_inputs
modeling_grok1_outputs.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ from transformers.modeling_outputs import ModelOutput
6
+
7
+ __all__ = [
8
+ "MoeModelOutputWithPast",
9
+ "MoeCausalLMOutputWithPast",
10
+ ]
11
+
12
+ try:
13
+ from transformers.modeling_outputs import (
14
+ MoeCausalLMOutputWithPast,
15
+ MoeModelOutputWithPast,
16
+ )
17
+ except:
18
+
19
+ @dataclass
20
+ class MoeModelOutputWithPast(ModelOutput):
21
+ """
22
+ Base class for model's outputs, with potential hidden states and attentions.
23
+
24
+ Args:
25
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
26
+ Sequence of hidden-states at the output of the last layer of the model.
27
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
28
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
29
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
30
+ `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
31
+ encoder_sequence_length, embed_size_per_head)`.
32
+
33
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
34
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
35
+ input) to speed up sequential decoding.
36
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
37
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
38
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
39
+
40
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
41
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
42
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
43
+ sequence_length)`.
44
+
45
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
46
+ heads.
47
+ router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
48
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
49
+
50
+ Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary
51
+ loss for Mixture of Experts models.
52
+ """
53
+
54
+ last_hidden_state: torch.FloatTensor = None
55
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
56
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
57
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
58
+ router_logits: Optional[Tuple[torch.FloatTensor]] = None
59
+
60
+ @dataclass
61
+ class MoeCausalLMOutputWithPast(ModelOutput):
62
+ """
63
+ Base class for causal language model (or autoregressive) with mixture of experts outputs.
64
+
65
+ Args:
66
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
67
+ Language modeling loss (for next-token prediction).
68
+
69
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
70
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
71
+
72
+ aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
73
+ aux_loss for the sparse modules.
74
+
75
+ router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
76
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
77
+
78
+ Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary
79
+ loss for Mixture of Experts models.
80
+
81
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
82
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
83
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
84
+
85
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
86
+ `past_key_values` input) to speed up sequential decoding.
87
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
88
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
89
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
90
+
91
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
92
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
93
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
94
+ sequence_length)`.
95
+
96
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
97
+ heads.
98
+ """
99
+
100
+ loss: Optional[torch.FloatTensor] = None
101
+ aux_loss: Optional[torch.FloatTensor] = None
102
+ logits: torch.FloatTensor = None
103
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
104
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
105
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
106
+ router_logits: Optional[Tuple[torch.FloatTensor]] = None