Crystalcareai commited on
Commit
a0d5586
·
verified ·
1 Parent(s): 7b9b6d3

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +5 -10
modeling_gemmoe.py CHANGED
@@ -670,10 +670,9 @@ class GemmoeBlockSparseTop2MLP(nn.Module):
670
  self.act_fn = approx_gelu
671
 
672
  def forward(self, hidden_states):
673
- hidden_states = hidden_states.to(torch.float32) # Cast to float32
674
  current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
675
- current_hidden_states = self.w2(current_hidden_states.to(hidden_states.dtype)) # Cast back to original dtype
676
- return current_hidden_states
677
 
678
 
679
  class GemmoeSparseMoeBlock(nn.Module):
@@ -694,15 +693,11 @@ class GemmoeSparseMoeBlock(nn.Module):
694
  hidden_states = hidden_states.view(-1, hidden_dim)
695
 
696
  # router_logits: (batch * sequence_length, n_experts)
697
- hidden_states_float = hidden_states.float() # Cast to float32
698
- router_logits = self.gate(hidden_states_float)
699
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float32)
700
  topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
701
  topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
702
 
703
- # we cast back to the input dtype
704
- topk_weight = topk_weight.to(hidden_states.dtype)
705
-
706
  hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
707
 
708
  y = torch.empty_like(hidden_states)
@@ -716,7 +711,7 @@ class GemmoeSparseMoeBlock(nn.Module):
716
  y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
717
 
718
  final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
719
- return final_hidden_states, router_logits
720
 
721
 
722
  # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMOE,Llama->Gemmoe
 
670
  self.act_fn = approx_gelu
671
 
672
  def forward(self, hidden_states):
 
673
  current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
674
+ current_hidden_states = self.w2(current_hidden_states)
675
+ return current_hidden_states.to(hidden_states.dtype)
676
 
677
 
678
  class GemmoeSparseMoeBlock(nn.Module):
 
693
  hidden_states = hidden_states.view(-1, hidden_dim)
694
 
695
  # router_logits: (batch * sequence_length, n_experts)
696
+ router_logits = self.gate(hidden_states)
697
+ routing_weights = F.softmax(router_logits, dim=1)
 
698
  topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
699
  topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
700
 
 
 
 
701
  hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
702
 
703
  y = torch.empty_like(hidden_states)
 
711
  y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
712
 
713
  final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
714
+ return final_hidden_states.to(hidden_states.dtype), router_logits.to(hidden_states.dtype)
715
 
716
 
717
  # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMOE,Llama->Gemmoe