Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +189 -148
modeling_gemmoe.py
CHANGED
@@ -31,7 +31,7 @@ from transformers.modeling_attn_mask_utils import (
|
|
31 |
AttentionMaskConverter,
|
32 |
_prepare_4d_causal_attention_mask,
|
33 |
)
|
34 |
-
from transformers.modeling_outputs import
|
35 |
from transformers.modeling_utils import PreTrainedModel
|
36 |
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
|
37 |
from transformers.utils import (
|
@@ -45,8 +45,6 @@ from transformers.utils import (
|
|
45 |
from transformers.utils.import_utils import is_torch_fx_available
|
46 |
from .configuration_gemmoe import GemmoeConfig
|
47 |
|
48 |
-
from math import sqrt as math_sqrt
|
49 |
-
|
50 |
|
51 |
if is_flash_attn_2_available():
|
52 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
@@ -55,7 +53,6 @@ if is_flash_attn_2_available():
|
|
55 |
|
56 |
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
57 |
# It means that the function will not be traced through and simply appear as a node in the graph.
|
58 |
-
|
59 |
if is_torch_fx_available():
|
60 |
if not is_torch_greater_or_equal_than_1_13:
|
61 |
import torch.fx
|
@@ -67,9 +64,7 @@ logger = logging.get_logger(__name__)
|
|
67 |
|
68 |
_CONFIG_FOR_DOC = "GemmoeConfig"
|
69 |
|
70 |
-
def load_balancing_loss_func(
|
71 |
-
gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
|
72 |
-
) -> float:
|
73 |
r"""
|
74 |
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
|
75 |
|
@@ -79,73 +74,42 @@ def load_balancing_loss_func(
|
|
79 |
|
80 |
Args:
|
81 |
gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
|
82 |
-
Logits from the `gate`, should be a tuple of
|
83 |
-
shape [batch_size X sequence_length, num_experts].
|
84 |
-
attention_mask (`torch.Tensor`, None):
|
85 |
-
The attention_mask used in forward function
|
86 |
-
shape [batch_size X sequence_length] if not None.
|
87 |
num_experts (`int`, *optional*):
|
88 |
Number of experts
|
89 |
|
90 |
Returns:
|
91 |
The auxiliary loss.
|
92 |
"""
|
93 |
-
if gate_logits is None
|
94 |
return 0
|
95 |
|
96 |
if isinstance(gate_logits, tuple):
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
|
101 |
-
|
102 |
-
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
if attention_mask is None:
|
107 |
-
# Compute the percentage of tokens routed to each experts
|
108 |
-
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
109 |
-
|
110 |
-
# Compute the average probability of routing to these experts
|
111 |
-
router_prob_per_expert = torch.mean(routing_weights, dim=0)
|
112 |
-
else:
|
113 |
-
batch_size, sequence_length = attention_mask.shape
|
114 |
-
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
|
115 |
-
|
116 |
-
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
|
117 |
-
expert_attention_mask = (
|
118 |
-
attention_mask[None, :, :, None, None]
|
119 |
-
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
|
120 |
-
.reshape(-1, top_k, num_experts)
|
121 |
-
.to(compute_device)
|
122 |
-
)
|
123 |
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
)
|
128 |
|
129 |
-
|
130 |
-
|
131 |
-
attention_mask[None, :, :, None]
|
132 |
-
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
|
133 |
-
.reshape(-1, num_experts)
|
134 |
-
.to(compute_device)
|
135 |
-
)
|
136 |
|
137 |
-
|
138 |
-
router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
|
139 |
-
router_per_expert_attention_mask, dim=0
|
140 |
-
)
|
141 |
|
142 |
-
|
143 |
-
|
144 |
|
|
|
|
|
|
|
145 |
|
|
|
|
|
146 |
|
147 |
-
def approx_gelu(x):
|
148 |
-
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * x**3)))
|
149 |
|
150 |
def _get_unpad_data(attention_mask):
|
151 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
@@ -159,7 +123,6 @@ def _get_unpad_data(attention_mask):
|
|
159 |
)
|
160 |
|
161 |
|
162 |
-
|
163 |
class GemmoeRMSNorm(nn.Module):
|
164 |
def __init__(self, dim: int, eps: float = 1e-6):
|
165 |
super().__init__()
|
@@ -167,45 +130,48 @@ class GemmoeRMSNorm(nn.Module):
|
|
167 |
self.weight = nn.Parameter(torch.zeros(dim))
|
168 |
|
169 |
def _norm(self, x):
|
170 |
-
|
171 |
-
normed_x = x_float * torch.rsqrt(x_float.pow(2).mean(-1, keepdim=True) + self.eps)
|
172 |
-
return normed_x
|
173 |
|
174 |
def forward(self, x):
|
175 |
-
|
176 |
-
#
|
177 |
-
|
178 |
-
|
|
|
|
|
179 |
|
180 |
ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
|
181 |
|
|
|
182 |
class GemmoeRotaryEmbedding(nn.Module):
|
183 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
184 |
super().__init__()
|
|
|
185 |
self.dim = dim
|
186 |
self.max_position_embeddings = max_position_embeddings
|
187 |
self.base = base
|
188 |
-
self.
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
return
|
|
|
209 |
|
210 |
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
211 |
def rotate_half(x):
|
@@ -215,15 +181,34 @@ def rotate_half(x):
|
|
215 |
return torch.cat((-x2, x1), dim=-1)
|
216 |
|
217 |
|
|
|
218 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
219 |
-
|
220 |
-
|
221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
223 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
224 |
return q_embed, k_embed
|
225 |
|
226 |
-
|
227 |
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
228 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
229 |
"""
|
@@ -236,6 +221,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
236 |
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
237 |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
238 |
|
|
|
239 |
class GemmoeAttention(nn.Module):
|
240 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
241 |
|
@@ -638,8 +624,20 @@ class GemmoeBlockSparseTop2MLP(nn.Module):
|
|
638 |
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
639 |
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
|
640 |
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
641 |
-
|
642 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
643 |
|
644 |
def forward(self, hidden_states):
|
645 |
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
|
@@ -786,7 +784,6 @@ GEMMOE_START_DOCSTRING = r"""
|
|
786 |
"The bare Gemmoe Model outputting raw hidden-states without any specific head on top.",
|
787 |
GEMMOE_START_DOCSTRING,
|
788 |
)
|
789 |
-
|
790 |
class GemmoePreTrainedModel(PreTrainedModel):
|
791 |
config_class = GemmoeConfig
|
792 |
base_model_prefix = "model"
|
@@ -909,7 +906,7 @@ GEMMOE_INPUTS_DOCSTRING = r"""
|
|
909 |
"The bare Gemmoe Model outputting raw hidden-states without any specific head on top.",
|
910 |
GEMMOE_START_DOCSTRING,
|
911 |
)
|
912 |
-
|
913 |
class GemmoeModel(GemmoePreTrainedModel):
|
914 |
"""
|
915 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmoeDecoderLayer`]
|
@@ -929,6 +926,7 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
929 |
)
|
930 |
self.norm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
931 |
self.gradient_checkpointing = False
|
|
|
932 |
|
933 |
# Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
|
934 |
# NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
|
@@ -946,6 +944,7 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
946 |
self.embed_tokens = value
|
947 |
|
948 |
@add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
|
|
|
949 |
def forward(
|
950 |
self,
|
951 |
input_ids: torch.LongTensor = None,
|
@@ -961,6 +960,9 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
961 |
cache_position: Optional[torch.LongTensor] = None,
|
962 |
) -> Union[Tuple, MoeModelOutputWithPast]:
|
963 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
|
|
|
|
964 |
output_hidden_states = (
|
965 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
966 |
)
|
@@ -981,14 +983,6 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
981 |
if inputs_embeds is None:
|
982 |
inputs_embeds = self.embed_tokens(input_ids)
|
983 |
|
984 |
-
# Scale embeddings
|
985 |
-
# Fix for precision issue when casting to bfloat16
|
986 |
-
hidden_size_sqrt = math.sqrt(self.config.hidden_size)
|
987 |
-
if inputs_embeds.dtype == torch.bfloat16:
|
988 |
-
pass
|
989 |
-
|
990 |
-
hidden_states = inputs_embeds * hidden_size_sqrt
|
991 |
-
|
992 |
past_seen_tokens = 0
|
993 |
if use_cache: # kept for BC (cache positions)
|
994 |
if not isinstance(past_key_values, StaticCache):
|
@@ -1003,43 +997,58 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
1003 |
if position_ids is None:
|
1004 |
position_ids = cache_position.unsqueeze(0)
|
1005 |
|
1006 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1007 |
|
1008 |
-
|
1009 |
-
|
|
|
|
|
|
|
1010 |
|
1011 |
# normalized
|
1012 |
-
|
|
|
|
|
|
|
1013 |
|
1014 |
# decoder layers
|
1015 |
all_hidden_states = () if output_hidden_states else None
|
1016 |
all_self_attns = () if output_attentions else None
|
|
|
1017 |
next_decoder_cache = None
|
1018 |
|
1019 |
for decoder_layer in self.layers:
|
1020 |
if output_hidden_states:
|
1021 |
all_hidden_states += (hidden_states,)
|
|
|
|
|
1022 |
layer_outputs = self._gradient_checkpointing_func(
|
1023 |
decoder_layer.__call__,
|
1024 |
hidden_states,
|
1025 |
-
causal_mask,
|
1026 |
position_ids,
|
1027 |
past_key_values,
|
1028 |
output_attentions,
|
1029 |
output_router_logits,
|
1030 |
-
use_cache
|
1031 |
cache_position,
|
1032 |
-
output_router_logits,
|
1033 |
)
|
1034 |
else:
|
1035 |
layer_outputs = decoder_layer(
|
1036 |
hidden_states,
|
1037 |
-
attention_mask=causal_mask,
|
1038 |
position_ids=position_ids,
|
1039 |
past_key_value=past_key_values,
|
1040 |
output_attentions=output_attentions,
|
1041 |
output_router_logits=output_router_logits,
|
1042 |
-
use_cache=use_cache
|
1043 |
cache_position=cache_position,
|
1044 |
)
|
1045 |
|
@@ -1051,6 +1060,9 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
1051 |
if output_attentions:
|
1052 |
all_self_attns += (layer_outputs[1],)
|
1053 |
|
|
|
|
|
|
|
1054 |
hidden_states = self.norm(hidden_states)
|
1055 |
|
1056 |
# add hidden states from the last decoder layer
|
@@ -1063,15 +1075,24 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
1063 |
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
1064 |
)
|
1065 |
if not return_dict:
|
1066 |
-
return tuple(
|
|
|
|
|
|
|
|
|
1067 |
return MoeModelOutputWithPast(
|
1068 |
last_hidden_state=hidden_states,
|
1069 |
past_key_values=next_cache,
|
1070 |
hidden_states=all_hidden_states,
|
1071 |
attentions=all_self_attns,
|
|
|
1072 |
)
|
1073 |
|
1074 |
-
|
|
|
|
|
|
|
|
|
1075 |
if self.config._attn_implementation == "flash_attention_2":
|
1076 |
if attention_mask is not None and 0.0 in attention_mask:
|
1077 |
return attention_mask
|
@@ -1092,15 +1113,23 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
1092 |
causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
|
1093 |
causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
|
1094 |
if attention_mask is not None:
|
1095 |
-
causal_mask = causal_mask.clone()
|
1096 |
if attention_mask.dim() == 2:
|
1097 |
mask_length = attention_mask.shape[-1]
|
1098 |
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
1099 |
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
1100 |
elif attention_mask.dim() == 4:
|
|
|
|
|
|
|
|
|
|
|
|
|
1101 |
mask_shape = attention_mask.shape
|
1102 |
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
|
1103 |
-
causal_mask[
|
|
|
|
|
1104 |
|
1105 |
if (
|
1106 |
self.config._attn_implementation == "sdpa"
|
@@ -1121,6 +1150,8 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
1121 |
|
1122 |
return causal_mask
|
1123 |
|
|
|
|
|
1124 |
class GemmoeForCausalLM(GemmoePreTrainedModel):
|
1125 |
_tied_weights_keys = ["lm_head.weight"]
|
1126 |
|
@@ -1155,7 +1186,6 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1155 |
|
1156 |
@add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
|
1157 |
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1158 |
-
# Ignore copy
|
1159 |
def forward(
|
1160 |
self,
|
1161 |
input_ids: torch.LongTensor = None,
|
@@ -1169,6 +1199,7 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1169 |
output_hidden_states: Optional[bool] = None,
|
1170 |
output_router_logits: Optional[bool] = None,
|
1171 |
return_dict: Optional[bool] = None,
|
|
|
1172 |
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
|
1173 |
r"""
|
1174 |
Args:
|
@@ -1184,23 +1215,21 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1184 |
```python
|
1185 |
>>> from transformers import AutoTokenizer, GemmoeForCausalLM
|
1186 |
|
1187 |
-
>>> model = GemmoeForCausalLM.from_pretrained("
|
1188 |
-
>>> tokenizer = AutoTokenizer.from_pretrained("
|
1189 |
|
1190 |
-
>>> prompt = "
|
1191 |
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
1192 |
|
1193 |
>>> # Generate
|
1194 |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
1195 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
1196 |
-
"
|
1197 |
```"""
|
1198 |
-
|
1199 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1200 |
output_router_logits = (
|
1201 |
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
1202 |
)
|
1203 |
-
|
1204 |
output_hidden_states = (
|
1205 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1206 |
)
|
@@ -1218,17 +1247,12 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1218 |
output_hidden_states=output_hidden_states,
|
1219 |
output_router_logits=output_router_logits,
|
1220 |
return_dict=return_dict,
|
|
|
1221 |
)
|
1222 |
|
1223 |
hidden_states = outputs[0]
|
1224 |
logits = self.lm_head(hidden_states)
|
1225 |
logits = logits.float()
|
1226 |
-
|
1227 |
-
if self.training:
|
1228 |
-
for expert in self.model.layers[-1].block_sparse_moe.experts:
|
1229 |
-
for param in expert.parameters():
|
1230 |
-
if param.requires_grad and param.grad is None:
|
1231 |
-
param.grad = torch.zeros_like(param)
|
1232 |
|
1233 |
loss = None
|
1234 |
if labels is not None:
|
@@ -1246,13 +1270,10 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1246 |
aux_loss = None
|
1247 |
if output_router_logits:
|
1248 |
aux_loss = load_balancing_loss_func(
|
1249 |
-
outputs.router_logits if return_dict else outputs[-1],
|
1250 |
-
self.num_experts,
|
1251 |
-
self.num_experts_per_tok,
|
1252 |
-
attention_mask,
|
1253 |
)
|
1254 |
if labels is not None:
|
1255 |
-
loss += self.router_aux_loss_coef * aux_loss
|
1256 |
|
1257 |
if not return_dict:
|
1258 |
output = (logits,) + outputs[1:]
|
@@ -1271,20 +1292,26 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1271 |
)
|
1272 |
|
1273 |
def prepare_inputs_for_generation(
|
1274 |
-
self,
|
1275 |
-
input_ids,
|
1276 |
-
past_key_values=None,
|
1277 |
-
attention_mask=None,
|
1278 |
-
inputs_embeds=None,
|
1279 |
-
output_router_logits=False,
|
1280 |
-
**kwargs,
|
1281 |
):
|
1282 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1283 |
if past_key_values is not None:
|
1284 |
if isinstance(past_key_values, Cache):
|
1285 |
-
|
1286 |
-
|
1287 |
-
|
|
|
|
|
|
|
|
|
|
|
1288 |
else:
|
1289 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
1290 |
max_cache_length = None
|
@@ -1321,15 +1348,27 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1321 |
if inputs_embeds is not None and past_key_values is None:
|
1322 |
model_inputs = {"inputs_embeds": inputs_embeds}
|
1323 |
else:
|
1324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1325 |
|
1326 |
model_inputs.update(
|
1327 |
{
|
1328 |
"position_ids": position_ids,
|
|
|
1329 |
"past_key_values": past_key_values,
|
1330 |
"use_cache": kwargs.get("use_cache"),
|
1331 |
"attention_mask": attention_mask,
|
1332 |
-
"output_router_logits": output_router_logits,
|
1333 |
}
|
1334 |
)
|
1335 |
return model_inputs
|
@@ -1342,10 +1381,12 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1342 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
1343 |
)
|
1344 |
return reordered_past
|
1345 |
-
|
|
|
1346 |
@add_start_docstrings(
|
1347 |
"""
|
1348 |
The Gemmoe Model transformer with a sequence classification head on top (linear layer).
|
|
|
1349 |
[`GemmoeForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
1350 |
(e.g. GPT-2) do.
|
1351 |
|
@@ -1357,7 +1398,7 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1357 |
""",
|
1358 |
GEMMOE_START_DOCSTRING,
|
1359 |
)
|
1360 |
-
|
1361 |
class GemmoeForSequenceClassification(GemmoePreTrainedModel):
|
1362 |
def __init__(self, config):
|
1363 |
super().__init__(config)
|
|
|
31 |
AttentionMaskConverter,
|
32 |
_prepare_4d_causal_attention_mask,
|
33 |
)
|
34 |
+
from transformers.modeling_outputs import MoeModelOutputWithPast, MoeCausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
35 |
from transformers.modeling_utils import PreTrainedModel
|
36 |
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
|
37 |
from transformers.utils import (
|
|
|
45 |
from transformers.utils.import_utils import is_torch_fx_available
|
46 |
from .configuration_gemmoe import GemmoeConfig
|
47 |
|
|
|
|
|
48 |
|
49 |
if is_flash_attn_2_available():
|
50 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
|
53 |
|
54 |
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
55 |
# It means that the function will not be traced through and simply appear as a node in the graph.
|
|
|
56 |
if is_torch_fx_available():
|
57 |
if not is_torch_greater_or_equal_than_1_13:
|
58 |
import torch.fx
|
|
|
64 |
|
65 |
_CONFIG_FOR_DOC = "GemmoeConfig"
|
66 |
|
67 |
+
def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2) -> float:
|
|
|
|
|
68 |
r"""
|
69 |
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
|
70 |
|
|
|
74 |
|
75 |
Args:
|
76 |
gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
|
77 |
+
Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, seqeunce_length, num_experts].
|
|
|
|
|
|
|
|
|
78 |
num_experts (`int`, *optional*):
|
79 |
Number of experts
|
80 |
|
81 |
Returns:
|
82 |
The auxiliary loss.
|
83 |
"""
|
84 |
+
if gate_logits is None:
|
85 |
return 0
|
86 |
|
87 |
if isinstance(gate_logits, tuple):
|
88 |
+
# cat along the layers?
|
89 |
+
gate_logits = torch.cat(gate_logits, dim=0)
|
|
|
|
|
|
|
|
|
90 |
|
91 |
+
routing_weights, selected_experts = torch.topk(gate_logits, top_k, dim=-1)
|
92 |
+
routing_weights = routing_weights.softmax(dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
+
# cast the expert indices to int64, otherwise one-hot encoding will fail
|
95 |
+
if selected_experts.dtype != torch.int64:
|
96 |
+
selected_experts = selected_experts.to(torch.int64)
|
|
|
97 |
|
98 |
+
if len(selected_experts.shape) == 2:
|
99 |
+
selected_experts = selected_experts.unsqueeze(2)
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
|
|
|
|
|
|
|
102 |
|
103 |
+
# For a given token, determine if it was routed to a given expert.
|
104 |
+
expert_mask = torch.max(expert_mask, axis=-2).values
|
105 |
|
106 |
+
# cast to float32 otherwise mean will fail
|
107 |
+
expert_mask = expert_mask.to(torch.float32)
|
108 |
+
tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2)
|
109 |
|
110 |
+
router_prob_per_group_and_expert = torch.mean(routing_weights, axis=-1)
|
111 |
+
return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert.unsqueeze(-1)) * (num_experts**2)
|
112 |
|
|
|
|
|
113 |
|
114 |
def _get_unpad_data(attention_mask):
|
115 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
|
|
123 |
)
|
124 |
|
125 |
|
|
|
126 |
class GemmoeRMSNorm(nn.Module):
|
127 |
def __init__(self, dim: int, eps: float = 1e-6):
|
128 |
super().__init__()
|
|
|
130 |
self.weight = nn.Parameter(torch.zeros(dim))
|
131 |
|
132 |
def _norm(self, x):
|
133 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
|
|
|
134 |
|
135 |
def forward(self, x):
|
136 |
+
output = self._norm(x.float())
|
137 |
+
# Llama does x.to(float16) * w whilst Gemmoe is (x * w).to(float16)
|
138 |
+
# See https://github.com/huggingface/transformers/pull/29402
|
139 |
+
output = output * (1.0 + self.weight.float())
|
140 |
+
return output.type_as(x)
|
141 |
+
|
142 |
|
143 |
ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
|
144 |
|
145 |
+
|
146 |
class GemmoeRotaryEmbedding(nn.Module):
|
147 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
148 |
super().__init__()
|
149 |
+
|
150 |
self.dim = dim
|
151 |
self.max_position_embeddings = max_position_embeddings
|
152 |
self.base = base
|
153 |
+
self.register_buffer("inv_freq", None, persistent=False)
|
154 |
+
|
155 |
+
@torch.no_grad()
|
156 |
+
def forward(self, x, position_ids, seq_len=None):
|
157 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
158 |
+
if self.inv_freq is None:
|
159 |
+
self.inv_freq = 1.0 / (
|
160 |
+
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
|
161 |
+
)
|
162 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
163 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
164 |
+
# Force float32 since bfloat16 loses precision on long contexts
|
165 |
+
# See https://github.com/huggingface/transformers/pull/29285
|
166 |
+
device_type = x.device.type
|
167 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
168 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
169 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
170 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
171 |
+
cos = emb.cos()
|
172 |
+
sin = emb.sin()
|
173 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
174 |
+
|
175 |
|
176 |
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
177 |
def rotate_half(x):
|
|
|
181 |
return torch.cat((-x2, x1), dim=-1)
|
182 |
|
183 |
|
184 |
+
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
185 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
186 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
187 |
+
|
188 |
+
Args:
|
189 |
+
q (`torch.Tensor`): The query tensor.
|
190 |
+
k (`torch.Tensor`): The key tensor.
|
191 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
192 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
193 |
+
position_ids (`torch.Tensor`, *optional*):
|
194 |
+
Deprecated and unused.
|
195 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
196 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
197 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
198 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
199 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
200 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
201 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
202 |
+
Returns:
|
203 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
204 |
+
"""
|
205 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
206 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
207 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
208 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
209 |
return q_embed, k_embed
|
210 |
|
211 |
+
|
212 |
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
213 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
214 |
"""
|
|
|
221 |
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
222 |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
223 |
|
224 |
+
|
225 |
class GemmoeAttention(nn.Module):
|
226 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
227 |
|
|
|
624 |
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
625 |
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
|
626 |
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
627 |
+
|
628 |
+
if config.hidden_activation is None:
|
629 |
+
logger.warning_once(
|
630 |
+
"Gemmoe's activation function should be approximate GeLU and not exact GeLU.\n"
|
631 |
+
"Changing the activation function to `gelu_pytorch_tanh`."
|
632 |
+
f"if you want to use the legacy `{config.hidden_act}`, "
|
633 |
+
f"edit the `model.config` to set `hidden_activation={config.hidden_act}` "
|
634 |
+
" instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details."
|
635 |
+
)
|
636 |
+
hidden_activation = "gelu_pytorch_tanh"
|
637 |
+
else:
|
638 |
+
hidden_activation = config.hidden_activation
|
639 |
+
|
640 |
+
self.act_fn = ACT2FN[hidden_activation]
|
641 |
|
642 |
def forward(self, hidden_states):
|
643 |
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
|
|
|
784 |
"The bare Gemmoe Model outputting raw hidden-states without any specific head on top.",
|
785 |
GEMMOE_START_DOCSTRING,
|
786 |
)
|
|
|
787 |
class GemmoePreTrainedModel(PreTrainedModel):
|
788 |
config_class = GemmoeConfig
|
789 |
base_model_prefix = "model"
|
|
|
906 |
"The bare Gemmoe Model outputting raw hidden-states without any specific head on top.",
|
907 |
GEMMOE_START_DOCSTRING,
|
908 |
)
|
909 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaModel with LLAMA->GEMMOE,Llama->Gemmoe
|
910 |
class GemmoeModel(GemmoePreTrainedModel):
|
911 |
"""
|
912 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmoeDecoderLayer`]
|
|
|
926 |
)
|
927 |
self.norm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
928 |
self.gradient_checkpointing = False
|
929 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
930 |
|
931 |
# Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
|
932 |
# NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
|
|
|
944 |
self.embed_tokens = value
|
945 |
|
946 |
@add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
|
947 |
+
# Ignore copy
|
948 |
def forward(
|
949 |
self,
|
950 |
input_ids: torch.LongTensor = None,
|
|
|
960 |
cache_position: Optional[torch.LongTensor] = None,
|
961 |
) -> Union[Tuple, MoeModelOutputWithPast]:
|
962 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
963 |
+
output_router_logits = (
|
964 |
+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
965 |
+
)
|
966 |
output_hidden_states = (
|
967 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
968 |
)
|
|
|
983 |
if inputs_embeds is None:
|
984 |
inputs_embeds = self.embed_tokens(input_ids)
|
985 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
986 |
past_seen_tokens = 0
|
987 |
if use_cache: # kept for BC (cache positions)
|
988 |
if not isinstance(past_key_values, StaticCache):
|
|
|
997 |
if position_ids is None:
|
998 |
position_ids = cache_position.unsqueeze(0)
|
999 |
|
1000 |
+
if attention_mask is not None and self._use_flash_attention_2 and use_cache:
|
1001 |
+
is_padding_right = attention_mask[:, -1].sum().item() != inputs_embeds.shape[0]
|
1002 |
+
if is_padding_right:
|
1003 |
+
raise ValueError(
|
1004 |
+
"You are attempting to perform batched generation with padding_side='right'"
|
1005 |
+
" this may lead to unexpected behaviour for Flash Attention version of Gemmoe. Make sure to "
|
1006 |
+
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
1007 |
+
)
|
1008 |
|
1009 |
+
if self._use_flash_attention_2:
|
1010 |
+
# 2d mask is passed through the layers
|
1011 |
+
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
1012 |
+
else:
|
1013 |
+
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, past_seen_tokens)
|
1014 |
|
1015 |
# normalized
|
1016 |
+
# Gemmoe downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
|
1017 |
+
# See https://github.com/huggingface/transformers/pull/29402
|
1018 |
+
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=inputs_embeds.dtype)
|
1019 |
+
hidden_states = inputs_embeds * normalizer
|
1020 |
|
1021 |
# decoder layers
|
1022 |
all_hidden_states = () if output_hidden_states else None
|
1023 |
all_self_attns = () if output_attentions else None
|
1024 |
+
all_router_logits = () if output_router_logits else None
|
1025 |
next_decoder_cache = None
|
1026 |
|
1027 |
for decoder_layer in self.layers:
|
1028 |
if output_hidden_states:
|
1029 |
all_hidden_states += (hidden_states,)
|
1030 |
+
|
1031 |
+
if self.gradient_checkpointing and self.training:
|
1032 |
layer_outputs = self._gradient_checkpointing_func(
|
1033 |
decoder_layer.__call__,
|
1034 |
hidden_states,
|
1035 |
+
causal_mask if not self._use_flash_attention_2 else attention_mask,
|
1036 |
position_ids,
|
1037 |
past_key_values,
|
1038 |
output_attentions,
|
1039 |
output_router_logits,
|
1040 |
+
use_cache,
|
1041 |
cache_position,
|
|
|
1042 |
)
|
1043 |
else:
|
1044 |
layer_outputs = decoder_layer(
|
1045 |
hidden_states,
|
1046 |
+
attention_mask=causal_mask if not self._use_flash_attention_2 else attention_mask,
|
1047 |
position_ids=position_ids,
|
1048 |
past_key_value=past_key_values,
|
1049 |
output_attentions=output_attentions,
|
1050 |
output_router_logits=output_router_logits,
|
1051 |
+
use_cache=use_cache,
|
1052 |
cache_position=cache_position,
|
1053 |
)
|
1054 |
|
|
|
1060 |
if output_attentions:
|
1061 |
all_self_attns += (layer_outputs[1],)
|
1062 |
|
1063 |
+
if output_router_logits:
|
1064 |
+
all_router_logits += (layer_outputs[-1],)
|
1065 |
+
|
1066 |
hidden_states = self.norm(hidden_states)
|
1067 |
|
1068 |
# add hidden states from the last decoder layer
|
|
|
1075 |
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
1076 |
)
|
1077 |
if not return_dict:
|
1078 |
+
return tuple(
|
1079 |
+
v
|
1080 |
+
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
1081 |
+
if v is not None
|
1082 |
+
)
|
1083 |
return MoeModelOutputWithPast(
|
1084 |
last_hidden_state=hidden_states,
|
1085 |
past_key_values=next_cache,
|
1086 |
hidden_states=all_hidden_states,
|
1087 |
attentions=all_self_attns,
|
1088 |
+
router_logits=all_router_logits,
|
1089 |
)
|
1090 |
|
1091 |
+
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
1092 |
+
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
1093 |
+
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
1094 |
+
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
1095 |
+
def _update_causal_mask(self, attention_mask, input_tensor, past_seen_tokens):
|
1096 |
if self.config._attn_implementation == "flash_attention_2":
|
1097 |
if attention_mask is not None and 0.0 in attention_mask:
|
1098 |
return attention_mask
|
|
|
1113 |
causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
|
1114 |
causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
|
1115 |
if attention_mask is not None:
|
1116 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
1117 |
if attention_mask.dim() == 2:
|
1118 |
mask_length = attention_mask.shape[-1]
|
1119 |
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
1120 |
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
1121 |
elif attention_mask.dim() == 4:
|
1122 |
+
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
1123 |
+
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
1124 |
+
if attention_mask.shape[-2] < past_seen_tokens + input_tensor.shape[1]:
|
1125 |
+
offset = past_seen_tokens
|
1126 |
+
else:
|
1127 |
+
offset = 0
|
1128 |
mask_shape = attention_mask.shape
|
1129 |
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
|
1130 |
+
causal_mask[
|
1131 |
+
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
|
1132 |
+
] = mask_slice
|
1133 |
|
1134 |
if (
|
1135 |
self.config._attn_implementation == "sdpa"
|
|
|
1150 |
|
1151 |
return causal_mask
|
1152 |
|
1153 |
+
|
1154 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->GEMMOE,Llama->Gemmoe,llama->GEMMA
|
1155 |
class GemmoeForCausalLM(GemmoePreTrainedModel):
|
1156 |
_tied_weights_keys = ["lm_head.weight"]
|
1157 |
|
|
|
1186 |
|
1187 |
@add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
|
1188 |
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
|
1189 |
def forward(
|
1190 |
self,
|
1191 |
input_ids: torch.LongTensor = None,
|
|
|
1199 |
output_hidden_states: Optional[bool] = None,
|
1200 |
output_router_logits: Optional[bool] = None,
|
1201 |
return_dict: Optional[bool] = None,
|
1202 |
+
cache_position: Optional[torch.LongTensor] = None,
|
1203 |
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
|
1204 |
r"""
|
1205 |
Args:
|
|
|
1215 |
```python
|
1216 |
>>> from transformers import AutoTokenizer, GemmoeForCausalLM
|
1217 |
|
1218 |
+
>>> model = GemmoeForCausalLM.from_pretrained("google/GEMMA-7b")
|
1219 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("google/GEMMA-7b")
|
1220 |
|
1221 |
+
>>> prompt = "What is your favorite condiment?"
|
1222 |
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
1223 |
|
1224 |
>>> # Generate
|
1225 |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
1226 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
1227 |
+
"What is your favorite condiment?"
|
1228 |
```"""
|
|
|
1229 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1230 |
output_router_logits = (
|
1231 |
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
1232 |
)
|
|
|
1233 |
output_hidden_states = (
|
1234 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1235 |
)
|
|
|
1247 |
output_hidden_states=output_hidden_states,
|
1248 |
output_router_logits=output_router_logits,
|
1249 |
return_dict=return_dict,
|
1250 |
+
cache_position=cache_position,
|
1251 |
)
|
1252 |
|
1253 |
hidden_states = outputs[0]
|
1254 |
logits = self.lm_head(hidden_states)
|
1255 |
logits = logits.float()
|
|
|
|
|
|
|
|
|
|
|
|
|
1256 |
|
1257 |
loss = None
|
1258 |
if labels is not None:
|
|
|
1270 |
aux_loss = None
|
1271 |
if output_router_logits:
|
1272 |
aux_loss = load_balancing_loss_func(
|
1273 |
+
outputs.router_logits if return_dict else outputs[-1], self.num_experts, self.num_experts_per_tok
|
|
|
|
|
|
|
1274 |
)
|
1275 |
if labels is not None:
|
1276 |
+
loss += self.router_aux_loss_coef * aux_loss
|
1277 |
|
1278 |
if not return_dict:
|
1279 |
output = (logits,) + outputs[1:]
|
|
|
1292 |
)
|
1293 |
|
1294 |
def prepare_inputs_for_generation(
|
1295 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
1296 |
):
|
1297 |
+
# With static cache, the `past_key_values` is None
|
1298 |
+
# TODO joao: standardize interface for the different Cache classes and remove of this if
|
1299 |
+
has_static_cache = False
|
1300 |
+
if past_key_values is None:
|
1301 |
+
past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)
|
1302 |
+
has_static_cache = past_key_values is not None
|
1303 |
+
|
1304 |
+
past_length = 0
|
1305 |
if past_key_values is not None:
|
1306 |
if isinstance(past_key_values, Cache):
|
1307 |
+
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
1308 |
+
max_cache_length = (
|
1309 |
+
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
1310 |
+
if past_key_values.get_max_length() is not None
|
1311 |
+
else None
|
1312 |
+
)
|
1313 |
+
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
1314 |
+
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
|
1315 |
else:
|
1316 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
1317 |
max_cache_length = None
|
|
|
1348 |
if inputs_embeds is not None and past_key_values is None:
|
1349 |
model_inputs = {"inputs_embeds": inputs_embeds}
|
1350 |
else:
|
1351 |
+
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
1352 |
+
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
|
1353 |
+
# TODO: use `next_tokens` directly instead.
|
1354 |
+
model_inputs = {"input_ids": input_ids.contiguous()}
|
1355 |
+
|
1356 |
+
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
1357 |
+
if cache_position is None:
|
1358 |
+
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
|
1359 |
+
else:
|
1360 |
+
cache_position = cache_position[-input_length:]
|
1361 |
+
|
1362 |
+
if has_static_cache:
|
1363 |
+
past_key_values = None
|
1364 |
|
1365 |
model_inputs.update(
|
1366 |
{
|
1367 |
"position_ids": position_ids,
|
1368 |
+
"cache_position": cache_position,
|
1369 |
"past_key_values": past_key_values,
|
1370 |
"use_cache": kwargs.get("use_cache"),
|
1371 |
"attention_mask": attention_mask,
|
|
|
1372 |
}
|
1373 |
)
|
1374 |
return model_inputs
|
|
|
1381 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
1382 |
)
|
1383 |
return reordered_past
|
1384 |
+
|
1385 |
+
|
1386 |
@add_start_docstrings(
|
1387 |
"""
|
1388 |
The Gemmoe Model transformer with a sequence classification head on top (linear layer).
|
1389 |
+
|
1390 |
[`GemmoeForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
1391 |
(e.g. GPT-2) do.
|
1392 |
|
|
|
1398 |
""",
|
1399 |
GEMMOE_START_DOCSTRING,
|
1400 |
)
|
1401 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->GEMMOE,Llama->Gemmoe
|
1402 |
class GemmoeForSequenceClassification(GemmoePreTrainedModel):
|
1403 |
def __init__(self, config):
|
1404 |
super().__init__(config)
|