Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +5 -10
modeling_gemmoe.py
CHANGED
@@ -670,10 +670,9 @@ class GemmoeBlockSparseTop2MLP(nn.Module):
|
|
670 |
self.act_fn = approx_gelu
|
671 |
|
672 |
def forward(self, hidden_states):
|
673 |
-
hidden_states = hidden_states.to(torch.float32) # Cast to float32
|
674 |
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
|
675 |
-
current_hidden_states = self.w2(current_hidden_states
|
676 |
-
return current_hidden_states
|
677 |
|
678 |
|
679 |
class GemmoeSparseMoeBlock(nn.Module):
|
@@ -694,15 +693,11 @@ class GemmoeSparseMoeBlock(nn.Module):
|
|
694 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
695 |
|
696 |
# router_logits: (batch * sequence_length, n_experts)
|
697 |
-
|
698 |
-
|
699 |
-
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float32)
|
700 |
topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
|
701 |
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
702 |
|
703 |
-
# we cast back to the input dtype
|
704 |
-
topk_weight = topk_weight.to(hidden_states.dtype)
|
705 |
-
|
706 |
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
|
707 |
|
708 |
y = torch.empty_like(hidden_states)
|
@@ -716,7 +711,7 @@ class GemmoeSparseMoeBlock(nn.Module):
|
|
716 |
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
717 |
|
718 |
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
|
719 |
-
return final_hidden_states, router_logits
|
720 |
|
721 |
|
722 |
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMOE,Llama->Gemmoe
|
|
|
670 |
self.act_fn = approx_gelu
|
671 |
|
672 |
def forward(self, hidden_states):
|
|
|
673 |
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
|
674 |
+
current_hidden_states = self.w2(current_hidden_states)
|
675 |
+
return current_hidden_states.to(hidden_states.dtype)
|
676 |
|
677 |
|
678 |
class GemmoeSparseMoeBlock(nn.Module):
|
|
|
693 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
694 |
|
695 |
# router_logits: (batch * sequence_length, n_experts)
|
696 |
+
router_logits = self.gate(hidden_states)
|
697 |
+
routing_weights = F.softmax(router_logits, dim=1)
|
|
|
698 |
topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
|
699 |
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
700 |
|
|
|
|
|
|
|
701 |
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
|
702 |
|
703 |
y = torch.empty_like(hidden_states)
|
|
|
711 |
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
712 |
|
713 |
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
|
714 |
+
return final_hidden_states.to(hidden_states.dtype), router_logits.to(hidden_states.dtype)
|
715 |
|
716 |
|
717 |
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMOE,Llama->Gemmoe
|