Crystalcareai commited on
Commit
53006e5
·
verified ·
1 Parent(s): 892da81

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +42 -9
modeling_gemmoe.py CHANGED
@@ -689,10 +689,12 @@ class GemmoeDecoderLayer(nn.Module):
689
  def __init__(self, config: GemmoeConfig, layer_idx: int):
690
  super().__init__()
691
  self.hidden_size = config.hidden_size
692
-
693
  self.self_attn = GEMMOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
 
694
  self.block_sparse_moe = GemmoeSparseMoeBlock(config)
695
  self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
696
 
697
  def forward(
698
  self,
@@ -705,9 +707,32 @@ class GemmoeDecoderLayer(nn.Module):
705
  use_cache: Optional[bool] = False,
706
  **kwargs,
707
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
708
  residual = hidden_states
 
709
  hidden_states = self.input_layernorm(hidden_states)
710
 
 
711
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
712
  hidden_states=hidden_states,
713
  attention_mask=attention_mask,
@@ -717,17 +742,21 @@ class GemmoeDecoderLayer(nn.Module):
717
  use_cache=use_cache,
718
  )
719
  hidden_states = residual + hidden_states
720
-
 
721
  residual = hidden_states
722
- hidden_states = self.input_layernorm(hidden_states)
723
  hidden_states, router_logits = self.block_sparse_moe(hidden_states)
724
  hidden_states = residual + hidden_states
725
 
726
  outputs = (hidden_states,)
 
727
  if output_attentions:
728
  outputs += (self_attn_weights,)
 
729
  if use_cache:
730
  outputs += (present_key_value,)
 
731
  if output_router_logits:
732
  outputs += (router_logits,)
733
 
@@ -950,6 +979,14 @@ class GemmoeModel(GemmoePreTrainedModel):
950
  if inputs_embeds is None:
951
  inputs_embeds = self.embed_tokens(input_ids)
952
 
 
 
 
 
 
 
 
 
953
  past_seen_tokens = 0
954
  if use_cache: # kept for BC (cache positions)
955
  if not isinstance(past_key_values, StaticCache):
@@ -969,12 +1006,8 @@ class GemmoeModel(GemmoePreTrainedModel):
969
  # embed positions
970
  hidden_states = inputs_embeds
971
 
972
- # Scale embeddings
973
- hidden_size_sqrt = math.sqrt(self.config.hidden_size)
974
- if inputs_embeds.dtype == torch.bfloat16:
975
- hidden_states = inputs_embeds * hidden_size_sqrt
976
- else:
977
- hidden_states = inputs_embeds * hidden_size_sqrt
978
 
979
  # decoder layers
980
  all_hidden_states = () if output_hidden_states else None
 
689
  def __init__(self, config: GemmoeConfig, layer_idx: int):
690
  super().__init__()
691
  self.hidden_size = config.hidden_size
692
+
693
  self.self_attn = GEMMOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
694
+
695
  self.block_sparse_moe = GemmoeSparseMoeBlock(config)
696
  self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
697
+ self.post_attention_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
698
 
699
  def forward(
700
  self,
 
707
  use_cache: Optional[bool] = False,
708
  **kwargs,
709
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
710
+ if "padding_mask" in kwargs:
711
+ warnings.warn(
712
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
713
+ )
714
+ """
715
+ Args:
716
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
717
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
718
+ `(batch, sequence_length)` where padding elements are indicated by 0.
719
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
720
+ output_attentions (`bool`, *optional*):
721
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
722
+ returned tensors for more detail.
723
+ output_router_logits (`bool`, *optional*):
724
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
725
+ should not be returned during inference.
726
+ use_cache (`bool`, *optional*):
727
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
728
+ (see `past_key_values`).
729
+ """
730
+
731
  residual = hidden_states
732
+
733
  hidden_states = self.input_layernorm(hidden_states)
734
 
735
+ # Self Attention
736
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
737
  hidden_states=hidden_states,
738
  attention_mask=attention_mask,
 
742
  use_cache=use_cache,
743
  )
744
  hidden_states = residual + hidden_states
745
+
746
+ # Fully Connected
747
  residual = hidden_states
748
+ hidden_states = self.post_attention_layernorm(hidden_states)
749
  hidden_states, router_logits = self.block_sparse_moe(hidden_states)
750
  hidden_states = residual + hidden_states
751
 
752
  outputs = (hidden_states,)
753
+
754
  if output_attentions:
755
  outputs += (self_attn_weights,)
756
+
757
  if use_cache:
758
  outputs += (present_key_value,)
759
+
760
  if output_router_logits:
761
  outputs += (router_logits,)
762
 
 
979
  if inputs_embeds is None:
980
  inputs_embeds = self.embed_tokens(input_ids)
981
 
982
+ # Scale embeddings
983
+ # Fix for precision issue when casting to bfloat16
984
+ hidden_size_sqrt = math.sqrt(self.config.hidden_size)
985
+ if inputs_embeds.dtype == torch.bfloat16:
986
+ pass
987
+
988
+ hidden_states = inputs_embeds * hidden_size_sqrt
989
+
990
  past_seen_tokens = 0
991
  if use_cache: # kept for BC (cache positions)
992
  if not isinstance(past_key_values, StaticCache):
 
1006
  # embed positions
1007
  hidden_states = inputs_embeds
1008
 
1009
+ # normalized
1010
+ hidden_states = hidden_states * (self.config.hidden_size**0.5)
 
 
 
 
1011
 
1012
  # decoder layers
1013
  all_hidden_states = () if output_hidden_states else None