Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +37 -29
modeling_gemmoe.py
CHANGED
@@ -629,23 +629,32 @@ GEMMOE_ATTENTION_CLASSES = {
|
|
629 |
"sdpa": GemmoeSdpaAttention,
|
630 |
}
|
631 |
|
632 |
-
|
633 |
-
def __init__(self, config: GemmoeConfig):
|
634 |
-
super().__init__()
|
635 |
-
self.ffn_dim = config.intermediate_size
|
636 |
-
self.hidden_dim = config.hidden_size
|
637 |
-
|
638 |
-
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
639 |
-
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
|
640 |
-
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
641 |
-
|
642 |
-
self.act_fn = approx_gelu
|
643 |
-
|
644 |
-
def forward(self, hidden_states):
|
645 |
-
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
|
646 |
-
current_hidden_states = self.w2(current_hidden_states)
|
647 |
-
return current_hidden_states.to(hidden_states.dtype)
|
648 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
649 |
|
650 |
class GemmoeSparseMoeBlock(nn.Module):
|
651 |
def __init__(self, config):
|
@@ -654,11 +663,14 @@ class GemmoeSparseMoeBlock(nn.Module):
|
|
654 |
self.ffn_dim = config.intermediate_size
|
655 |
self.num_experts = config.num_local_experts
|
656 |
self.top_k = 2
|
|
|
657 |
|
658 |
# gating
|
659 |
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
660 |
|
661 |
-
self.
|
|
|
|
|
662 |
|
663 |
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
664 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
@@ -667,22 +679,18 @@ class GemmoeSparseMoeBlock(nn.Module):
|
|
667 |
# router_logits: (batch * sequence_length, n_experts)
|
668 |
router_logits = self.gate(hidden_states)
|
669 |
routing_weights = F.softmax(router_logits, dim=1)
|
670 |
-
|
|
|
671 |
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
672 |
|
673 |
-
|
674 |
-
|
675 |
-
y = torch.empty_like(hidden_states)
|
676 |
-
|
677 |
-
flat_topk_idx = topk_idx.view(-1)
|
678 |
-
for i in range(self.num_experts):
|
679 |
-
expert = self.experts[i]
|
680 |
-
expert_output = expert(hidden_states[flat_topk_idx == i])
|
681 |
-
y[flat_topk_idx == i] = expert_output
|
682 |
|
683 |
-
|
|
|
|
|
|
|
684 |
|
685 |
-
final_hidden_states =
|
686 |
return final_hidden_states.to(hidden_states.dtype), router_logits.to(hidden_states.dtype)
|
687 |
|
688 |
|
|
|
629 |
"sdpa": GemmoeSdpaAttention,
|
630 |
}
|
631 |
|
632 |
+
scatter2scatter=False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
633 |
|
634 |
+
class ParallelLinear(nn.Module):
|
635 |
+
def __init__(self, in_features, out_features, num_experts):
|
636 |
+
super().__init__()
|
637 |
+
self.in_features = in_features
|
638 |
+
self.out_features = out_features
|
639 |
+
self.num_experts = num_experts
|
640 |
+
self.weight = nn.Parameter(torch.Tensor(num_experts, in_features, out_features))
|
641 |
+
nn.init.xavier_uniform_(self.weight)
|
642 |
+
|
643 |
+
def forward(self, input, ordering, top_k, routing_weights=None, grouped_in=False, grouped_out=False, use_scatter2scatter=False):
|
644 |
+
if use_scatter2scatter:
|
645 |
+
output = scatter2scatter(input, self.weight, ordering, top_k, grouped_in=grouped_in, grouped_out=grouped_out)
|
646 |
+
else:
|
647 |
+
if not grouped_in:
|
648 |
+
input = input[ordering]
|
649 |
+
output = torch.bmm(input.view(-1, top_k, self.in_features), self.weight[ordering].view(-1, self.in_features, self.out_features))
|
650 |
+
if not grouped_out:
|
651 |
+
output = output.view(-1, top_k * self.out_features)
|
652 |
+
output = output[torch.argsort(ordering)]
|
653 |
+
|
654 |
+
if routing_weights is not None:
|
655 |
+
output = output.view(-1, top_k, self.out_features)
|
656 |
+
output = torch.bmm(routing_weights.unsqueeze(1), output).squeeze(1)
|
657 |
+
return output
|
658 |
|
659 |
class GemmoeSparseMoeBlock(nn.Module):
|
660 |
def __init__(self, config):
|
|
|
663 |
self.ffn_dim = config.intermediate_size
|
664 |
self.num_experts = config.num_local_experts
|
665 |
self.top_k = 2
|
666 |
+
self.use_scatter2scatter = config.use_scatter2scatter if hasattr(config, 'use_scatter2scatter') else False
|
667 |
|
668 |
# gating
|
669 |
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
670 |
|
671 |
+
self.expert_mlp1 = ParallelLinear(self.hidden_dim, self.ffn_dim, self.num_experts)
|
672 |
+
self.expert_mlp2 = ParallelLinear(self.ffn_dim, self.hidden_dim, self.num_experts)
|
673 |
+
self.activation = nn.GELU()
|
674 |
|
675 |
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
676 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
|
|
679 |
# router_logits: (batch * sequence_length, n_experts)
|
680 |
router_logits = self.gate(hidden_states)
|
681 |
routing_weights = F.softmax(router_logits, dim=1)
|
682 |
+
_, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
|
683 |
+
topk_weight = routing_weights.gather(1, topk_idx)
|
684 |
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
685 |
|
686 |
+
ordering = topk_idx.view(-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
687 |
|
688 |
+
# ParallelLinear for expert MLP
|
689 |
+
hidden_states = self.expert_mlp1(hidden_states, ordering, self.top_k, grouped_in=False, grouped_out=True, use_scatter2scatter=self.use_scatter2scatter)
|
690 |
+
hidden_states = self.activation(hidden_states)
|
691 |
+
hidden_states = self.expert_mlp2(hidden_states, ordering, self.top_k, topk_weight, grouped_in=True, grouped_out=False, use_scatter2scatter=self.use_scatter2scatter)
|
692 |
|
693 |
+
final_hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
694 |
return final_hidden_states.to(hidden_states.dtype), router_logits.to(hidden_states.dtype)
|
695 |
|
696 |
|