Crystalcareai commited on
Commit
2577c85
·
verified ·
1 Parent(s): 5a791d5

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +5 -4
modeling_gemmoe.py CHANGED
@@ -1086,8 +1086,6 @@ class GemmoeModel(GemmoePreTrainedModel):
1086
  for decoder_layer in self.layers:
1087
  if output_hidden_states:
1088
  all_hidden_states += (hidden_states,)
1089
-
1090
- if self.gradient_checkpointing and self.training:
1091
  layer_outputs = self._gradient_checkpointing_func(
1092
  decoder_layer.__call__,
1093
  hidden_states,
@@ -1095,8 +1093,10 @@ class GemmoeModel(GemmoePreTrainedModel):
1095
  position_ids,
1096
  past_key_values,
1097
  output_attentions,
1098
- bool(use_cache),
 
1099
  cache_position,
 
1100
  )
1101
  else:
1102
  layer_outputs = decoder_layer(
@@ -1105,7 +1105,8 @@ class GemmoeModel(GemmoePreTrainedModel):
1105
  position_ids=position_ids,
1106
  past_key_value=past_key_values,
1107
  output_attentions=output_attentions,
1108
- use_cache=bool(use_cache),
 
1109
  cache_position=cache_position,
1110
  )
1111
 
 
1086
  for decoder_layer in self.layers:
1087
  if output_hidden_states:
1088
  all_hidden_states += (hidden_states,)
 
 
1089
  layer_outputs = self._gradient_checkpointing_func(
1090
  decoder_layer.__call__,
1091
  hidden_states,
 
1093
  position_ids,
1094
  past_key_values,
1095
  output_attentions,
1096
+ output_router_logits,
1097
+ use_cache.item() if isinstance(use_cache, torch.Tensor) else use_cache,
1098
  cache_position,
1099
+ output_router_logits,
1100
  )
1101
  else:
1102
  layer_outputs = decoder_layer(
 
1105
  position_ids=position_ids,
1106
  past_key_value=past_key_values,
1107
  output_attentions=output_attentions,
1108
+ output_router_logits=output_router_logits,
1109
+ use_cache=use_cache.item() if isinstance(use_cache, torch.Tensor) else use_cache,
1110
  cache_position=cache_position,
1111
  )
1112