Crystalcareai commited on
Commit
61c9550
·
verified ·
1 Parent(s): 6354236

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +11 -10
modeling_gemmoe.py CHANGED
@@ -682,42 +682,43 @@ class GemmoeBlockSparseTop2MLP(GemmoeBlockSparseTop2MLP):
682
  super().__init__(*args, **kwargs)
683
 
684
  class GemmoeSparseMoeBlock(nn.Module):
685
- """
686
- This implementation is strictly equivalent to standard MoE with full capacity (no dropped tokens). It's faster since it formulates MoE operations in terms of block-sparse operations to accommodate imbalanced assignments of tokens to experts.
687
- """
688
-
689
  def __init__(self, config):
690
  super().__init__()
691
  self.hidden_dim = config.hidden_size
692
  self.ffn_dim = config.intermediate_size
693
  self.num_experts = config.num_local_experts
694
- self.top_k = config.num_experts_per_tok
695
 
696
  # gating
697
  self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
698
 
699
  self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
700
 
701
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
702
  batch_size, sequence_length, hidden_dim = hidden_states.shape
703
  hidden_states = hidden_states.view(-1, hidden_dim)
 
704
  # router_logits: (batch * sequence_length, n_experts)
705
  router_logits = self.gate(hidden_states)
706
-
707
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
708
  topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
709
  topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
 
 
710
  topk_weight = topk_weight.to(hidden_states.dtype)
711
 
712
  hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
 
713
  y = torch.empty_like(hidden_states)
 
714
  flat_topk_idx = topk_idx.view(-1)
715
  for i in range(self.num_experts):
716
  expert = self.experts[i]
717
- mask = flat_topk_idx == i
718
- if mask.any():
719
- y[mask] = expert(hidden_states[mask]).to(y.dtype)
720
  y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
 
721
  final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
722
  return final_hidden_states, router_logits
723
 
 
682
  super().__init__(*args, **kwargs)
683
 
684
  class GemmoeSparseMoeBlock(nn.Module):
 
 
 
 
685
  def __init__(self, config):
686
  super().__init__()
687
  self.hidden_dim = config.hidden_size
688
  self.ffn_dim = config.intermediate_size
689
  self.num_experts = config.num_local_experts
690
+ self.top_k = 2
691
 
692
  # gating
693
  self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
694
 
695
  self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
696
 
697
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
698
  batch_size, sequence_length, hidden_dim = hidden_states.shape
699
  hidden_states = hidden_states.view(-1, hidden_dim)
700
+
701
  # router_logits: (batch * sequence_length, n_experts)
702
  router_logits = self.gate(hidden_states)
 
703
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
704
  topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
705
  topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
706
+
707
+ # we cast back to the input dtype
708
  topk_weight = topk_weight.to(hidden_states.dtype)
709
 
710
  hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
711
+
712
  y = torch.empty_like(hidden_states)
713
+
714
  flat_topk_idx = topk_idx.view(-1)
715
  for i in range(self.num_experts):
716
  expert = self.experts[i]
717
+ expert_output = expert(hidden_states[flat_topk_idx == i])
718
+ y[flat_topk_idx == i] = expert_output.to(y.dtype) # Cast expert_output to the same dtype as y
719
+
720
  y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
721
+
722
  final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
723
  return final_hidden_states, router_logits
724