Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +8 -18
modeling_gemmoe.py
CHANGED
@@ -670,16 +670,11 @@ class GemmoeBlockSparseTop2MLP(nn.Module):
|
|
670 |
self.act_fn = approx_gelu
|
671 |
|
672 |
def forward(self, hidden_states):
|
|
|
673 |
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
|
674 |
-
current_hidden_states = self.w2(current_hidden_states)
|
675 |
return current_hidden_states
|
676 |
|
677 |
-
class GemmoeBlockSparseTop2MLP(GemmoeBlockSparseTop2MLP):
|
678 |
-
def __init__(self, *args, **kwargs):
|
679 |
-
logger.warning_once(
|
680 |
-
"GemmoeBLockSparseTop2MLP is deprecated by GemmoeBlockSparseTop2MLP and will be removed in v4.40."
|
681 |
-
)
|
682 |
-
super().__init__(*args, **kwargs)
|
683 |
|
684 |
class GemmoeSparseMoeBlock(nn.Module):
|
685 |
def __init__(self, config):
|
@@ -699,8 +694,9 @@ class GemmoeSparseMoeBlock(nn.Module):
|
|
699 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
700 |
|
701 |
# router_logits: (batch * sequence_length, n_experts)
|
702 |
-
|
703 |
-
|
|
|
704 |
topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
|
705 |
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
706 |
|
@@ -715,7 +711,7 @@ class GemmoeSparseMoeBlock(nn.Module):
|
|
715 |
for i in range(self.num_experts):
|
716 |
expert = self.experts[i]
|
717 |
expert_output = expert(hidden_states[flat_topk_idx == i])
|
718 |
-
y[flat_topk_idx == i] = expert_output
|
719 |
|
720 |
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
721 |
|
@@ -983,7 +979,6 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
983 |
self.embed_tokens = value
|
984 |
|
985 |
@add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
|
986 |
-
# Ignore copy
|
987 |
def forward(
|
988 |
self,
|
989 |
input_ids: torch.LongTensor = None,
|
@@ -994,7 +989,7 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
994 |
use_cache: Optional[bool] = None,
|
995 |
output_attentions: Optional[bool] = None,
|
996 |
output_hidden_states: Optional[bool] = None,
|
997 |
-
output_router_logits: Optional[bool] = None,
|
998 |
return_dict: Optional[bool] = None,
|
999 |
cache_position: Optional[torch.LongTensor] = None,
|
1000 |
) -> Union[Tuple, MoeModelOutputWithPast]:
|
@@ -1023,7 +1018,6 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
1023 |
# Fix for precision issue when casting to bfloat16
|
1024 |
hidden_size_sqrt = math.sqrt(self.config.hidden_size)
|
1025 |
if inputs_embeds.dtype == torch.bfloat16:
|
1026 |
-
|
1027 |
pass
|
1028 |
|
1029 |
hidden_states = inputs_embeds * hidden_size_sqrt
|
@@ -1110,10 +1104,6 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
1110 |
attentions=all_self_attns,
|
1111 |
)
|
1112 |
|
1113 |
-
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
1114 |
-
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
1115 |
-
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
1116 |
-
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
1117 |
def _update_causal_mask(self, attention_mask, input_tensor):
|
1118 |
if self.config._attn_implementation == "flash_attention_2":
|
1119 |
if attention_mask is not None and 0.0 in attention_mask:
|
@@ -1135,7 +1125,7 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
1135 |
causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
|
1136 |
causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
|
1137 |
if attention_mask is not None:
|
1138 |
-
causal_mask = causal_mask.clone()
|
1139 |
if attention_mask.dim() == 2:
|
1140 |
mask_length = attention_mask.shape[-1]
|
1141 |
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
|
|
670 |
self.act_fn = approx_gelu
|
671 |
|
672 |
def forward(self, hidden_states):
|
673 |
+
hidden_states = hidden_states.to(torch.float32) # Cast to float32
|
674 |
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
|
675 |
+
current_hidden_states = self.w2(current_hidden_states.to(hidden_states.dtype)) # Cast back to original dtype
|
676 |
return current_hidden_states
|
677 |
|
|
|
|
|
|
|
|
|
|
|
|
|
678 |
|
679 |
class GemmoeSparseMoeBlock(nn.Module):
|
680 |
def __init__(self, config):
|
|
|
694 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
695 |
|
696 |
# router_logits: (batch * sequence_length, n_experts)
|
697 |
+
hidden_states_float = hidden_states.float() # Cast to float32
|
698 |
+
router_logits = self.gate(hidden_states_float)
|
699 |
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float32)
|
700 |
topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
|
701 |
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
702 |
|
|
|
711 |
for i in range(self.num_experts):
|
712 |
expert = self.experts[i]
|
713 |
expert_output = expert(hidden_states[flat_topk_idx == i])
|
714 |
+
y[flat_topk_idx == i] = expert_output
|
715 |
|
716 |
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
717 |
|
|
|
979 |
self.embed_tokens = value
|
980 |
|
981 |
@add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
|
|
|
982 |
def forward(
|
983 |
self,
|
984 |
input_ids: torch.LongTensor = None,
|
|
|
989 |
use_cache: Optional[bool] = None,
|
990 |
output_attentions: Optional[bool] = None,
|
991 |
output_hidden_states: Optional[bool] = None,
|
992 |
+
output_router_logits: Optional[bool] = None,
|
993 |
return_dict: Optional[bool] = None,
|
994 |
cache_position: Optional[torch.LongTensor] = None,
|
995 |
) -> Union[Tuple, MoeModelOutputWithPast]:
|
|
|
1018 |
# Fix for precision issue when casting to bfloat16
|
1019 |
hidden_size_sqrt = math.sqrt(self.config.hidden_size)
|
1020 |
if inputs_embeds.dtype == torch.bfloat16:
|
|
|
1021 |
pass
|
1022 |
|
1023 |
hidden_states = inputs_embeds * hidden_size_sqrt
|
|
|
1104 |
attentions=all_self_attns,
|
1105 |
)
|
1106 |
|
|
|
|
|
|
|
|
|
1107 |
def _update_causal_mask(self, attention_mask, input_tensor):
|
1108 |
if self.config._attn_implementation == "flash_attention_2":
|
1109 |
if attention_mask is not None and 0.0 in attention_mask:
|
|
|
1125 |
causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
|
1126 |
causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
|
1127 |
if attention_mask is not None:
|
1128 |
+
causal_mask = causal_mask.clone()
|
1129 |
if attention_mask.dim() == 2:
|
1130 |
mask_length = attention_mask.shape[-1]
|
1131 |
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|