Crystalcareai commited on
Commit
d72d599
·
verified ·
1 Parent(s): 65ab6d1

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +29 -37
modeling_gemmoe.py CHANGED
@@ -629,32 +629,23 @@ GEMMOE_ATTENTION_CLASSES = {
629
  "sdpa": GemmoeSdpaAttention,
630
  }
631
 
632
- scatter2scatter=False
633
-
634
- class ParallelLinear(nn.Module):
635
- def __init__(self, in_features, out_features, num_experts):
636
  super().__init__()
637
- self.in_features = in_features
638
- self.out_features = out_features
639
- self.num_experts = num_experts
640
- self.weight = nn.Parameter(torch.Tensor(num_experts, in_features, out_features))
641
- nn.init.xavier_uniform_(self.weight)
642
-
643
- def forward(self, input, ordering, top_k, routing_weights=None, grouped_in=False, grouped_out=False, use_scatter2scatter=False):
644
- if use_scatter2scatter:
645
- output = scatter2scatter(input, self.weight, ordering, top_k, grouped_in=grouped_in, grouped_out=grouped_out)
646
- else:
647
- if not grouped_in:
648
- input = input[ordering]
649
- output = torch.bmm(input.view(-1, top_k, self.in_features), self.weight[ordering].view(-1, self.in_features, self.out_features))
650
- if not grouped_out:
651
- output = output.view(-1, top_k * self.out_features)
652
- output = output[torch.argsort(ordering)]
653
-
654
- if routing_weights is not None:
655
- output = output.view(-1, top_k, self.out_features)
656
- output = torch.bmm(routing_weights.unsqueeze(1), output).squeeze(1)
657
- return output
658
 
659
  class GemmoeSparseMoeBlock(nn.Module):
660
  def __init__(self, config):
@@ -663,14 +654,11 @@ class GemmoeSparseMoeBlock(nn.Module):
663
  self.ffn_dim = config.intermediate_size
664
  self.num_experts = config.num_local_experts
665
  self.top_k = 2
666
- self.use_scatter2scatter = config.use_scatter2scatter if hasattr(config, 'use_scatter2scatter') else False
667
 
668
  # gating
669
  self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
670
 
671
- self.expert_mlp1 = ParallelLinear(self.hidden_dim, self.ffn_dim, self.num_experts)
672
- self.expert_mlp2 = ParallelLinear(self.ffn_dim, self.hidden_dim, self.num_experts)
673
- self.activation = nn.GELU()
674
 
675
  def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
676
  batch_size, sequence_length, hidden_dim = hidden_states.shape
@@ -679,18 +667,22 @@ class GemmoeSparseMoeBlock(nn.Module):
679
  # router_logits: (batch * sequence_length, n_experts)
680
  router_logits = self.gate(hidden_states)
681
  routing_weights = F.softmax(router_logits, dim=1)
682
- _, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
683
- topk_weight = routing_weights.gather(1, topk_idx)
684
  topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
685
 
686
- ordering = topk_idx.view(-1)
 
 
 
 
 
 
 
 
687
 
688
- # ParallelLinear for expert MLP
689
- hidden_states = self.expert_mlp1(hidden_states, ordering, self.top_k, grouped_in=False, grouped_out=True, use_scatter2scatter=self.use_scatter2scatter)
690
- hidden_states = self.activation(hidden_states)
691
- hidden_states = self.expert_mlp2(hidden_states, ordering, self.top_k, topk_weight, grouped_in=True, grouped_out=False, use_scatter2scatter=self.use_scatter2scatter)
692
 
693
- final_hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim)
694
  return final_hidden_states.to(hidden_states.dtype), router_logits.to(hidden_states.dtype)
695
 
696
 
 
629
  "sdpa": GemmoeSdpaAttention,
630
  }
631
 
632
+ class GemmoeBlockSparseTop2MLP(nn.Module):
633
+ def __init__(self, config: GemmoeConfig):
 
 
634
  super().__init__()
635
+ self.ffn_dim = config.intermediate_size
636
+ self.hidden_dim = config.hidden_size
637
+
638
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
639
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
640
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
641
+
642
+ self.act_fn = approx_gelu
643
+
644
+ def forward(self, hidden_states):
645
+ current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
646
+ current_hidden_states = self.w2(current_hidden_states)
647
+ return current_hidden_states.to(hidden_states.dtype)
648
+
 
 
 
 
 
 
 
649
 
650
  class GemmoeSparseMoeBlock(nn.Module):
651
  def __init__(self, config):
 
654
  self.ffn_dim = config.intermediate_size
655
  self.num_experts = config.num_local_experts
656
  self.top_k = 2
 
657
 
658
  # gating
659
  self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
660
 
661
+ self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
 
 
662
 
663
  def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
664
  batch_size, sequence_length, hidden_dim = hidden_states.shape
 
667
  # router_logits: (batch * sequence_length, n_experts)
668
  router_logits = self.gate(hidden_states)
669
  routing_weights = F.softmax(router_logits, dim=1)
670
+ topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
 
671
  topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
672
 
673
+ hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
674
+
675
+ y = torch.empty_like(hidden_states)
676
+
677
+ flat_topk_idx = topk_idx.view(-1)
678
+ for i in range(self.num_experts):
679
+ expert = self.experts[i]
680
+ expert_output = expert(hidden_states[flat_topk_idx == i])
681
+ y[flat_topk_idx == i] = expert_output
682
 
683
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
 
 
 
684
 
685
+ final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
686
  return final_hidden_states.to(hidden_states.dtype), router_logits.to(hidden_states.dtype)
687
 
688