Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +11 -10
modeling_gemmoe.py
CHANGED
@@ -682,42 +682,43 @@ class GemmoeBlockSparseTop2MLP(GemmoeBlockSparseTop2MLP):
|
|
682 |
super().__init__(*args, **kwargs)
|
683 |
|
684 |
class GemmoeSparseMoeBlock(nn.Module):
|
685 |
-
"""
|
686 |
-
This implementation is strictly equivalent to standard MoE with full capacity (no dropped tokens). It's faster since it formulates MoE operations in terms of block-sparse operations to accommodate imbalanced assignments of tokens to experts.
|
687 |
-
"""
|
688 |
-
|
689 |
def __init__(self, config):
|
690 |
super().__init__()
|
691 |
self.hidden_dim = config.hidden_size
|
692 |
self.ffn_dim = config.intermediate_size
|
693 |
self.num_experts = config.num_local_experts
|
694 |
-
self.top_k =
|
695 |
|
696 |
# gating
|
697 |
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
698 |
|
699 |
self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
700 |
|
701 |
-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
702 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
703 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
|
704 |
# router_logits: (batch * sequence_length, n_experts)
|
705 |
router_logits = self.gate(hidden_states)
|
706 |
-
|
707 |
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
708 |
topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
|
709 |
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
|
|
|
|
710 |
topk_weight = topk_weight.to(hidden_states.dtype)
|
711 |
|
712 |
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
|
|
|
713 |
y = torch.empty_like(hidden_states)
|
|
|
714 |
flat_topk_idx = topk_idx.view(-1)
|
715 |
for i in range(self.num_experts):
|
716 |
expert = self.experts[i]
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
|
|
721 |
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
|
722 |
return final_hidden_states, router_logits
|
723 |
|
|
|
682 |
super().__init__(*args, **kwargs)
|
683 |
|
684 |
class GemmoeSparseMoeBlock(nn.Module):
|
|
|
|
|
|
|
|
|
685 |
def __init__(self, config):
|
686 |
super().__init__()
|
687 |
self.hidden_dim = config.hidden_size
|
688 |
self.ffn_dim = config.intermediate_size
|
689 |
self.num_experts = config.num_local_experts
|
690 |
+
self.top_k = 2
|
691 |
|
692 |
# gating
|
693 |
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
694 |
|
695 |
self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
696 |
|
697 |
+
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
698 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
699 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
700 |
+
|
701 |
# router_logits: (batch * sequence_length, n_experts)
|
702 |
router_logits = self.gate(hidden_states)
|
|
|
703 |
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
704 |
topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
|
705 |
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
706 |
+
|
707 |
+
# we cast back to the input dtype
|
708 |
topk_weight = topk_weight.to(hidden_states.dtype)
|
709 |
|
710 |
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
|
711 |
+
|
712 |
y = torch.empty_like(hidden_states)
|
713 |
+
|
714 |
flat_topk_idx = topk_idx.view(-1)
|
715 |
for i in range(self.num_experts):
|
716 |
expert = self.experts[i]
|
717 |
+
expert_output = expert(hidden_states[flat_topk_idx == i])
|
718 |
+
y[flat_topk_idx == i] = expert_output.to(y.dtype) # Cast expert_output to the same dtype as y
|
719 |
+
|
720 |
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
721 |
+
|
722 |
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
|
723 |
return final_hidden_states, router_logits
|
724 |
|