Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +5 -4
modeling_gemmoe.py
CHANGED
@@ -1086,8 +1086,6 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
1086 |
for decoder_layer in self.layers:
|
1087 |
if output_hidden_states:
|
1088 |
all_hidden_states += (hidden_states,)
|
1089 |
-
|
1090 |
-
if self.gradient_checkpointing and self.training:
|
1091 |
layer_outputs = self._gradient_checkpointing_func(
|
1092 |
decoder_layer.__call__,
|
1093 |
hidden_states,
|
@@ -1095,8 +1093,10 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
1095 |
position_ids,
|
1096 |
past_key_values,
|
1097 |
output_attentions,
|
1098 |
-
|
|
|
1099 |
cache_position,
|
|
|
1100 |
)
|
1101 |
else:
|
1102 |
layer_outputs = decoder_layer(
|
@@ -1105,7 +1105,8 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
1105 |
position_ids=position_ids,
|
1106 |
past_key_value=past_key_values,
|
1107 |
output_attentions=output_attentions,
|
1108 |
-
|
|
|
1109 |
cache_position=cache_position,
|
1110 |
)
|
1111 |
|
|
|
1086 |
for decoder_layer in self.layers:
|
1087 |
if output_hidden_states:
|
1088 |
all_hidden_states += (hidden_states,)
|
|
|
|
|
1089 |
layer_outputs = self._gradient_checkpointing_func(
|
1090 |
decoder_layer.__call__,
|
1091 |
hidden_states,
|
|
|
1093 |
position_ids,
|
1094 |
past_key_values,
|
1095 |
output_attentions,
|
1096 |
+
output_router_logits,
|
1097 |
+
use_cache.item() if isinstance(use_cache, torch.Tensor) else use_cache,
|
1098 |
cache_position,
|
1099 |
+
output_router_logits,
|
1100 |
)
|
1101 |
else:
|
1102 |
layer_outputs = decoder_layer(
|
|
|
1105 |
position_ids=position_ids,
|
1106 |
past_key_value=past_key_values,
|
1107 |
output_attentions=output_attentions,
|
1108 |
+
output_router_logits=output_router_logits,
|
1109 |
+
use_cache=use_cache.item() if isinstance(use_cache, torch.Tensor) else use_cache,
|
1110 |
cache_position=cache_position,
|
1111 |
)
|
1112 |
|