Crystalcareai commited on
Commit
686fb4f
·
verified ·
1 Parent(s): d72d599

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +189 -148
modeling_gemmoe.py CHANGED
@@ -31,7 +31,7 @@ from transformers.modeling_attn_mask_utils import (
31
  AttentionMaskConverter,
32
  _prepare_4d_causal_attention_mask,
33
  )
34
- from transformers.modeling_outputs import SequenceClassifierOutputWithPast, MoeModelOutputWithPast, MoeCausalLMOutputWithPast
35
  from transformers.modeling_utils import PreTrainedModel
36
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
37
  from transformers.utils import (
@@ -45,8 +45,6 @@ from transformers.utils import (
45
  from transformers.utils.import_utils import is_torch_fx_available
46
  from .configuration_gemmoe import GemmoeConfig
47
 
48
- from math import sqrt as math_sqrt
49
-
50
 
51
  if is_flash_attn_2_available():
52
  from flash_attn import flash_attn_func, flash_attn_varlen_func
@@ -55,7 +53,6 @@ if is_flash_attn_2_available():
55
 
56
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
57
  # It means that the function will not be traced through and simply appear as a node in the graph.
58
-
59
  if is_torch_fx_available():
60
  if not is_torch_greater_or_equal_than_1_13:
61
  import torch.fx
@@ -67,9 +64,7 @@ logger = logging.get_logger(__name__)
67
 
68
  _CONFIG_FOR_DOC = "GemmoeConfig"
69
 
70
- def load_balancing_loss_func(
71
- gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
72
- ) -> float:
73
  r"""
74
  Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
75
 
@@ -79,73 +74,42 @@ def load_balancing_loss_func(
79
 
80
  Args:
81
  gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
82
- Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
83
- shape [batch_size X sequence_length, num_experts].
84
- attention_mask (`torch.Tensor`, None):
85
- The attention_mask used in forward function
86
- shape [batch_size X sequence_length] if not None.
87
  num_experts (`int`, *optional*):
88
  Number of experts
89
 
90
  Returns:
91
  The auxiliary loss.
92
  """
93
- if gate_logits is None or not isinstance(gate_logits, tuple):
94
  return 0
95
 
96
  if isinstance(gate_logits, tuple):
97
- compute_device = gate_logits[0].device
98
- concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
99
-
100
- routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
101
-
102
- _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
103
 
104
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
105
-
106
- if attention_mask is None:
107
- # Compute the percentage of tokens routed to each experts
108
- tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
109
-
110
- # Compute the average probability of routing to these experts
111
- router_prob_per_expert = torch.mean(routing_weights, dim=0)
112
- else:
113
- batch_size, sequence_length = attention_mask.shape
114
- num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
115
-
116
- # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
117
- expert_attention_mask = (
118
- attention_mask[None, :, :, None, None]
119
- .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
120
- .reshape(-1, top_k, num_experts)
121
- .to(compute_device)
122
- )
123
 
124
- # Compute the percentage of tokens routed to each experts
125
- tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
126
- expert_attention_mask, dim=0
127
- )
128
 
129
- # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
130
- router_per_expert_attention_mask = (
131
- attention_mask[None, :, :, None]
132
- .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
133
- .reshape(-1, num_experts)
134
- .to(compute_device)
135
- )
136
 
137
- # Compute the average probability of routing to these experts
138
- router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
139
- router_per_expert_attention_mask, dim=0
140
- )
141
 
142
- overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
143
- return overall_loss * num_experts
144
 
 
 
 
145
 
 
 
146
 
147
- def approx_gelu(x):
148
- return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * x**3)))
149
 
150
  def _get_unpad_data(attention_mask):
151
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
@@ -159,7 +123,6 @@ def _get_unpad_data(attention_mask):
159
  )
160
 
161
 
162
-
163
  class GemmoeRMSNorm(nn.Module):
164
  def __init__(self, dim: int, eps: float = 1e-6):
165
  super().__init__()
@@ -167,45 +130,48 @@ class GemmoeRMSNorm(nn.Module):
167
  self.weight = nn.Parameter(torch.zeros(dim))
168
 
169
  def _norm(self, x):
170
- x_float = x.float()
171
- normed_x = x_float * torch.rsqrt(x_float.pow(2).mean(-1, keepdim=True) + self.eps)
172
- return normed_x
173
 
174
  def forward(self, x):
175
- normed_x = self._norm(x)
176
- # Downcast the result to the original dtype at the end
177
- normed_x = normed_x.type_as(x)
178
- return normed_x * (self.weight + 1)
 
 
179
 
180
  ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
181
 
 
182
  class GemmoeRotaryEmbedding(nn.Module):
183
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
184
  super().__init__()
 
185
  self.dim = dim
186
  self.max_position_embeddings = max_position_embeddings
187
  self.base = base
188
- self._set_cos_sin_cache(seq_len=max_position_embeddings, device=device, dtype=torch.get_default_dtype())
189
-
190
- def _set_cos_sin_cache(self, seq_len, device, dtype):
191
- self.max_seq_len_cached = seq_len
192
- freq_exponents = (2.0 / self.dim) * (torch.arange(self.dim // 2, dtype=torch.float32, device="cpu").float())
193
- timescale = self.base ** freq_exponents
194
- positions = torch.arange(self.max_seq_len_cached, device="cpu", dtype=torch.float32).float()
195
- radians_new = positions[..., None] / timescale[None, None, :]
196
- radians_new = radians_new.squeeze(0)
197
- emb = torch.cat((radians_new, radians_new), dim=-1)
198
- cos = emb.cos().to(device=device, dtype=dtype, non_blocking=True)
199
- sin = emb.sin().to(device=device, dtype=dtype, non_blocking=True)
200
- self.register_buffer("cos_cached", cos, persistent=False)
201
- self.register_buffer("sin_cached", sin, persistent=False)
202
-
203
- def forward(self, x, position_ids=None, seq_len=None):
204
- if seq_len is None:
205
- seq_len = x.size(2)
206
- if seq_len > self.max_seq_len_cached:
207
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
208
- return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
 
209
 
210
  # Copied from transformers.models.llama.modeling_llama.rotate_half
211
  def rotate_half(x):
@@ -215,15 +181,34 @@ def rotate_half(x):
215
  return torch.cat((-x2, x1), dim=-1)
216
 
217
 
 
218
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
219
- seq_len, dim = q.shape[-2], q.shape[-1]
220
- cos = cos[:seq_len].view(1, 1, seq_len, dim)
221
- sin = sin[:seq_len].view(1, 1, seq_len, dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  q_embed = (q * cos) + (rotate_half(q) * sin)
223
  k_embed = (k * cos) + (rotate_half(k) * sin)
224
  return q_embed, k_embed
225
 
226
-
227
  # Copied from transformers.models.llama.modeling_llama.repeat_kv
228
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
229
  """
@@ -236,6 +221,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
236
  hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
237
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
238
 
 
239
  class GemmoeAttention(nn.Module):
240
  """Multi-headed attention from 'Attention Is All You Need' paper"""
241
 
@@ -638,8 +624,20 @@ class GemmoeBlockSparseTop2MLP(nn.Module):
638
  self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
639
  self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
640
  self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
641
-
642
- self.act_fn = approx_gelu
 
 
 
 
 
 
 
 
 
 
 
 
643
 
644
  def forward(self, hidden_states):
645
  current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
@@ -786,7 +784,6 @@ GEMMOE_START_DOCSTRING = r"""
786
  "The bare Gemmoe Model outputting raw hidden-states without any specific head on top.",
787
  GEMMOE_START_DOCSTRING,
788
  )
789
-
790
  class GemmoePreTrainedModel(PreTrainedModel):
791
  config_class = GemmoeConfig
792
  base_model_prefix = "model"
@@ -909,7 +906,7 @@ GEMMOE_INPUTS_DOCSTRING = r"""
909
  "The bare Gemmoe Model outputting raw hidden-states without any specific head on top.",
910
  GEMMOE_START_DOCSTRING,
911
  )
912
-
913
  class GemmoeModel(GemmoePreTrainedModel):
914
  """
915
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmoeDecoderLayer`]
@@ -929,6 +926,7 @@ class GemmoeModel(GemmoePreTrainedModel):
929
  )
930
  self.norm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
931
  self.gradient_checkpointing = False
 
932
 
933
  # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
934
  # NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
@@ -946,6 +944,7 @@ class GemmoeModel(GemmoePreTrainedModel):
946
  self.embed_tokens = value
947
 
948
  @add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
 
949
  def forward(
950
  self,
951
  input_ids: torch.LongTensor = None,
@@ -961,6 +960,9 @@ class GemmoeModel(GemmoePreTrainedModel):
961
  cache_position: Optional[torch.LongTensor] = None,
962
  ) -> Union[Tuple, MoeModelOutputWithPast]:
963
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
964
  output_hidden_states = (
965
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
966
  )
@@ -981,14 +983,6 @@ class GemmoeModel(GemmoePreTrainedModel):
981
  if inputs_embeds is None:
982
  inputs_embeds = self.embed_tokens(input_ids)
983
 
984
- # Scale embeddings
985
- # Fix for precision issue when casting to bfloat16
986
- hidden_size_sqrt = math.sqrt(self.config.hidden_size)
987
- if inputs_embeds.dtype == torch.bfloat16:
988
- pass
989
-
990
- hidden_states = inputs_embeds * hidden_size_sqrt
991
-
992
  past_seen_tokens = 0
993
  if use_cache: # kept for BC (cache positions)
994
  if not isinstance(past_key_values, StaticCache):
@@ -1003,43 +997,58 @@ class GemmoeModel(GemmoePreTrainedModel):
1003
  if position_ids is None:
1004
  position_ids = cache_position.unsqueeze(0)
1005
 
1006
- causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
 
 
 
 
 
 
 
1007
 
1008
- # embed positions
1009
- hidden_states = inputs_embeds
 
 
 
1010
 
1011
  # normalized
1012
- hidden_states = hidden_states * (self.config.hidden_size**0.5)
 
 
 
1013
 
1014
  # decoder layers
1015
  all_hidden_states = () if output_hidden_states else None
1016
  all_self_attns = () if output_attentions else None
 
1017
  next_decoder_cache = None
1018
 
1019
  for decoder_layer in self.layers:
1020
  if output_hidden_states:
1021
  all_hidden_states += (hidden_states,)
 
 
1022
  layer_outputs = self._gradient_checkpointing_func(
1023
  decoder_layer.__call__,
1024
  hidden_states,
1025
- causal_mask,
1026
  position_ids,
1027
  past_key_values,
1028
  output_attentions,
1029
  output_router_logits,
1030
- use_cache.item() if isinstance(use_cache, torch.Tensor) else use_cache,
1031
  cache_position,
1032
- output_router_logits,
1033
  )
1034
  else:
1035
  layer_outputs = decoder_layer(
1036
  hidden_states,
1037
- attention_mask=causal_mask,
1038
  position_ids=position_ids,
1039
  past_key_value=past_key_values,
1040
  output_attentions=output_attentions,
1041
  output_router_logits=output_router_logits,
1042
- use_cache=use_cache.item() if isinstance(use_cache, torch.Tensor) else use_cache,
1043
  cache_position=cache_position,
1044
  )
1045
 
@@ -1051,6 +1060,9 @@ class GemmoeModel(GemmoePreTrainedModel):
1051
  if output_attentions:
1052
  all_self_attns += (layer_outputs[1],)
1053
 
 
 
 
1054
  hidden_states = self.norm(hidden_states)
1055
 
1056
  # add hidden states from the last decoder layer
@@ -1063,15 +1075,24 @@ class GemmoeModel(GemmoePreTrainedModel):
1063
  next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
1064
  )
1065
  if not return_dict:
1066
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
 
 
 
 
1067
  return MoeModelOutputWithPast(
1068
  last_hidden_state=hidden_states,
1069
  past_key_values=next_cache,
1070
  hidden_states=all_hidden_states,
1071
  attentions=all_self_attns,
 
1072
  )
1073
 
1074
- def _update_causal_mask(self, attention_mask, input_tensor):
 
 
 
 
1075
  if self.config._attn_implementation == "flash_attention_2":
1076
  if attention_mask is not None and 0.0 in attention_mask:
1077
  return attention_mask
@@ -1092,15 +1113,23 @@ class GemmoeModel(GemmoePreTrainedModel):
1092
  causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
1093
  causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
1094
  if attention_mask is not None:
1095
- causal_mask = causal_mask.clone()
1096
  if attention_mask.dim() == 2:
1097
  mask_length = attention_mask.shape[-1]
1098
  padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
1099
  causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
1100
  elif attention_mask.dim() == 4:
 
 
 
 
 
 
1101
  mask_shape = attention_mask.shape
1102
  mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
1103
- causal_mask[: mask_shape[0], : mask_shape[1], : mask_shape[2], : mask_shape[3]] = mask_slice
 
 
1104
 
1105
  if (
1106
  self.config._attn_implementation == "sdpa"
@@ -1121,6 +1150,8 @@ class GemmoeModel(GemmoePreTrainedModel):
1121
 
1122
  return causal_mask
1123
 
 
 
1124
  class GemmoeForCausalLM(GemmoePreTrainedModel):
1125
  _tied_weights_keys = ["lm_head.weight"]
1126
 
@@ -1155,7 +1186,6 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1155
 
1156
  @add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
1157
  @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1158
- # Ignore copy
1159
  def forward(
1160
  self,
1161
  input_ids: torch.LongTensor = None,
@@ -1169,6 +1199,7 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1169
  output_hidden_states: Optional[bool] = None,
1170
  output_router_logits: Optional[bool] = None,
1171
  return_dict: Optional[bool] = None,
 
1172
  ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1173
  r"""
1174
  Args:
@@ -1184,23 +1215,21 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1184
  ```python
1185
  >>> from transformers import AutoTokenizer, GemmoeForCausalLM
1186
 
1187
- >>> model = GemmoeForCausalLM.from_pretrained("mistralai/Gemmoe-8x7B-v0.1")
1188
- >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Gemmoe-8x7B-v0.1")
1189
 
1190
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
1191
  >>> inputs = tokenizer(prompt, return_tensors="pt")
1192
 
1193
  >>> # Generate
1194
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1195
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1196
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1197
  ```"""
1198
-
1199
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1200
  output_router_logits = (
1201
  output_router_logits if output_router_logits is not None else self.config.output_router_logits
1202
  )
1203
-
1204
  output_hidden_states = (
1205
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1206
  )
@@ -1218,17 +1247,12 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1218
  output_hidden_states=output_hidden_states,
1219
  output_router_logits=output_router_logits,
1220
  return_dict=return_dict,
 
1221
  )
1222
 
1223
  hidden_states = outputs[0]
1224
  logits = self.lm_head(hidden_states)
1225
  logits = logits.float()
1226
-
1227
- if self.training:
1228
- for expert in self.model.layers[-1].block_sparse_moe.experts:
1229
- for param in expert.parameters():
1230
- if param.requires_grad and param.grad is None:
1231
- param.grad = torch.zeros_like(param)
1232
 
1233
  loss = None
1234
  if labels is not None:
@@ -1246,13 +1270,10 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1246
  aux_loss = None
1247
  if output_router_logits:
1248
  aux_loss = load_balancing_loss_func(
1249
- outputs.router_logits if return_dict else outputs[-1],
1250
- self.num_experts,
1251
- self.num_experts_per_tok,
1252
- attention_mask,
1253
  )
1254
  if labels is not None:
1255
- loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
1256
 
1257
  if not return_dict:
1258
  output = (logits,) + outputs[1:]
@@ -1271,20 +1292,26 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1271
  )
1272
 
1273
  def prepare_inputs_for_generation(
1274
- self,
1275
- input_ids,
1276
- past_key_values=None,
1277
- attention_mask=None,
1278
- inputs_embeds=None,
1279
- output_router_logits=False,
1280
- **kwargs,
1281
  ):
1282
- # Omit tokens covered by past_key_values
 
 
 
 
 
 
 
1283
  if past_key_values is not None:
1284
  if isinstance(past_key_values, Cache):
1285
- cache_length = past_key_values.get_seq_length()
1286
- past_length = past_key_values.seen_tokens
1287
- max_cache_length = past_key_values.get_max_length()
 
 
 
 
 
1288
  else:
1289
  cache_length = past_length = past_key_values[0][0].shape[2]
1290
  max_cache_length = None
@@ -1321,15 +1348,27 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1321
  if inputs_embeds is not None and past_key_values is None:
1322
  model_inputs = {"inputs_embeds": inputs_embeds}
1323
  else:
1324
- model_inputs = {"input_ids": input_ids}
 
 
 
 
 
 
 
 
 
 
 
 
1325
 
1326
  model_inputs.update(
1327
  {
1328
  "position_ids": position_ids,
 
1329
  "past_key_values": past_key_values,
1330
  "use_cache": kwargs.get("use_cache"),
1331
  "attention_mask": attention_mask,
1332
- "output_router_logits": output_router_logits,
1333
  }
1334
  )
1335
  return model_inputs
@@ -1342,10 +1381,12 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1342
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1343
  )
1344
  return reordered_past
1345
-
 
1346
  @add_start_docstrings(
1347
  """
1348
  The Gemmoe Model transformer with a sequence classification head on top (linear layer).
 
1349
  [`GemmoeForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1350
  (e.g. GPT-2) do.
1351
 
@@ -1357,7 +1398,7 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1357
  """,
1358
  GEMMOE_START_DOCSTRING,
1359
  )
1360
-
1361
  class GemmoeForSequenceClassification(GemmoePreTrainedModel):
1362
  def __init__(self, config):
1363
  super().__init__(config)
 
31
  AttentionMaskConverter,
32
  _prepare_4d_causal_attention_mask,
33
  )
34
+ from transformers.modeling_outputs import MoeModelOutputWithPast, MoeCausalLMOutputWithPast, SequenceClassifierOutputWithPast
35
  from transformers.modeling_utils import PreTrainedModel
36
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
37
  from transformers.utils import (
 
45
  from transformers.utils.import_utils import is_torch_fx_available
46
  from .configuration_gemmoe import GemmoeConfig
47
 
 
 
48
 
49
  if is_flash_attn_2_available():
50
  from flash_attn import flash_attn_func, flash_attn_varlen_func
 
53
 
54
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
55
  # It means that the function will not be traced through and simply appear as a node in the graph.
 
56
  if is_torch_fx_available():
57
  if not is_torch_greater_or_equal_than_1_13:
58
  import torch.fx
 
64
 
65
  _CONFIG_FOR_DOC = "GemmoeConfig"
66
 
67
+ def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2) -> float:
 
 
68
  r"""
69
  Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
70
 
 
74
 
75
  Args:
76
  gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
77
+ Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, seqeunce_length, num_experts].
 
 
 
 
78
  num_experts (`int`, *optional*):
79
  Number of experts
80
 
81
  Returns:
82
  The auxiliary loss.
83
  """
84
+ if gate_logits is None:
85
  return 0
86
 
87
  if isinstance(gate_logits, tuple):
88
+ # cat along the layers?
89
+ gate_logits = torch.cat(gate_logits, dim=0)
 
 
 
 
90
 
91
+ routing_weights, selected_experts = torch.topk(gate_logits, top_k, dim=-1)
92
+ routing_weights = routing_weights.softmax(dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
+ # cast the expert indices to int64, otherwise one-hot encoding will fail
95
+ if selected_experts.dtype != torch.int64:
96
+ selected_experts = selected_experts.to(torch.int64)
 
97
 
98
+ if len(selected_experts.shape) == 2:
99
+ selected_experts = selected_experts.unsqueeze(2)
 
 
 
 
 
100
 
101
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
 
 
 
102
 
103
+ # For a given token, determine if it was routed to a given expert.
104
+ expert_mask = torch.max(expert_mask, axis=-2).values
105
 
106
+ # cast to float32 otherwise mean will fail
107
+ expert_mask = expert_mask.to(torch.float32)
108
+ tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2)
109
 
110
+ router_prob_per_group_and_expert = torch.mean(routing_weights, axis=-1)
111
+ return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert.unsqueeze(-1)) * (num_experts**2)
112
 
 
 
113
 
114
  def _get_unpad_data(attention_mask):
115
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
 
123
  )
124
 
125
 
 
126
  class GemmoeRMSNorm(nn.Module):
127
  def __init__(self, dim: int, eps: float = 1e-6):
128
  super().__init__()
 
130
  self.weight = nn.Parameter(torch.zeros(dim))
131
 
132
  def _norm(self, x):
133
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
 
 
134
 
135
  def forward(self, x):
136
+ output = self._norm(x.float())
137
+ # Llama does x.to(float16) * w whilst Gemmoe is (x * w).to(float16)
138
+ # See https://github.com/huggingface/transformers/pull/29402
139
+ output = output * (1.0 + self.weight.float())
140
+ return output.type_as(x)
141
+
142
 
143
  ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
144
 
145
+
146
  class GemmoeRotaryEmbedding(nn.Module):
147
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
148
  super().__init__()
149
+
150
  self.dim = dim
151
  self.max_position_embeddings = max_position_embeddings
152
  self.base = base
153
+ self.register_buffer("inv_freq", None, persistent=False)
154
+
155
+ @torch.no_grad()
156
+ def forward(self, x, position_ids, seq_len=None):
157
+ # x: [bs, num_attention_heads, seq_len, head_size]
158
+ if self.inv_freq is None:
159
+ self.inv_freq = 1.0 / (
160
+ self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
161
+ )
162
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
163
+ position_ids_expanded = position_ids[:, None, :].float()
164
+ # Force float32 since bfloat16 loses precision on long contexts
165
+ # See https://github.com/huggingface/transformers/pull/29285
166
+ device_type = x.device.type
167
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
168
+ with torch.autocast(device_type=device_type, enabled=False):
169
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
170
+ emb = torch.cat((freqs, freqs), dim=-1)
171
+ cos = emb.cos()
172
+ sin = emb.sin()
173
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
174
+
175
 
176
  # Copied from transformers.models.llama.modeling_llama.rotate_half
177
  def rotate_half(x):
 
181
  return torch.cat((-x2, x1), dim=-1)
182
 
183
 
184
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
185
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
186
+ """Applies Rotary Position Embedding to the query and key tensors.
187
+
188
+ Args:
189
+ q (`torch.Tensor`): The query tensor.
190
+ k (`torch.Tensor`): The key tensor.
191
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
192
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
193
+ position_ids (`torch.Tensor`, *optional*):
194
+ Deprecated and unused.
195
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
196
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
197
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
198
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
199
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
200
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
201
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
202
+ Returns:
203
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
204
+ """
205
+ cos = cos.unsqueeze(unsqueeze_dim)
206
+ sin = sin.unsqueeze(unsqueeze_dim)
207
  q_embed = (q * cos) + (rotate_half(q) * sin)
208
  k_embed = (k * cos) + (rotate_half(k) * sin)
209
  return q_embed, k_embed
210
 
211
+
212
  # Copied from transformers.models.llama.modeling_llama.repeat_kv
213
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
214
  """
 
221
  hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
222
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
223
 
224
+
225
  class GemmoeAttention(nn.Module):
226
  """Multi-headed attention from 'Attention Is All You Need' paper"""
227
 
 
624
  self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
625
  self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
626
  self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
627
+
628
+ if config.hidden_activation is None:
629
+ logger.warning_once(
630
+ "Gemmoe's activation function should be approximate GeLU and not exact GeLU.\n"
631
+ "Changing the activation function to `gelu_pytorch_tanh`."
632
+ f"if you want to use the legacy `{config.hidden_act}`, "
633
+ f"edit the `model.config` to set `hidden_activation={config.hidden_act}` "
634
+ " instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details."
635
+ )
636
+ hidden_activation = "gelu_pytorch_tanh"
637
+ else:
638
+ hidden_activation = config.hidden_activation
639
+
640
+ self.act_fn = ACT2FN[hidden_activation]
641
 
642
  def forward(self, hidden_states):
643
  current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
 
784
  "The bare Gemmoe Model outputting raw hidden-states without any specific head on top.",
785
  GEMMOE_START_DOCSTRING,
786
  )
 
787
  class GemmoePreTrainedModel(PreTrainedModel):
788
  config_class = GemmoeConfig
789
  base_model_prefix = "model"
 
906
  "The bare Gemmoe Model outputting raw hidden-states without any specific head on top.",
907
  GEMMOE_START_DOCSTRING,
908
  )
909
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel with LLAMA->GEMMOE,Llama->Gemmoe
910
  class GemmoeModel(GemmoePreTrainedModel):
911
  """
912
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmoeDecoderLayer`]
 
926
  )
927
  self.norm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
928
  self.gradient_checkpointing = False
929
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
930
 
931
  # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
932
  # NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
 
944
  self.embed_tokens = value
945
 
946
  @add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
947
+ # Ignore copy
948
  def forward(
949
  self,
950
  input_ids: torch.LongTensor = None,
 
960
  cache_position: Optional[torch.LongTensor] = None,
961
  ) -> Union[Tuple, MoeModelOutputWithPast]:
962
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
963
+ output_router_logits = (
964
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
965
+ )
966
  output_hidden_states = (
967
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
968
  )
 
983
  if inputs_embeds is None:
984
  inputs_embeds = self.embed_tokens(input_ids)
985
 
 
 
 
 
 
 
 
 
986
  past_seen_tokens = 0
987
  if use_cache: # kept for BC (cache positions)
988
  if not isinstance(past_key_values, StaticCache):
 
997
  if position_ids is None:
998
  position_ids = cache_position.unsqueeze(0)
999
 
1000
+ if attention_mask is not None and self._use_flash_attention_2 and use_cache:
1001
+ is_padding_right = attention_mask[:, -1].sum().item() != inputs_embeds.shape[0]
1002
+ if is_padding_right:
1003
+ raise ValueError(
1004
+ "You are attempting to perform batched generation with padding_side='right'"
1005
+ " this may lead to unexpected behaviour for Flash Attention version of Gemmoe. Make sure to "
1006
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1007
+ )
1008
 
1009
+ if self._use_flash_attention_2:
1010
+ # 2d mask is passed through the layers
1011
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1012
+ else:
1013
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, past_seen_tokens)
1014
 
1015
  # normalized
1016
+ # Gemmoe downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
1017
+ # See https://github.com/huggingface/transformers/pull/29402
1018
+ normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=inputs_embeds.dtype)
1019
+ hidden_states = inputs_embeds * normalizer
1020
 
1021
  # decoder layers
1022
  all_hidden_states = () if output_hidden_states else None
1023
  all_self_attns = () if output_attentions else None
1024
+ all_router_logits = () if output_router_logits else None
1025
  next_decoder_cache = None
1026
 
1027
  for decoder_layer in self.layers:
1028
  if output_hidden_states:
1029
  all_hidden_states += (hidden_states,)
1030
+
1031
+ if self.gradient_checkpointing and self.training:
1032
  layer_outputs = self._gradient_checkpointing_func(
1033
  decoder_layer.__call__,
1034
  hidden_states,
1035
+ causal_mask if not self._use_flash_attention_2 else attention_mask,
1036
  position_ids,
1037
  past_key_values,
1038
  output_attentions,
1039
  output_router_logits,
1040
+ use_cache,
1041
  cache_position,
 
1042
  )
1043
  else:
1044
  layer_outputs = decoder_layer(
1045
  hidden_states,
1046
+ attention_mask=causal_mask if not self._use_flash_attention_2 else attention_mask,
1047
  position_ids=position_ids,
1048
  past_key_value=past_key_values,
1049
  output_attentions=output_attentions,
1050
  output_router_logits=output_router_logits,
1051
+ use_cache=use_cache,
1052
  cache_position=cache_position,
1053
  )
1054
 
 
1060
  if output_attentions:
1061
  all_self_attns += (layer_outputs[1],)
1062
 
1063
+ if output_router_logits:
1064
+ all_router_logits += (layer_outputs[-1],)
1065
+
1066
  hidden_states = self.norm(hidden_states)
1067
 
1068
  # add hidden states from the last decoder layer
 
1075
  next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
1076
  )
1077
  if not return_dict:
1078
+ return tuple(
1079
+ v
1080
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
1081
+ if v is not None
1082
+ )
1083
  return MoeModelOutputWithPast(
1084
  last_hidden_state=hidden_states,
1085
  past_key_values=next_cache,
1086
  hidden_states=all_hidden_states,
1087
  attentions=all_self_attns,
1088
+ router_logits=all_router_logits,
1089
  )
1090
 
1091
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1092
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1093
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1094
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1095
+ def _update_causal_mask(self, attention_mask, input_tensor, past_seen_tokens):
1096
  if self.config._attn_implementation == "flash_attention_2":
1097
  if attention_mask is not None and 0.0 in attention_mask:
1098
  return attention_mask
 
1113
  causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
1114
  causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
1115
  if attention_mask is not None:
1116
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1117
  if attention_mask.dim() == 2:
1118
  mask_length = attention_mask.shape[-1]
1119
  padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
1120
  causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
1121
  elif attention_mask.dim() == 4:
1122
+ # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
1123
+ # cache. In that case, the 4D attention mask attends to the newest tokens only.
1124
+ if attention_mask.shape[-2] < past_seen_tokens + input_tensor.shape[1]:
1125
+ offset = past_seen_tokens
1126
+ else:
1127
+ offset = 0
1128
  mask_shape = attention_mask.shape
1129
  mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
1130
+ causal_mask[
1131
+ : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
1132
+ ] = mask_slice
1133
 
1134
  if (
1135
  self.config._attn_implementation == "sdpa"
 
1150
 
1151
  return causal_mask
1152
 
1153
+
1154
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->GEMMOE,Llama->Gemmoe,llama->GEMMA
1155
  class GemmoeForCausalLM(GemmoePreTrainedModel):
1156
  _tied_weights_keys = ["lm_head.weight"]
1157
 
 
1186
 
1187
  @add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
1188
  @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
1189
  def forward(
1190
  self,
1191
  input_ids: torch.LongTensor = None,
 
1199
  output_hidden_states: Optional[bool] = None,
1200
  output_router_logits: Optional[bool] = None,
1201
  return_dict: Optional[bool] = None,
1202
+ cache_position: Optional[torch.LongTensor] = None,
1203
  ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1204
  r"""
1205
  Args:
 
1215
  ```python
1216
  >>> from transformers import AutoTokenizer, GemmoeForCausalLM
1217
 
1218
+ >>> model = GemmoeForCausalLM.from_pretrained("google/GEMMA-7b")
1219
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/GEMMA-7b")
1220
 
1221
+ >>> prompt = "What is your favorite condiment?"
1222
  >>> inputs = tokenizer(prompt, return_tensors="pt")
1223
 
1224
  >>> # Generate
1225
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1226
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1227
+ "What is your favorite condiment?"
1228
  ```"""
 
1229
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1230
  output_router_logits = (
1231
  output_router_logits if output_router_logits is not None else self.config.output_router_logits
1232
  )
 
1233
  output_hidden_states = (
1234
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1235
  )
 
1247
  output_hidden_states=output_hidden_states,
1248
  output_router_logits=output_router_logits,
1249
  return_dict=return_dict,
1250
+ cache_position=cache_position,
1251
  )
1252
 
1253
  hidden_states = outputs[0]
1254
  logits = self.lm_head(hidden_states)
1255
  logits = logits.float()
 
 
 
 
 
 
1256
 
1257
  loss = None
1258
  if labels is not None:
 
1270
  aux_loss = None
1271
  if output_router_logits:
1272
  aux_loss = load_balancing_loss_func(
1273
+ outputs.router_logits if return_dict else outputs[-1], self.num_experts, self.num_experts_per_tok
 
 
 
1274
  )
1275
  if labels is not None:
1276
+ loss += self.router_aux_loss_coef * aux_loss
1277
 
1278
  if not return_dict:
1279
  output = (logits,) + outputs[1:]
 
1292
  )
1293
 
1294
  def prepare_inputs_for_generation(
1295
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
 
 
 
 
 
 
1296
  ):
1297
+ # With static cache, the `past_key_values` is None
1298
+ # TODO joao: standardize interface for the different Cache classes and remove of this if
1299
+ has_static_cache = False
1300
+ if past_key_values is None:
1301
+ past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)
1302
+ has_static_cache = past_key_values is not None
1303
+
1304
+ past_length = 0
1305
  if past_key_values is not None:
1306
  if isinstance(past_key_values, Cache):
1307
+ past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
1308
+ max_cache_length = (
1309
+ torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
1310
+ if past_key_values.get_max_length() is not None
1311
+ else None
1312
+ )
1313
+ cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
1314
+ # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
1315
  else:
1316
  cache_length = past_length = past_key_values[0][0].shape[2]
1317
  max_cache_length = None
 
1348
  if inputs_embeds is not None and past_key_values is None:
1349
  model_inputs = {"inputs_embeds": inputs_embeds}
1350
  else:
1351
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
1352
+ # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
1353
+ # TODO: use `next_tokens` directly instead.
1354
+ model_inputs = {"input_ids": input_ids.contiguous()}
1355
+
1356
+ input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
1357
+ if cache_position is None:
1358
+ cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
1359
+ else:
1360
+ cache_position = cache_position[-input_length:]
1361
+
1362
+ if has_static_cache:
1363
+ past_key_values = None
1364
 
1365
  model_inputs.update(
1366
  {
1367
  "position_ids": position_ids,
1368
+ "cache_position": cache_position,
1369
  "past_key_values": past_key_values,
1370
  "use_cache": kwargs.get("use_cache"),
1371
  "attention_mask": attention_mask,
 
1372
  }
1373
  )
1374
  return model_inputs
 
1381
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1382
  )
1383
  return reordered_past
1384
+
1385
+
1386
  @add_start_docstrings(
1387
  """
1388
  The Gemmoe Model transformer with a sequence classification head on top (linear layer).
1389
+
1390
  [`GemmoeForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1391
  (e.g. GPT-2) do.
1392
 
 
1398
  """,
1399
  GEMMOE_START_DOCSTRING,
1400
  )
1401
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->GEMMOE,Llama->Gemmoe
1402
  class GemmoeForSequenceClassification(GemmoePreTrainedModel):
1403
  def __init__(self, config):
1404
  super().__init__(config)