drewwas commited on
Commit
b0cb9ba
·
verified ·
1 Parent(s): ae62696

Upload flashNorm_modeling_llama.py

Browse files
Files changed (1) hide show
  1. flashNorm_modeling_llama.py +1659 -0
flashNorm_modeling_llama.py ADDED
@@ -0,0 +1,1659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this file was copied from transformers version 4.45.2
2
+ # https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/llama/modeling_llama.py
3
+ # wget https://raw.githubusercontent.com/huggingface/transformers/refs/tags/v4.45.2/src/transformers/models/llama/modeling_llama.py
4
+ #
5
+ # I made the following changes:
6
+ # - added new class 'FlashNorm' as alternative to RMSNorm
7
+ # - replaced RMSNorm calls by FlashNorm in class 'LlamaDecoderLayer'
8
+ # - renamed 'LlamaForCausalLM' to 'LlamaFlashNorm'
9
+ # - changed relative imports 'from ...foo' to absolute 'from transformers.foo':
10
+ # sed -i 's/from \.\.\./from transformers./'
11
+ # sed -i 's/from \.configuration_llama import LlamaConfig/from transformers import LlamaConfig/'
12
+ # - All changes (except for the import changes at the beginning) are marked
13
+ # by a preceeding comment saying '# FlashNorm change:'
14
+ #--------------------------------------------------------------------------------
15
+
16
+ # coding=utf-8
17
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
18
+ #
19
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
20
+ # and OPT implementations in this library. It has been modified from its
21
+ # original forms to accommodate minor architectural differences compared
22
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
23
+ #
24
+ # Licensed under the Apache License, Version 2.0 (the "License");
25
+ # you may not use this file except in compliance with the License.
26
+ # You may obtain a copy of the License at
27
+ #
28
+ # http://www.apache.org/licenses/LICENSE-2.0
29
+ #
30
+ # Unless required by applicable law or agreed to in writing, software
31
+ # distributed under the License is distributed on an "AS IS" BASIS,
32
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33
+ # See the License for the specific language governing permissions and
34
+ # limitations under the License.
35
+ import math
36
+ from typing import List, Optional, Tuple, Union
37
+
38
+ import torch
39
+ import torch.nn.functional as F
40
+ import torch.utils.checkpoint
41
+ from torch import nn
42
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
43
+
44
+ from transformers.activations import ACT2FN
45
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
46
+ from transformers.generation import GenerationMixin
47
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
48
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
49
+ from transformers.modeling_outputs import (
50
+ BaseModelOutputWithPast,
51
+ CausalLMOutputWithPast,
52
+ QuestionAnsweringModelOutput,
53
+ SequenceClassifierOutputWithPast,
54
+ TokenClassifierOutput,
55
+ )
56
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
57
+ from transformers.modeling_utils import PreTrainedModel
58
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
59
+ from transformers.utils import (
60
+ add_start_docstrings,
61
+ add_start_docstrings_to_model_forward,
62
+ is_flash_attn_greater_or_equal_2_10,
63
+ is_torchdynamo_compiling,
64
+ logging,
65
+ replace_return_docstrings,
66
+ )
67
+ from transformers import LlamaConfig
68
+
69
+
70
+ logger = logging.get_logger(__name__)
71
+
72
+ _CONFIG_FOR_DOC = "LlamaConfig"
73
+
74
+
75
+ def _prepare_4d_causal_attention_mask_with_cache_position(
76
+ attention_mask: torch.Tensor,
77
+ sequence_length: int,
78
+ target_length: int,
79
+ dtype: torch.dtype,
80
+ device: torch.device,
81
+ min_dtype: float,
82
+ cache_position: torch.Tensor,
83
+ batch_size: int,
84
+ ):
85
+ """
86
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
87
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
88
+
89
+ Args:
90
+ attention_mask (`torch.Tensor`):
91
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
92
+ sequence_length (`int`):
93
+ The sequence length being processed.
94
+ target_length (`int`):
95
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
96
+ dtype (`torch.dtype`):
97
+ The dtype to use for the 4D attention mask.
98
+ device (`torch.device`):
99
+ The device to plcae the 4D attention mask on.
100
+ min_dtype (`float`):
101
+ The minimum value representable with the dtype `dtype`.
102
+ cache_position (`torch.Tensor`):
103
+ Indices depicting the position of the input sequence tokens in the sequence.
104
+ batch_size (`torch.Tensor`):
105
+ Batch size.
106
+ """
107
+ if attention_mask is not None and attention_mask.dim() == 4:
108
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
109
+ causal_mask = attention_mask
110
+ else:
111
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
112
+ if sequence_length != 1:
113
+ causal_mask = torch.triu(causal_mask, diagonal=1)
114
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
115
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
116
+ if attention_mask is not None:
117
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
118
+ mask_length = attention_mask.shape[-1]
119
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
120
+ padding_mask = padding_mask == 0
121
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
122
+ padding_mask, min_dtype
123
+ )
124
+
125
+ return causal_mask
126
+
127
+
128
+ class LlamaRMSNorm(nn.Module):
129
+ def __init__(self, hidden_size, eps=1e-6):
130
+ """
131
+ LlamaRMSNorm is equivalent to T5LayerNorm
132
+ """
133
+ super().__init__()
134
+ self.weight = nn.Parameter(torch.ones(hidden_size))
135
+ self.variance_epsilon = eps
136
+
137
+ def forward(self, hidden_states):
138
+ input_dtype = hidden_states.dtype
139
+ hidden_states = hidden_states.to(torch.float32)
140
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
141
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
142
+ return self.weight * hidden_states.to(input_dtype)
143
+
144
+ def extra_repr(self):
145
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
146
+
147
+
148
+ # FlashNorm change: added the class below
149
+ class FlashNorm(nn.Module):
150
+ def __init__(self, hidden_size, eps=1e-6):
151
+ """
152
+ FlashNorm is like RMSNorm without weights, see https://arxiv.org/abs/2407.09577
153
+ """
154
+ super().__init__()
155
+ #self.weight = nn.Parameter(torch.ones(hidden_size))
156
+ self.variance_epsilon = eps
157
+
158
+ def forward(self, hidden_states):
159
+ input_dtype = hidden_states.dtype
160
+ hidden_states = hidden_states.to(torch.float32)
161
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
162
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
163
+ return hidden_states.to(input_dtype)
164
+
165
+
166
+ ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
167
+
168
+
169
+ class LlamaRotaryEmbedding(nn.Module):
170
+ def __init__(
171
+ self,
172
+ dim=None,
173
+ max_position_embeddings=2048,
174
+ base=10000,
175
+ device=None,
176
+ scaling_factor=1.0,
177
+ rope_type="default",
178
+ config: Optional[LlamaConfig] = None,
179
+ ):
180
+ super().__init__()
181
+ # TODO (joao): remove the `if` below, only used for BC
182
+ self.rope_kwargs = {}
183
+ if config is None:
184
+ logger.warning_once(
185
+ "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
186
+ "`config` argument. All other arguments will be removed in v4.46"
187
+ )
188
+ self.rope_kwargs = {
189
+ "rope_type": rope_type,
190
+ "factor": scaling_factor,
191
+ "dim": dim,
192
+ "base": base,
193
+ "max_position_embeddings": max_position_embeddings,
194
+ }
195
+ self.rope_type = rope_type
196
+ self.max_seq_len_cached = max_position_embeddings
197
+ self.original_max_seq_len = max_position_embeddings
198
+ else:
199
+ # BC: "rope_type" was originally "type"
200
+ if config.rope_scaling is not None:
201
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
202
+ else:
203
+ self.rope_type = "default"
204
+ self.max_seq_len_cached = config.max_position_embeddings
205
+ self.original_max_seq_len = config.max_position_embeddings
206
+
207
+ self.config = config
208
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
209
+
210
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
211
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
212
+ self.original_inv_freq = self.inv_freq
213
+
214
+ def _dynamic_frequency_update(self, position_ids, device):
215
+ """
216
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
217
+ 1 - growing beyond the cached sequence length (allow scaling)
218
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
219
+ """
220
+ seq_len = torch.max(position_ids) + 1
221
+ if seq_len > self.max_seq_len_cached: # growth
222
+ inv_freq, self.attention_scaling = self.rope_init_fn(
223
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
224
+ )
225
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
226
+ self.max_seq_len_cached = seq_len
227
+
228
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
229
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
230
+ self.max_seq_len_cached = self.original_max_seq_len
231
+
232
+ @torch.no_grad()
233
+ def forward(self, x, position_ids):
234
+ if "dynamic" in self.rope_type:
235
+ self._dynamic_frequency_update(position_ids, device=x.device)
236
+
237
+ # Core RoPE block
238
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
239
+ position_ids_expanded = position_ids[:, None, :].float()
240
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
241
+ device_type = x.device.type
242
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
243
+ with torch.autocast(device_type=device_type, enabled=False):
244
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
245
+ emb = torch.cat((freqs, freqs), dim=-1)
246
+ cos = emb.cos()
247
+ sin = emb.sin()
248
+
249
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
250
+ cos = cos * self.attention_scaling
251
+ sin = sin * self.attention_scaling
252
+
253
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
254
+
255
+
256
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
257
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
258
+
259
+ def __init__(self, *args, **kwargs):
260
+ logger.warning_once(
261
+ "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
262
+ "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
263
+ )
264
+ kwargs["rope_type"] = "linear"
265
+ super().__init__(*args, **kwargs)
266
+
267
+
268
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
269
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
270
+
271
+ def __init__(self, *args, **kwargs):
272
+ logger.warning_once(
273
+ "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
274
+ "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
275
+ "__init__)."
276
+ )
277
+ kwargs["rope_type"] = "dynamic"
278
+ super().__init__(*args, **kwargs)
279
+
280
+
281
+ def rotate_half(x):
282
+ """Rotates half the hidden dims of the input."""
283
+ x1 = x[..., : x.shape[-1] // 2]
284
+ x2 = x[..., x.shape[-1] // 2 :]
285
+ return torch.cat((-x2, x1), dim=-1)
286
+
287
+
288
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
289
+ """Applies Rotary Position Embedding to the query and key tensors.
290
+
291
+ Args:
292
+ q (`torch.Tensor`): The query tensor.
293
+ k (`torch.Tensor`): The key tensor.
294
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
295
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
296
+ position_ids (`torch.Tensor`, *optional*):
297
+ Deprecated and unused.
298
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
299
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
300
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
301
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
302
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
303
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
304
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
305
+ Returns:
306
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
307
+ """
308
+ cos = cos.unsqueeze(unsqueeze_dim)
309
+ sin = sin.unsqueeze(unsqueeze_dim)
310
+ q_embed = (q * cos) + (rotate_half(q) * sin)
311
+ k_embed = (k * cos) + (rotate_half(k) * sin)
312
+ return q_embed, k_embed
313
+
314
+
315
+ class LlamaMLP(nn.Module):
316
+ def __init__(self, config):
317
+ super().__init__()
318
+ self.config = config
319
+ self.hidden_size = config.hidden_size
320
+ self.intermediate_size = config.intermediate_size
321
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
322
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
323
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
324
+ self.act_fn = ACT2FN[config.hidden_act]
325
+
326
+ def forward(self, x):
327
+ if self.config.pretraining_tp > 1:
328
+ slice = self.intermediate_size // self.config.pretraining_tp
329
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
330
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
331
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
332
+
333
+ gate_proj = torch.cat(
334
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
335
+ )
336
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
337
+
338
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
339
+ down_proj = [
340
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
341
+ ]
342
+ down_proj = sum(down_proj)
343
+ else:
344
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
345
+
346
+ return down_proj
347
+
348
+
349
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
350
+ """
351
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
352
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
353
+ """
354
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
355
+ if n_rep == 1:
356
+ return hidden_states
357
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
358
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
359
+
360
+
361
+ class LlamaAttention(nn.Module):
362
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
363
+
364
+ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
365
+ super().__init__()
366
+ self.config = config
367
+ self.layer_idx = layer_idx
368
+ if layer_idx is None:
369
+ logger.warning_once(
370
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
371
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
372
+ "when creating this class."
373
+ )
374
+
375
+ self.attention_dropout = config.attention_dropout
376
+ self.hidden_size = config.hidden_size
377
+ self.num_heads = config.num_attention_heads
378
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
379
+ self.num_key_value_heads = config.num_key_value_heads
380
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
381
+ self.max_position_embeddings = config.max_position_embeddings
382
+ self.rope_theta = config.rope_theta
383
+ self.is_causal = True
384
+
385
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
386
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
387
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
388
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
389
+
390
+ # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
391
+ self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
392
+
393
+ def forward(
394
+ self,
395
+ hidden_states: torch.Tensor,
396
+ attention_mask: Optional[torch.Tensor] = None,
397
+ position_ids: Optional[torch.LongTensor] = None,
398
+ past_key_value: Optional[Cache] = None,
399
+ output_attentions: bool = False,
400
+ use_cache: bool = False,
401
+ cache_position: Optional[torch.LongTensor] = None,
402
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
403
+ **kwargs,
404
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
405
+ bsz, q_len, _ = hidden_states.size()
406
+
407
+ if self.config.pretraining_tp > 1:
408
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
409
+ query_slices = self.q_proj.weight.split(
410
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
411
+ )
412
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
413
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
414
+
415
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
416
+ query_states = torch.cat(query_states, dim=-1)
417
+
418
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
419
+ key_states = torch.cat(key_states, dim=-1)
420
+
421
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
422
+ value_states = torch.cat(value_states, dim=-1)
423
+
424
+ else:
425
+ query_states = self.q_proj(hidden_states)
426
+ key_states = self.k_proj(hidden_states)
427
+ value_states = self.v_proj(hidden_states)
428
+
429
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
430
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
431
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
432
+
433
+ if position_embeddings is None:
434
+ logger.warning_once(
435
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
436
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
437
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
438
+ "removed and `position_embeddings` will be mandatory."
439
+ )
440
+ cos, sin = self.rotary_emb(value_states, position_ids)
441
+ else:
442
+ cos, sin = position_embeddings
443
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
444
+
445
+ if past_key_value is not None:
446
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
447
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
448
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
449
+
450
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
451
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
452
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
453
+
454
+ if attention_mask is not None: # no matter the length, we just slice it
455
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
456
+ attn_weights = attn_weights + causal_mask
457
+
458
+ # upcast attention to fp32
459
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
460
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
461
+ attn_output = torch.matmul(attn_weights, value_states)
462
+
463
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
464
+ raise ValueError(
465
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
466
+ f" {attn_output.size()}"
467
+ )
468
+
469
+ attn_output = attn_output.transpose(1, 2).contiguous()
470
+
471
+ attn_output = attn_output.reshape(bsz, q_len, -1)
472
+
473
+ if self.config.pretraining_tp > 1:
474
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
475
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
476
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
477
+ else:
478
+ attn_output = self.o_proj(attn_output)
479
+
480
+ if not output_attentions:
481
+ attn_weights = None
482
+
483
+ return attn_output, attn_weights, past_key_value
484
+
485
+
486
+ class LlamaFlashAttention2(LlamaAttention):
487
+ """
488
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
489
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
490
+ flash attention and deal with padding tokens in case the input contains any of them.
491
+ """
492
+
493
+ def __init__(self, *args, **kwargs):
494
+ super().__init__(*args, **kwargs)
495
+
496
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
497
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
498
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
499
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
500
+
501
+ def forward(
502
+ self,
503
+ hidden_states: torch.Tensor,
504
+ attention_mask: Optional[torch.LongTensor] = None,
505
+ position_ids: Optional[torch.LongTensor] = None,
506
+ past_key_value: Optional[Cache] = None,
507
+ output_attentions: bool = False,
508
+ use_cache: bool = False,
509
+ cache_position: Optional[torch.LongTensor] = None,
510
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
511
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
512
+ if isinstance(past_key_value, StaticCache):
513
+ raise ValueError(
514
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
515
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
516
+ )
517
+
518
+ output_attentions = False
519
+
520
+ bsz, q_len, _ = hidden_states.size()
521
+
522
+ query_states = self.q_proj(hidden_states)
523
+ key_states = self.k_proj(hidden_states)
524
+ value_states = self.v_proj(hidden_states)
525
+
526
+ # Flash attention requires the input to have the shape
527
+ # batch_size x seq_length x head_dim x hidden_dim
528
+ # therefore we just need to keep the original shape
529
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
530
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
531
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
532
+
533
+ if position_embeddings is None:
534
+ logger.warning_once(
535
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
536
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
537
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
538
+ "removed and `position_embeddings` will be mandatory."
539
+ )
540
+ cos, sin = self.rotary_emb(value_states, position_ids)
541
+ else:
542
+ cos, sin = position_embeddings
543
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
544
+
545
+ if past_key_value is not None:
546
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
547
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
548
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
549
+
550
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
551
+ # to be able to avoid many of these transpose/reshape/view.
552
+ query_states = query_states.transpose(1, 2)
553
+ key_states = key_states.transpose(1, 2)
554
+ value_states = value_states.transpose(1, 2)
555
+
556
+ dropout_rate = self.attention_dropout if self.training else 0.0
557
+
558
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
559
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
560
+ # cast them back in the correct dtype just to be sure everything works as expected.
561
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
562
+ # in fp32. (LlamaRMSNorm handles it correctly)
563
+
564
+ input_dtype = query_states.dtype
565
+ if input_dtype == torch.float32:
566
+ if torch.is_autocast_enabled():
567
+ target_dtype = torch.get_autocast_gpu_dtype()
568
+ # Handle the case where the model is quantized
569
+ elif hasattr(self.config, "_pre_quantization_dtype"):
570
+ target_dtype = self.config._pre_quantization_dtype
571
+ else:
572
+ target_dtype = self.q_proj.weight.dtype
573
+
574
+ logger.warning_once(
575
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
576
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
577
+ f" {target_dtype}."
578
+ )
579
+
580
+ query_states = query_states.to(target_dtype)
581
+ key_states = key_states.to(target_dtype)
582
+ value_states = value_states.to(target_dtype)
583
+
584
+ attn_output = _flash_attention_forward(
585
+ query_states,
586
+ key_states,
587
+ value_states,
588
+ attention_mask,
589
+ q_len,
590
+ position_ids=position_ids,
591
+ dropout=dropout_rate,
592
+ sliding_window=getattr(self, "sliding_window", None),
593
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
594
+ is_causal=self.is_causal,
595
+ )
596
+
597
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
598
+ attn_output = self.o_proj(attn_output)
599
+
600
+ if not output_attentions:
601
+ attn_weights = None
602
+
603
+ return attn_output, attn_weights, past_key_value
604
+
605
+
606
+ class LlamaSdpaAttention(LlamaAttention):
607
+ """
608
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
609
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
610
+ SDPA API.
611
+ """
612
+
613
+ # Adapted from LlamaAttention.forward
614
+ def forward(
615
+ self,
616
+ hidden_states: torch.Tensor,
617
+ attention_mask: Optional[torch.Tensor] = None,
618
+ position_ids: Optional[torch.LongTensor] = None,
619
+ past_key_value: Optional[Cache] = None,
620
+ output_attentions: bool = False,
621
+ use_cache: bool = False,
622
+ cache_position: Optional[torch.LongTensor] = None,
623
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
624
+ **kwargs,
625
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
626
+ if output_attentions:
627
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
628
+ logger.warning_once(
629
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
630
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
631
+ )
632
+ return super().forward(
633
+ hidden_states=hidden_states,
634
+ attention_mask=attention_mask,
635
+ position_ids=position_ids,
636
+ past_key_value=past_key_value,
637
+ output_attentions=output_attentions,
638
+ use_cache=use_cache,
639
+ cache_position=cache_position,
640
+ position_embeddings=position_embeddings,
641
+ )
642
+
643
+ bsz, q_len, _ = hidden_states.size()
644
+
645
+ query_states = self.q_proj(hidden_states)
646
+ key_states = self.k_proj(hidden_states)
647
+ value_states = self.v_proj(hidden_states)
648
+
649
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
650
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
651
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
652
+
653
+ if position_embeddings is None:
654
+ logger.warning_once(
655
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
656
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
657
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
658
+ "removed and `position_embeddings` will be mandatory."
659
+ )
660
+ cos, sin = self.rotary_emb(value_states, position_ids)
661
+ else:
662
+ cos, sin = position_embeddings
663
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
664
+
665
+ if past_key_value is not None:
666
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
667
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
668
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
669
+
670
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
671
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
672
+
673
+ causal_mask = attention_mask
674
+ if attention_mask is not None:
675
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
676
+
677
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
678
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
679
+ if query_states.device.type == "cuda" and causal_mask is not None:
680
+ query_states = query_states.contiguous()
681
+ key_states = key_states.contiguous()
682
+ value_states = value_states.contiguous()
683
+
684
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
685
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
686
+ is_causal = True if causal_mask is None and q_len > 1 else False
687
+
688
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
689
+ query_states,
690
+ key_states,
691
+ value_states,
692
+ attn_mask=causal_mask,
693
+ dropout_p=self.attention_dropout if self.training else 0.0,
694
+ is_causal=is_causal,
695
+ )
696
+
697
+ attn_output = attn_output.transpose(1, 2).contiguous()
698
+ attn_output = attn_output.view(bsz, q_len, -1)
699
+
700
+ attn_output = self.o_proj(attn_output)
701
+
702
+ return attn_output, None, past_key_value
703
+
704
+
705
+ LLAMA_ATTENTION_CLASSES = {
706
+ "eager": LlamaAttention,
707
+ "flash_attention_2": LlamaFlashAttention2,
708
+ "sdpa": LlamaSdpaAttention,
709
+ }
710
+
711
+
712
+ class LlamaDecoderLayer(nn.Module):
713
+ def __init__(self, config: LlamaConfig, layer_idx: int):
714
+ super().__init__()
715
+ self.hidden_size = config.hidden_size
716
+
717
+ self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
718
+
719
+ self.mlp = LlamaMLP(config)
720
+ # FlashNorm change: before it was:
721
+ #self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
722
+ #self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
723
+ self.input_layernorm = FlashNorm(config.hidden_size, eps=config.rms_norm_eps)
724
+ self.post_attention_layernorm = FlashNorm(config.hidden_size, eps=config.rms_norm_eps)
725
+
726
+ def forward(
727
+ self,
728
+ hidden_states: torch.Tensor,
729
+ attention_mask: Optional[torch.Tensor] = None,
730
+ position_ids: Optional[torch.LongTensor] = None,
731
+ past_key_value: Optional[Cache] = None,
732
+ output_attentions: Optional[bool] = False,
733
+ use_cache: Optional[bool] = False,
734
+ cache_position: Optional[torch.LongTensor] = None,
735
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
736
+ **kwargs,
737
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
738
+ """
739
+ Args:
740
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
741
+ attention_mask (`torch.FloatTensor`, *optional*):
742
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
743
+ query_sequence_length, key_sequence_length)` if default attention is used.
744
+ output_attentions (`bool`, *optional*):
745
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
746
+ returned tensors for more detail.
747
+ use_cache (`bool`, *optional*):
748
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
749
+ (see `past_key_values`).
750
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
751
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
752
+ Indices depicting the position of the input sequence tokens in the sequence
753
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
754
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
755
+ with `head_dim` being the embedding dimension of each attention head.
756
+ kwargs (`dict`, *optional*):
757
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
758
+ into the model
759
+ """
760
+ residual = hidden_states
761
+
762
+ hidden_states = self.input_layernorm(hidden_states)
763
+
764
+ # Self Attention
765
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
766
+ hidden_states=hidden_states,
767
+ attention_mask=attention_mask,
768
+ position_ids=position_ids,
769
+ past_key_value=past_key_value,
770
+ output_attentions=output_attentions,
771
+ use_cache=use_cache,
772
+ cache_position=cache_position,
773
+ position_embeddings=position_embeddings,
774
+ **kwargs,
775
+ )
776
+ hidden_states = residual + hidden_states
777
+
778
+ # Fully Connected
779
+ residual = hidden_states
780
+ hidden_states = self.post_attention_layernorm(hidden_states)
781
+ hidden_states = self.mlp(hidden_states)
782
+ hidden_states = residual + hidden_states
783
+
784
+ outputs = (hidden_states,)
785
+
786
+ if output_attentions:
787
+ outputs += (self_attn_weights,)
788
+
789
+ if use_cache:
790
+ outputs += (present_key_value,)
791
+
792
+ return outputs
793
+
794
+
795
+ LLAMA_START_DOCSTRING = r"""
796
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
797
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
798
+ etc.)
799
+
800
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
801
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
802
+ and behavior.
803
+
804
+ Parameters:
805
+ config ([`LlamaConfig`]):
806
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
807
+ load the weights associated with the model, only the configuration. Check out the
808
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
809
+ """
810
+
811
+
812
+ @add_start_docstrings(
813
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
814
+ LLAMA_START_DOCSTRING,
815
+ )
816
+ class LlamaPreTrainedModel(PreTrainedModel):
817
+ config_class = LlamaConfig
818
+ base_model_prefix = "model"
819
+ supports_gradient_checkpointing = True
820
+ _no_split_modules = ["LlamaDecoderLayer"]
821
+ _skip_keys_device_placement = ["past_key_values"]
822
+ _supports_flash_attn_2 = True
823
+ _supports_sdpa = True
824
+ _supports_cache_class = True
825
+ _supports_quantized_cache = True
826
+ _supports_static_cache = True
827
+
828
+ def _init_weights(self, module):
829
+ std = self.config.initializer_range
830
+ if isinstance(module, nn.Linear):
831
+ module.weight.data.normal_(mean=0.0, std=std)
832
+ if module.bias is not None:
833
+ module.bias.data.zero_()
834
+ elif isinstance(module, nn.Embedding):
835
+ module.weight.data.normal_(mean=0.0, std=std)
836
+ if module.padding_idx is not None:
837
+ module.weight.data[module.padding_idx].zero_()
838
+
839
+
840
+ LLAMA_INPUTS_DOCSTRING = r"""
841
+ Args:
842
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
843
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
844
+ it.
845
+
846
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
847
+ [`PreTrainedTokenizer.__call__`] for details.
848
+
849
+ [What are input IDs?](../glossary#input-ids)
850
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
851
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
852
+
853
+ - 1 for tokens that are **not masked**,
854
+ - 0 for tokens that are **masked**.
855
+
856
+ [What are attention masks?](../glossary#attention-mask)
857
+
858
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
859
+ [`PreTrainedTokenizer.__call__`] for details.
860
+
861
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
862
+ `past_key_values`).
863
+
864
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
865
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
866
+ information on the default strategy.
867
+
868
+ - 1 indicates the head is **not masked**,
869
+ - 0 indicates the head is **masked**.
870
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
871
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
872
+ config.n_positions - 1]`.
873
+
874
+ [What are position IDs?](../glossary#position-ids)
875
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
876
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
877
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
878
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
879
+
880
+ Two formats are allowed:
881
+ - a [`~cache_utils.Cache`] instance, see our
882
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
883
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
884
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
885
+ cache format.
886
+
887
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
888
+ legacy cache format will be returned.
889
+
890
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
891
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
892
+ of shape `(batch_size, sequence_length)`.
893
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
894
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
895
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
896
+ model's internal embedding lookup matrix.
897
+ use_cache (`bool`, *optional*):
898
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
899
+ `past_key_values`).
900
+ output_attentions (`bool`, *optional*):
901
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
902
+ tensors for more detail.
903
+ output_hidden_states (`bool`, *optional*):
904
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
905
+ more detail.
906
+ return_dict (`bool`, *optional*):
907
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
908
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
909
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
910
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
911
+ the complete sequence length.
912
+ """
913
+
914
+
915
+ @add_start_docstrings(
916
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
917
+ LLAMA_START_DOCSTRING,
918
+ )
919
+ class LlamaModel(LlamaPreTrainedModel):
920
+ """
921
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
922
+
923
+ Args:
924
+ config: LlamaConfig
925
+ """
926
+
927
+ def __init__(self, config: LlamaConfig):
928
+ super().__init__(config)
929
+ self.padding_idx = config.pad_token_id
930
+ self.vocab_size = config.vocab_size
931
+
932
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
933
+ self.layers = nn.ModuleList(
934
+ [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
935
+ )
936
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
937
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
938
+ self.gradient_checkpointing = False
939
+
940
+ # Initialize weights and apply final processing
941
+ self.post_init()
942
+
943
+ def get_input_embeddings(self):
944
+ return self.embed_tokens
945
+
946
+ def set_input_embeddings(self, value):
947
+ self.embed_tokens = value
948
+
949
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
950
+ def forward(
951
+ self,
952
+ input_ids: torch.LongTensor = None,
953
+ attention_mask: Optional[torch.Tensor] = None,
954
+ position_ids: Optional[torch.LongTensor] = None,
955
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
956
+ inputs_embeds: Optional[torch.FloatTensor] = None,
957
+ use_cache: Optional[bool] = None,
958
+ output_attentions: Optional[bool] = None,
959
+ output_hidden_states: Optional[bool] = None,
960
+ return_dict: Optional[bool] = None,
961
+ cache_position: Optional[torch.LongTensor] = None,
962
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
963
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
964
+ output_hidden_states = (
965
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
966
+ )
967
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
968
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
969
+
970
+ if (input_ids is None) ^ (inputs_embeds is not None):
971
+ raise ValueError(
972
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
973
+ )
974
+
975
+ if self.gradient_checkpointing and self.training and use_cache:
976
+ logger.warning_once(
977
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
978
+ )
979
+ use_cache = False
980
+
981
+ if inputs_embeds is None:
982
+ inputs_embeds = self.embed_tokens(input_ids)
983
+
984
+ # kept for BC (non `Cache` `past_key_values` inputs)
985
+ return_legacy_cache = False
986
+ if use_cache and not isinstance(past_key_values, Cache):
987
+ return_legacy_cache = True
988
+ if past_key_values is None:
989
+ past_key_values = DynamicCache()
990
+ else:
991
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
992
+ logger.warning_once(
993
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
994
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
995
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
996
+ )
997
+
998
+ if cache_position is None:
999
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1000
+ cache_position = torch.arange(
1001
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1002
+ )
1003
+ if position_ids is None:
1004
+ position_ids = cache_position.unsqueeze(0)
1005
+
1006
+ causal_mask = self._update_causal_mask(
1007
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
1008
+ )
1009
+ hidden_states = inputs_embeds
1010
+
1011
+ # create position embeddings to be shared across the decoder layers
1012
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
1013
+
1014
+ # decoder layers
1015
+ all_hidden_states = () if output_hidden_states else None
1016
+ all_self_attns = () if output_attentions else None
1017
+ next_decoder_cache = None
1018
+
1019
+ for decoder_layer in self.layers:
1020
+ if output_hidden_states:
1021
+ all_hidden_states += (hidden_states,)
1022
+
1023
+ if self.gradient_checkpointing and self.training:
1024
+ layer_outputs = self._gradient_checkpointing_func(
1025
+ decoder_layer.__call__,
1026
+ hidden_states,
1027
+ causal_mask,
1028
+ position_ids,
1029
+ past_key_values,
1030
+ output_attentions,
1031
+ use_cache,
1032
+ cache_position,
1033
+ position_embeddings,
1034
+ )
1035
+ else:
1036
+ layer_outputs = decoder_layer(
1037
+ hidden_states,
1038
+ attention_mask=causal_mask,
1039
+ position_ids=position_ids,
1040
+ past_key_value=past_key_values,
1041
+ output_attentions=output_attentions,
1042
+ use_cache=use_cache,
1043
+ cache_position=cache_position,
1044
+ position_embeddings=position_embeddings,
1045
+ )
1046
+
1047
+ hidden_states = layer_outputs[0]
1048
+
1049
+ if use_cache:
1050
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1051
+
1052
+ if output_attentions:
1053
+ all_self_attns += (layer_outputs[1],)
1054
+
1055
+ hidden_states = self.norm(hidden_states)
1056
+
1057
+ # add hidden states from the last decoder layer
1058
+ if output_hidden_states:
1059
+ all_hidden_states += (hidden_states,)
1060
+
1061
+ next_cache = next_decoder_cache if use_cache else None
1062
+ if return_legacy_cache:
1063
+ next_cache = next_cache.to_legacy_cache()
1064
+
1065
+ if not return_dict:
1066
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1067
+ return BaseModelOutputWithPast(
1068
+ last_hidden_state=hidden_states,
1069
+ past_key_values=next_cache,
1070
+ hidden_states=all_hidden_states,
1071
+ attentions=all_self_attns,
1072
+ )
1073
+
1074
+ def _update_causal_mask(
1075
+ self,
1076
+ attention_mask: torch.Tensor,
1077
+ input_tensor: torch.Tensor,
1078
+ cache_position: torch.Tensor,
1079
+ past_key_values: Cache,
1080
+ output_attentions: bool,
1081
+ ):
1082
+ if self.config._attn_implementation == "flash_attention_2":
1083
+ if attention_mask is not None and 0.0 in attention_mask:
1084
+ return attention_mask
1085
+ return None
1086
+
1087
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1088
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1089
+ # to infer the attention mask.
1090
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1091
+ using_static_cache = isinstance(past_key_values, StaticCache)
1092
+
1093
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1094
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
1095
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1096
+ attention_mask,
1097
+ inputs_embeds=input_tensor,
1098
+ past_key_values_length=past_seen_tokens,
1099
+ is_training=self.training,
1100
+ ):
1101
+ return None
1102
+
1103
+ dtype, device = input_tensor.dtype, input_tensor.device
1104
+ min_dtype = torch.finfo(dtype).min
1105
+ sequence_length = input_tensor.shape[1]
1106
+ if using_static_cache:
1107
+ target_length = past_key_values.get_max_length()
1108
+ else:
1109
+ target_length = (
1110
+ attention_mask.shape[-1]
1111
+ if isinstance(attention_mask, torch.Tensor)
1112
+ else past_seen_tokens + sequence_length + 1
1113
+ )
1114
+
1115
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1116
+ causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1117
+ attention_mask,
1118
+ sequence_length=sequence_length,
1119
+ target_length=target_length,
1120
+ dtype=dtype,
1121
+ device=device,
1122
+ min_dtype=min_dtype,
1123
+ cache_position=cache_position,
1124
+ batch_size=input_tensor.shape[0],
1125
+ )
1126
+
1127
+ if (
1128
+ self.config._attn_implementation == "sdpa"
1129
+ and attention_mask is not None
1130
+ and attention_mask.device.type == "cuda"
1131
+ and not output_attentions
1132
+ ):
1133
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1134
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1135
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1136
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1137
+
1138
+ return causal_mask
1139
+
1140
+
1141
+ # FlashNorm change: before it was:
1142
+ #class LlamaForCausal(LlamaPreTrainedModel, GenerationMixin):
1143
+ class LlamaFlashNorm(LlamaPreTrainedModel, GenerationMixin):
1144
+ _tied_weights_keys = ["lm_head.weight"]
1145
+
1146
+ def __init__(self, config):
1147
+ super().__init__(config)
1148
+ self.model = LlamaModel(config)
1149
+ self.vocab_size = config.vocab_size
1150
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1151
+
1152
+ # Initialize weights and apply final processing
1153
+ self.post_init()
1154
+
1155
+ def get_input_embeddings(self):
1156
+ return self.model.embed_tokens
1157
+
1158
+ def set_input_embeddings(self, value):
1159
+ self.model.embed_tokens = value
1160
+
1161
+ def get_output_embeddings(self):
1162
+ return self.lm_head
1163
+
1164
+ def set_output_embeddings(self, new_embeddings):
1165
+ self.lm_head = new_embeddings
1166
+
1167
+ def set_decoder(self, decoder):
1168
+ self.model = decoder
1169
+
1170
+ def get_decoder(self):
1171
+ return self.model
1172
+
1173
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1174
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1175
+ def forward(
1176
+ self,
1177
+ input_ids: torch.LongTensor = None,
1178
+ attention_mask: Optional[torch.Tensor] = None,
1179
+ position_ids: Optional[torch.LongTensor] = None,
1180
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1181
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1182
+ labels: Optional[torch.LongTensor] = None,
1183
+ use_cache: Optional[bool] = None,
1184
+ output_attentions: Optional[bool] = None,
1185
+ output_hidden_states: Optional[bool] = None,
1186
+ return_dict: Optional[bool] = None,
1187
+ cache_position: Optional[torch.LongTensor] = None,
1188
+ num_logits_to_keep: int = 0,
1189
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1190
+ r"""
1191
+ Args:
1192
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1193
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1194
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1195
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1196
+
1197
+ num_logits_to_keep (`int`, *optional*):
1198
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
1199
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1200
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1201
+
1202
+ Returns:
1203
+
1204
+ Example:
1205
+
1206
+ ```python
1207
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
1208
+
1209
+ >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
1210
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
1211
+
1212
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1213
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1214
+
1215
+ >>> # Generate
1216
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1217
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1218
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1219
+ ```"""
1220
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1221
+ output_hidden_states = (
1222
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1223
+ )
1224
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1225
+
1226
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1227
+ outputs = self.model(
1228
+ input_ids=input_ids,
1229
+ attention_mask=attention_mask,
1230
+ position_ids=position_ids,
1231
+ past_key_values=past_key_values,
1232
+ inputs_embeds=inputs_embeds,
1233
+ use_cache=use_cache,
1234
+ output_attentions=output_attentions,
1235
+ output_hidden_states=output_hidden_states,
1236
+ return_dict=return_dict,
1237
+ cache_position=cache_position,
1238
+ )
1239
+
1240
+ hidden_states = outputs[0]
1241
+ if self.config.pretraining_tp > 1:
1242
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1243
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1244
+ logits = torch.cat(logits, dim=-1)
1245
+ else:
1246
+ if labels is None and not is_torchdynamo_compiling():
1247
+ logger.warning_once(
1248
+ "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
1249
+ )
1250
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1251
+ # TODO: remove the float() operation in v4.46
1252
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
1253
+
1254
+ loss = None
1255
+ if labels is not None:
1256
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
1257
+ logits = logits.float()
1258
+ # Shift so that tokens < n predict n
1259
+ shift_logits = logits[..., :-1, :].contiguous()
1260
+ shift_labels = labels[..., 1:].contiguous()
1261
+ # Flatten the tokens
1262
+ loss_fct = CrossEntropyLoss()
1263
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1264
+ shift_labels = shift_labels.view(-1)
1265
+ # Enable model parallelism
1266
+ shift_labels = shift_labels.to(shift_logits.device)
1267
+ loss = loss_fct(shift_logits, shift_labels)
1268
+
1269
+ if not return_dict:
1270
+ output = (logits,) + outputs[1:]
1271
+ return (loss,) + output if loss is not None else output
1272
+
1273
+ return CausalLMOutputWithPast(
1274
+ loss=loss,
1275
+ logits=logits,
1276
+ past_key_values=outputs.past_key_values,
1277
+ hidden_states=outputs.hidden_states,
1278
+ attentions=outputs.attentions,
1279
+ )
1280
+
1281
+ def prepare_inputs_for_generation(
1282
+ self,
1283
+ input_ids,
1284
+ past_key_values=None,
1285
+ attention_mask=None,
1286
+ inputs_embeds=None,
1287
+ cache_position=None,
1288
+ position_ids=None,
1289
+ use_cache=True,
1290
+ num_logits_to_keep=None,
1291
+ **kwargs,
1292
+ ):
1293
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1294
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
1295
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
1296
+ if past_key_values is not None:
1297
+ if inputs_embeds is not None: # Exception 1
1298
+ input_ids = input_ids[:, -cache_position.shape[0] :]
1299
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
1300
+ input_ids = input_ids[:, cache_position]
1301
+
1302
+ if attention_mask is not None and position_ids is None:
1303
+ # create position_ids on the fly for batch generation
1304
+ position_ids = attention_mask.long().cumsum(-1) - 1
1305
+ position_ids.masked_fill_(attention_mask == 0, 1)
1306
+ if past_key_values:
1307
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1308
+
1309
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
1310
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
1311
+
1312
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1313
+ if inputs_embeds is not None and cache_position[0] == 0:
1314
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
1315
+ else:
1316
+ # The clone here is for the same reason as for `position_ids`.
1317
+ model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
1318
+
1319
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
1320
+ if model_inputs["inputs_embeds"] is not None:
1321
+ batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
1322
+ device = model_inputs["inputs_embeds"].device
1323
+ else:
1324
+ batch_size, sequence_length = model_inputs["input_ids"].shape
1325
+ device = model_inputs["input_ids"].device
1326
+
1327
+ dtype = self.lm_head.weight.dtype
1328
+ min_dtype = torch.finfo(dtype).min
1329
+
1330
+ attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1331
+ attention_mask,
1332
+ sequence_length=sequence_length,
1333
+ target_length=past_key_values.get_max_length(),
1334
+ dtype=dtype,
1335
+ device=device,
1336
+ min_dtype=min_dtype,
1337
+ cache_position=cache_position,
1338
+ batch_size=batch_size,
1339
+ )
1340
+
1341
+ if num_logits_to_keep is not None:
1342
+ model_inputs["num_logits_to_keep"] = num_logits_to_keep
1343
+
1344
+ model_inputs.update(
1345
+ {
1346
+ "position_ids": position_ids,
1347
+ "cache_position": cache_position,
1348
+ "past_key_values": past_key_values,
1349
+ "use_cache": use_cache,
1350
+ "attention_mask": attention_mask,
1351
+ }
1352
+ )
1353
+ return model_inputs
1354
+
1355
+
1356
+ @add_start_docstrings(
1357
+ """
1358
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
1359
+
1360
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1361
+ (e.g. GPT-2) do.
1362
+
1363
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1364
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1365
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1366
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1367
+ each row of the batch).
1368
+ """,
1369
+ LLAMA_START_DOCSTRING,
1370
+ )
1371
+ class LlamaForSequenceClassification(LlamaPreTrainedModel):
1372
+ def __init__(self, config):
1373
+ super().__init__(config)
1374
+ self.num_labels = config.num_labels
1375
+ self.model = LlamaModel(config)
1376
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1377
+
1378
+ # Initialize weights and apply final processing
1379
+ self.post_init()
1380
+
1381
+ def get_input_embeddings(self):
1382
+ return self.model.embed_tokens
1383
+
1384
+ def set_input_embeddings(self, value):
1385
+ self.model.embed_tokens = value
1386
+
1387
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1388
+ def forward(
1389
+ self,
1390
+ input_ids: Optional[torch.LongTensor] = None,
1391
+ attention_mask: Optional[torch.Tensor] = None,
1392
+ position_ids: Optional[torch.LongTensor] = None,
1393
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1394
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1395
+ labels: Optional[torch.LongTensor] = None,
1396
+ use_cache: Optional[bool] = None,
1397
+ output_attentions: Optional[bool] = None,
1398
+ output_hidden_states: Optional[bool] = None,
1399
+ return_dict: Optional[bool] = None,
1400
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1401
+ r"""
1402
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1403
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1404
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1405
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1406
+ """
1407
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1408
+
1409
+ transformer_outputs = self.model(
1410
+ input_ids,
1411
+ attention_mask=attention_mask,
1412
+ position_ids=position_ids,
1413
+ past_key_values=past_key_values,
1414
+ inputs_embeds=inputs_embeds,
1415
+ use_cache=use_cache,
1416
+ output_attentions=output_attentions,
1417
+ output_hidden_states=output_hidden_states,
1418
+ return_dict=return_dict,
1419
+ )
1420
+ hidden_states = transformer_outputs[0]
1421
+ logits = self.score(hidden_states)
1422
+
1423
+ if input_ids is not None:
1424
+ batch_size = input_ids.shape[0]
1425
+ else:
1426
+ batch_size = inputs_embeds.shape[0]
1427
+
1428
+ if self.config.pad_token_id is None and batch_size != 1:
1429
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1430
+ if self.config.pad_token_id is None:
1431
+ sequence_lengths = -1
1432
+ else:
1433
+ if input_ids is not None:
1434
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1435
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1436
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1437
+ sequence_lengths = sequence_lengths.to(logits.device)
1438
+ else:
1439
+ sequence_lengths = -1
1440
+
1441
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1442
+
1443
+ loss = None
1444
+ if labels is not None:
1445
+ labels = labels.to(logits.device)
1446
+ if self.config.problem_type is None:
1447
+ if self.num_labels == 1:
1448
+ self.config.problem_type = "regression"
1449
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1450
+ self.config.problem_type = "single_label_classification"
1451
+ else:
1452
+ self.config.problem_type = "multi_label_classification"
1453
+
1454
+ if self.config.problem_type == "regression":
1455
+ loss_fct = MSELoss()
1456
+ if self.num_labels == 1:
1457
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1458
+ else:
1459
+ loss = loss_fct(pooled_logits, labels)
1460
+ elif self.config.problem_type == "single_label_classification":
1461
+ loss_fct = CrossEntropyLoss()
1462
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1463
+ elif self.config.problem_type == "multi_label_classification":
1464
+ loss_fct = BCEWithLogitsLoss()
1465
+ loss = loss_fct(pooled_logits, labels)
1466
+ if not return_dict:
1467
+ output = (pooled_logits,) + transformer_outputs[1:]
1468
+ return ((loss,) + output) if loss is not None else output
1469
+
1470
+ return SequenceClassifierOutputWithPast(
1471
+ loss=loss,
1472
+ logits=pooled_logits,
1473
+ past_key_values=transformer_outputs.past_key_values,
1474
+ hidden_states=transformer_outputs.hidden_states,
1475
+ attentions=transformer_outputs.attentions,
1476
+ )
1477
+
1478
+
1479
+ @add_start_docstrings(
1480
+ """
1481
+ The Llama Model transformer with a span classification head on top for extractive question-answering tasks like
1482
+ SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1483
+ """,
1484
+ LLAMA_START_DOCSTRING,
1485
+ )
1486
+ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
1487
+ base_model_prefix = "transformer"
1488
+
1489
+ # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama
1490
+ def __init__(self, config):
1491
+ super().__init__(config)
1492
+ self.transformer = LlamaModel(config)
1493
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1494
+
1495
+ # Initialize weights and apply final processing
1496
+ self.post_init()
1497
+
1498
+ def get_input_embeddings(self):
1499
+ return self.transformer.embed_tokens
1500
+
1501
+ def set_input_embeddings(self, value):
1502
+ self.transformer.embed_tokens = value
1503
+
1504
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1505
+ def forward(
1506
+ self,
1507
+ input_ids: Optional[torch.LongTensor] = None,
1508
+ attention_mask: Optional[torch.FloatTensor] = None,
1509
+ position_ids: Optional[torch.LongTensor] = None,
1510
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1511
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1512
+ start_positions: Optional[torch.LongTensor] = None,
1513
+ end_positions: Optional[torch.LongTensor] = None,
1514
+ output_attentions: Optional[bool] = None,
1515
+ output_hidden_states: Optional[bool] = None,
1516
+ return_dict: Optional[bool] = None,
1517
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1518
+ r"""
1519
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1520
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1521
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1522
+ are not taken into account for computing the loss.
1523
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1524
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1525
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1526
+ are not taken into account for computing the loss.
1527
+ """
1528
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1529
+
1530
+ outputs = self.transformer(
1531
+ input_ids,
1532
+ attention_mask=attention_mask,
1533
+ position_ids=position_ids,
1534
+ past_key_values=past_key_values,
1535
+ inputs_embeds=inputs_embeds,
1536
+ output_attentions=output_attentions,
1537
+ output_hidden_states=output_hidden_states,
1538
+ return_dict=return_dict,
1539
+ )
1540
+
1541
+ sequence_output = outputs[0]
1542
+
1543
+ logits = self.qa_outputs(sequence_output)
1544
+ start_logits, end_logits = logits.split(1, dim=-1)
1545
+ start_logits = start_logits.squeeze(-1).contiguous()
1546
+ end_logits = end_logits.squeeze(-1).contiguous()
1547
+
1548
+ total_loss = None
1549
+ if start_positions is not None and end_positions is not None:
1550
+ # If we are on multi-GPU, split add a dimension
1551
+ if len(start_positions.size()) > 1:
1552
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
1553
+ if len(end_positions.size()) > 1:
1554
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
1555
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1556
+ ignored_index = start_logits.size(1)
1557
+ start_positions = start_positions.clamp(0, ignored_index)
1558
+ end_positions = end_positions.clamp(0, ignored_index)
1559
+
1560
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1561
+ start_loss = loss_fct(start_logits, start_positions)
1562
+ end_loss = loss_fct(end_logits, end_positions)
1563
+ total_loss = (start_loss + end_loss) / 2
1564
+
1565
+ if not return_dict:
1566
+ output = (start_logits, end_logits) + outputs[2:]
1567
+ return ((total_loss,) + output) if total_loss is not None else output
1568
+
1569
+ return QuestionAnsweringModelOutput(
1570
+ loss=total_loss,
1571
+ start_logits=start_logits,
1572
+ end_logits=end_logits,
1573
+ hidden_states=outputs.hidden_states,
1574
+ attentions=outputs.attentions,
1575
+ )
1576
+
1577
+
1578
+ @add_start_docstrings(
1579
+ """
1580
+ The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states
1581
+ output) e.g. for Named-Entity-Recognition (NER) tasks.
1582
+ """,
1583
+ LLAMA_START_DOCSTRING,
1584
+ )
1585
+ class LlamaForTokenClassification(LlamaPreTrainedModel):
1586
+ def __init__(self, config):
1587
+ super().__init__(config)
1588
+ self.num_labels = config.num_labels
1589
+ self.model = LlamaModel(config)
1590
+ if getattr(config, "classifier_dropout", None) is not None:
1591
+ classifier_dropout = config.classifier_dropout
1592
+ elif getattr(config, "hidden_dropout", None) is not None:
1593
+ classifier_dropout = config.hidden_dropout
1594
+ else:
1595
+ classifier_dropout = 0.1
1596
+ self.dropout = nn.Dropout(classifier_dropout)
1597
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
1598
+
1599
+ # Initialize weights and apply final processing
1600
+ self.post_init()
1601
+
1602
+ def get_input_embeddings(self):
1603
+ return self.model.embed_tokens
1604
+
1605
+ def set_input_embeddings(self, value):
1606
+ self.model.embed_tokens = value
1607
+
1608
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1609
+ def forward(
1610
+ self,
1611
+ input_ids: Optional[torch.LongTensor] = None,
1612
+ attention_mask: Optional[torch.Tensor] = None,
1613
+ position_ids: Optional[torch.LongTensor] = None,
1614
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1615
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1616
+ labels: Optional[torch.LongTensor] = None,
1617
+ use_cache: Optional[bool] = None,
1618
+ output_attentions: Optional[bool] = None,
1619
+ output_hidden_states: Optional[bool] = None,
1620
+ return_dict: Optional[bool] = None,
1621
+ ) -> Union[Tuple, TokenClassifierOutput]:
1622
+ r"""
1623
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1624
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1625
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1626
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1627
+ """
1628
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1629
+
1630
+ outputs = self.model(
1631
+ input_ids,
1632
+ attention_mask=attention_mask,
1633
+ position_ids=position_ids,
1634
+ past_key_values=past_key_values,
1635
+ inputs_embeds=inputs_embeds,
1636
+ use_cache=use_cache,
1637
+ output_attentions=output_attentions,
1638
+ output_hidden_states=output_hidden_states,
1639
+ return_dict=return_dict,
1640
+ )
1641
+ sequence_output = outputs[0]
1642
+ sequence_output = self.dropout(sequence_output)
1643
+ logits = self.score(sequence_output)
1644
+
1645
+ loss = None
1646
+ if labels is not None:
1647
+ loss_fct = CrossEntropyLoss()
1648
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1649
+
1650
+ if not return_dict:
1651
+ output = (logits,) + outputs[2:]
1652
+ return ((loss,) + output) if loss is not None else output
1653
+
1654
+ return TokenClassifierOutput(
1655
+ loss=loss,
1656
+ logits=logits,
1657
+ hidden_states=outputs.hidden_states,
1658
+ attentions=outputs.attentions,
1659
+ )