Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +42 -9
modeling_gemmoe.py
CHANGED
@@ -689,10 +689,12 @@ class GemmoeDecoderLayer(nn.Module):
|
|
689 |
def __init__(self, config: GemmoeConfig, layer_idx: int):
|
690 |
super().__init__()
|
691 |
self.hidden_size = config.hidden_size
|
692 |
-
|
693 |
self.self_attn = GEMMOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
|
|
694 |
self.block_sparse_moe = GemmoeSparseMoeBlock(config)
|
695 |
self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
696 |
|
697 |
def forward(
|
698 |
self,
|
@@ -705,9 +707,32 @@ class GemmoeDecoderLayer(nn.Module):
|
|
705 |
use_cache: Optional[bool] = False,
|
706 |
**kwargs,
|
707 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
708 |
residual = hidden_states
|
|
|
709 |
hidden_states = self.input_layernorm(hidden_states)
|
710 |
|
|
|
711 |
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
712 |
hidden_states=hidden_states,
|
713 |
attention_mask=attention_mask,
|
@@ -717,17 +742,21 @@ class GemmoeDecoderLayer(nn.Module):
|
|
717 |
use_cache=use_cache,
|
718 |
)
|
719 |
hidden_states = residual + hidden_states
|
720 |
-
|
|
|
721 |
residual = hidden_states
|
722 |
-
hidden_states = self.
|
723 |
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
724 |
hidden_states = residual + hidden_states
|
725 |
|
726 |
outputs = (hidden_states,)
|
|
|
727 |
if output_attentions:
|
728 |
outputs += (self_attn_weights,)
|
|
|
729 |
if use_cache:
|
730 |
outputs += (present_key_value,)
|
|
|
731 |
if output_router_logits:
|
732 |
outputs += (router_logits,)
|
733 |
|
@@ -950,6 +979,14 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
950 |
if inputs_embeds is None:
|
951 |
inputs_embeds = self.embed_tokens(input_ids)
|
952 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
953 |
past_seen_tokens = 0
|
954 |
if use_cache: # kept for BC (cache positions)
|
955 |
if not isinstance(past_key_values, StaticCache):
|
@@ -969,12 +1006,8 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
969 |
# embed positions
|
970 |
hidden_states = inputs_embeds
|
971 |
|
972 |
-
#
|
973 |
-
|
974 |
-
if inputs_embeds.dtype == torch.bfloat16:
|
975 |
-
hidden_states = inputs_embeds * hidden_size_sqrt
|
976 |
-
else:
|
977 |
-
hidden_states = inputs_embeds * hidden_size_sqrt
|
978 |
|
979 |
# decoder layers
|
980 |
all_hidden_states = () if output_hidden_states else None
|
|
|
689 |
def __init__(self, config: GemmoeConfig, layer_idx: int):
|
690 |
super().__init__()
|
691 |
self.hidden_size = config.hidden_size
|
692 |
+
|
693 |
self.self_attn = GEMMOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
694 |
+
|
695 |
self.block_sparse_moe = GemmoeSparseMoeBlock(config)
|
696 |
self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
697 |
+
self.post_attention_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
698 |
|
699 |
def forward(
|
700 |
self,
|
|
|
707 |
use_cache: Optional[bool] = False,
|
708 |
**kwargs,
|
709 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
710 |
+
if "padding_mask" in kwargs:
|
711 |
+
warnings.warn(
|
712 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
713 |
+
)
|
714 |
+
"""
|
715 |
+
Args:
|
716 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
717 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
718 |
+
`(batch, sequence_length)` where padding elements are indicated by 0.
|
719 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
720 |
+
output_attentions (`bool`, *optional*):
|
721 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
722 |
+
returned tensors for more detail.
|
723 |
+
output_router_logits (`bool`, *optional*):
|
724 |
+
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
725 |
+
should not be returned during inference.
|
726 |
+
use_cache (`bool`, *optional*):
|
727 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
728 |
+
(see `past_key_values`).
|
729 |
+
"""
|
730 |
+
|
731 |
residual = hidden_states
|
732 |
+
|
733 |
hidden_states = self.input_layernorm(hidden_states)
|
734 |
|
735 |
+
# Self Attention
|
736 |
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
737 |
hidden_states=hidden_states,
|
738 |
attention_mask=attention_mask,
|
|
|
742 |
use_cache=use_cache,
|
743 |
)
|
744 |
hidden_states = residual + hidden_states
|
745 |
+
|
746 |
+
# Fully Connected
|
747 |
residual = hidden_states
|
748 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
749 |
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
750 |
hidden_states = residual + hidden_states
|
751 |
|
752 |
outputs = (hidden_states,)
|
753 |
+
|
754 |
if output_attentions:
|
755 |
outputs += (self_attn_weights,)
|
756 |
+
|
757 |
if use_cache:
|
758 |
outputs += (present_key_value,)
|
759 |
+
|
760 |
if output_router_logits:
|
761 |
outputs += (router_logits,)
|
762 |
|
|
|
979 |
if inputs_embeds is None:
|
980 |
inputs_embeds = self.embed_tokens(input_ids)
|
981 |
|
982 |
+
# Scale embeddings
|
983 |
+
# Fix for precision issue when casting to bfloat16
|
984 |
+
hidden_size_sqrt = math.sqrt(self.config.hidden_size)
|
985 |
+
if inputs_embeds.dtype == torch.bfloat16:
|
986 |
+
pass
|
987 |
+
|
988 |
+
hidden_states = inputs_embeds * hidden_size_sqrt
|
989 |
+
|
990 |
past_seen_tokens = 0
|
991 |
if use_cache: # kept for BC (cache positions)
|
992 |
if not isinstance(past_key_values, StaticCache):
|
|
|
1006 |
# embed positions
|
1007 |
hidden_states = inputs_embeds
|
1008 |
|
1009 |
+
# normalized
|
1010 |
+
hidden_states = hidden_states * (self.config.hidden_size**0.5)
|
|
|
|
|
|
|
|
|
1011 |
|
1012 |
# decoder layers
|
1013 |
all_hidden_states = () if output_hidden_states else None
|