shadeMe-explosion commited on
Commit
9e6cb47
·
unverified ·
1 Parent(s): 235f4b6

Pull changes from upstream `tiiuae/falcon-40b`

Browse files
config.json CHANGED
@@ -3,16 +3,16 @@
3
  "alibi": false,
4
  "apply_residual_connection_post_layernorm": false,
5
  "architectures": [
6
- "RWForCausalLM"
7
  ],
8
  "attention_dropout": 0.0,
9
  "auto_map": {
10
- "AutoConfig": "configuration_RW.RWConfig",
11
- "AutoModel": "modelling_RW.RWModel",
12
- "AutoModelForCausalLM": "modelling_RW.RWForCausalLM",
13
- "AutoModelForQuestionAnswering": "modelling_RW.RWForQuestionAnswering",
14
- "AutoModelForSequenceClassification": "modelling_RW.RWForSequenceClassification",
15
- "AutoModelForTokenClassification": "modelling_RW.RWForTokenClassification"
16
  },
17
  "bias": false,
18
  "bos_token_id": 11,
@@ -21,10 +21,11 @@
21
  "hidden_size": 32,
22
  "initializer_range": 0.02,
23
  "layer_norm_epsilon": 1e-05,
24
- "model_type": "RefinedWebModel",
 
25
  "multi_query": true,
26
- "n_head": 4,
27
- "n_layer": 5,
28
  "parallel_attn": true,
29
  "torch_dtype": "float32",
30
  "transformers_version": "4.28.1",
 
3
  "alibi": false,
4
  "apply_residual_connection_post_layernorm": false,
5
  "architectures": [
6
+ "FalconForCausalLM"
7
  ],
8
  "attention_dropout": 0.0,
9
  "auto_map": {
10
+ "AutoConfig": "configuration_falcon.FalconConfig",
11
+ "AutoModel": "modeling_falcon.FalconModel",
12
+ "AutoModelForSequenceClassification": "modeling_falcon.FalconForSequenceClassification",
13
+ "AutoModelForTokenClassification": "modeling_falcon.FalconForTokenClassification",
14
+ "AutoModelForQuestionAnswering": "modeling_falcon.FalconForQuestionAnswering",
15
+ "AutoModelForCausalLM": "modeling_falcon.FalconForCausalLM"
16
  },
17
  "bias": false,
18
  "bos_token_id": 11,
 
21
  "hidden_size": 32,
22
  "initializer_range": 0.02,
23
  "layer_norm_epsilon": 1e-05,
24
+ "model_type": "falcon",
25
+ "new_decoder_architecture": true,
26
  "multi_query": true,
27
+ "num_attention_heads": 4,
28
+ "num_hidden_layers": 5,
29
  "parallel_attn": true,
30
  "torch_dtype": "float32",
31
  "transformers_version": "4.28.1",
configuration_RW.py DELETED
@@ -1,79 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2022 the Big Science Workshop and HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ Bloom configuration"""
16
- from transformers.configuration_utils import PretrainedConfig
17
- from transformers.utils import logging
18
-
19
-
20
- logger = logging.get_logger(__name__)
21
-
22
-
23
- class RWConfig(PretrainedConfig):
24
- model_type = "RefinedWebModel"
25
- keys_to_ignore_at_inference = ["past_key_values"]
26
- attribute_map = {
27
- "num_hidden_layers": "n_layer",
28
- "num_attention_heads": "n_head",
29
- }
30
-
31
- def __init__(
32
- self,
33
- vocab_size=250880,
34
- hidden_size=64,
35
- n_layer=2,
36
- n_head=8,
37
- layer_norm_epsilon=1e-5,
38
- initializer_range=0.02,
39
- use_cache=True,
40
- bos_token_id=1,
41
- eos_token_id=2,
42
- apply_residual_connection_post_layernorm=False,
43
- hidden_dropout=0.0,
44
- attention_dropout=0.0,
45
- multi_query=False,
46
- alibi=False,
47
- bias=False,
48
- parallel_attn=False,
49
- **kwargs,
50
- ):
51
- self.vocab_size = vocab_size
52
- # Backward compatibility with n_embed kwarg
53
- n_embed = kwargs.pop("n_embed", None)
54
- self.hidden_size = hidden_size if n_embed is None else n_embed
55
- self.n_layer = n_layer
56
- self.n_head = n_head
57
- self.layer_norm_epsilon = layer_norm_epsilon
58
- self.initializer_range = initializer_range
59
- self.use_cache = use_cache
60
- self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
61
- self.hidden_dropout = hidden_dropout
62
- self.attention_dropout = attention_dropout
63
-
64
- self.bos_token_id = bos_token_id
65
- self.eos_token_id = eos_token_id
66
- self.multi_query = multi_query
67
- self.alibi = alibi
68
- self.bias = bias
69
- self.parallel_attn = parallel_attn
70
-
71
- super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
72
-
73
- @property
74
- def head_dim(self):
75
- return self.hidden_size // self.n_head
76
-
77
- @property
78
- def rotary(self):
79
- return not self.alibi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configuration_falcon.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Falcon configuration"""
16
+ from transformers.configuration_utils import PretrainedConfig
17
+ from transformers.utils import logging
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+ FALCON_PRETRAINED_CONFIG_ARCHIVE_MAP = {
22
+ "tiiuae/falcon-40b": "https://huggingface.co/tiiuae/falcon-40b/resolve/main/config.json",
23
+ "tiiuae/falcon-7b": "https://huggingface.co/tiiuae/falcon-7b/resolve/main/config.json",
24
+ }
25
+
26
+
27
+ class FalconConfig(PretrainedConfig):
28
+ r"""
29
+ This is the configuration class to store the configuration of a [`FalconModel`]. It is used to instantiate a Falcon
30
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
31
+ defaults will yield a similar configuration to that of the
32
+ [tiiuae/falcon-7b](https://huggingface.co/tiiuae/falcon-7b) architecture.
33
+
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+
37
+
38
+ Args:
39
+ vocab_size (`int`, *optional*, defaults to 65024):
40
+ Vocabulary size of the Falcon model. Defines the number of different tokens that can be represented by the
41
+ `inputs_ids` passed when calling [`FalconModel`]
42
+ hidden_size (`int`, *optional*, defaults to 4544):
43
+ Dimension of the hidden representations.
44
+ num_hidden_layers (`int`, *optional*, defaults to 32):
45
+ Number of hidden layers in the Transformer decoder.
46
+ num_attention_heads (`int`, *optional*, defaults to 71):
47
+ Number of attention heads for each attention layer in the Transformer encoder.
48
+ initializer_range (`float`, *optional*, defaults to 0.02):
49
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
50
+ use_cache (`bool`, *optional*, defaults to `True`):
51
+ Whether the model should return the last key/values attentions (not used by all models). Only relevant if
52
+ `config.is_decoder=True`.
53
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
54
+ The epsilon used by the layer normalization layers.
55
+ hidden_dropout (`float`, *optional*, defaults to 0.0):
56
+ The dropout probability for MLP layers.
57
+ attention_dropout (`float`, *optional*, defaults to 0.0):
58
+ The dropout probability for attention layers.
59
+ num_kv_heads (`int`, *optional*):
60
+ Number of key-value heads to use per attention layer. If unset, defaults to the same value as
61
+ `num_attention_heads`.
62
+ alibi (`bool`, *optional*, defaults to `False`):
63
+ Whether to use ALiBi positional biases during self-attention.
64
+ new_decoder_architecture (`bool`, *optional*, defaults to `False`):
65
+ Whether to use the new (Falcon-40B) decoder architecture. If `True`, the `multi_query` and `parallel_attn`
66
+ arguments are ignored, as the new decoder always uses parallel attention.
67
+ multi_query (`bool`, *optional*, defaults to `True`):
68
+ Whether to use multi-query attention in the decoder. Ignored when `new_decoder_architecture` is `True`.
69
+ parallel_attn (`bool`, *optional*, defaults to `True`):
70
+ Whether to compute attention in parallel with the feedforward layer. If False, they are consecutive
71
+ instead, as in the original Transformer architecture. Ignored when `new_decoder_architecture` is `True`.
72
+ bias (`bool`, *optional*, defaults to `False`):
73
+ Whether to use bias on Linear layers.
74
+ bos_token_id (`int`, *optional*, defaults to 11):
75
+ The id of the "beginning-of-sequence" token.
76
+ eos_token_id (`int`, *optional*, defaults to 11):
77
+ The id of the "end-of-sequence" token.
78
+
79
+ Example:
80
+
81
+ ```python
82
+ >>> from transformers import FalconModel, FalconConfig
83
+
84
+ >>> # Initializing a small (2-layer) Falcon configuration
85
+ >>> configuration = FalconConfig(num_hidden_layers=2)
86
+
87
+ >>> # Initializing a model from the small configuration
88
+ >>> model = FalconModel(configuration)
89
+
90
+ >>> # Accessing the model configuration
91
+ >>> configuration = model.config
92
+ ```"""
93
+ model_type = "falcon"
94
+ keys_to_ignore_at_inference = ["past_key_values"]
95
+
96
+ def __init__(
97
+ self,
98
+ vocab_size=65024,
99
+ hidden_size=4544,
100
+ num_hidden_layers=32,
101
+ num_attention_heads=71,
102
+ layer_norm_epsilon=1e-5,
103
+ initializer_range=0.02,
104
+ use_cache=True,
105
+ hidden_dropout=0.0,
106
+ attention_dropout=0.0,
107
+ num_kv_heads=None,
108
+ alibi=False,
109
+ new_decoder_architecture=False,
110
+ multi_query=True,
111
+ parallel_attn=True,
112
+ bias=False,
113
+ bos_token_id=11,
114
+ eos_token_id=11,
115
+ **kwargs,
116
+ ):
117
+ self.vocab_size = vocab_size
118
+ # Backward compatibility with n_embed kwarg
119
+ n_embed = kwargs.pop("n_embed", None)
120
+ self.hidden_size = hidden_size if n_embed is None else n_embed
121
+ self.num_hidden_layers = num_hidden_layers
122
+ self.num_attention_heads = num_attention_heads
123
+ self.layer_norm_epsilon = layer_norm_epsilon
124
+ self.initializer_range = initializer_range
125
+ self.use_cache = use_cache
126
+ self.hidden_dropout = hidden_dropout
127
+ self.attention_dropout = attention_dropout
128
+
129
+ self.bos_token_id = bos_token_id
130
+ self.eos_token_id = eos_token_id
131
+ self.num_kv_heads = num_attention_heads if num_kv_heads is None else num_kv_heads
132
+ self.alibi = alibi
133
+ self.new_decoder_architecture = new_decoder_architecture
134
+ self.multi_query = multi_query # Ignored when new_decoder_architecture is True
135
+ self.parallel_attn = parallel_attn
136
+ self.bias = bias
137
+
138
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
139
+
140
+ @property
141
+ def head_dim(self):
142
+ return self.hidden_size // self.num_attention_heads
143
+
144
+ @property
145
+ def rotary(self):
146
+ return not self.alibi
generation_config.json CHANGED
@@ -2,5 +2,5 @@
2
  "_from_model_config": true,
3
  "bos_token_id": 11,
4
  "eos_token_id": 11,
5
- "transformers_version": "4.28.1"
6
  }
 
2
  "_from_model_config": true,
3
  "bos_token_id": 11,
4
  "eos_token_id": 11,
5
+ "transformers_version": "4.31.0.dev0"
6
  }
modelling_RW.py → modeling_falcon.py RENAMED
@@ -1,9 +1,20 @@
1
- # port of models described in RW
2
- # We use the bloom model as a starting point for these model.
3
- # Please refer to the bloom models for usage instructions.
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  import math
6
- import warnings
7
  from typing import Optional, Tuple, Union
8
 
9
  import torch
@@ -11,68 +22,67 @@ import torch.utils.checkpoint
11
  from torch import nn
12
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
13
  from torch.nn import functional as F
14
-
15
  from transformers.modeling_outputs import (
16
  BaseModelOutputWithPastAndCrossAttentions,
17
- CausalLMOutputWithCrossAttentions,
18
- QuestionAnsweringModelOutput,
19
- SequenceClassifierOutputWithPast,
20
- TokenClassifierOutput,
21
- )
22
  from transformers.modeling_utils import PreTrainedModel
23
- from transformers.utils import logging
24
- from .configuration_RW import RWConfig
 
 
 
25
 
26
  logger = logging.get_logger(__name__)
27
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
29
  # In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
30
- class Linear(nn.Linear):
31
  def forward(self, input: torch.Tensor) -> torch.Tensor:
32
- ret = input @ self.weight.T
33
  if self.bias is None:
34
- return ret
35
- else:
36
- return ret + self.bias
37
-
38
 
39
- from einops import rearrange
40
 
41
  # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
42
  def rotate_half(x):
43
  x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
44
- return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0
45
 
46
 
47
- class RotaryEmbedding(torch.nn.Module):
48
  """Implementation of RotaryEmbedding from GPT-NeoX.
49
- This implementation is design to operate on queries and keys that are compatible with
50
- [batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format).
51
  """
52
 
53
- def __init__(
54
- self,
55
- head_dim: int,
56
- base=10000,
57
- ):
58
  super().__init__()
59
  inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
60
  self.register_buffer("inv_freq", inv_freq, persistent=False)
61
  self.head_dim = head_dim
62
- self.seq_len_cached = None
63
- self.batch_size_cached = None
64
  self.cos_cached: torch.Tensor | None = None
65
  self.sin_cached: torch.Tensor | None = None
66
 
67
- def cos_sin(
68
- self,
69
- seq_len: int,
70
- device="cuda",
71
- dtype=torch.bfloat16,
72
- ) -> torch.Tensor:
73
- if seq_len != self.seq_len_cached:
74
- self.seq_len_cached = seq_len
75
- t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
76
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
77
  emb = torch.cat((freqs, freqs), dim=-1).to(device)
78
 
@@ -85,36 +95,46 @@ class RotaryEmbedding(torch.nn.Module):
85
  self.cos_cached = self.cos_cached.type(dtype)
86
  self.sin_cached = self.sin_cached.type(dtype)
87
 
88
- return self.cos_cached, self.sin_cached
 
 
 
89
 
90
- def forward(self, q, k):
91
- batch, seq_len, head_dim = q.shape
92
- cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
93
- return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
94
 
95
 
96
  def _make_causal_mask(
97
  input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
98
  ) -> torch.BoolTensor:
 
 
 
 
 
99
  batch_size, target_length = input_ids_shape
100
- mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
101
- # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
102
- seq_ids = torch.arange(target_length, device=device)
103
- mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
104
-
105
- if past_key_values_length > 0:
106
- mask[:, :past_key_values_length] = False
107
 
 
 
 
 
 
 
108
  expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
109
  return expanded_mask
110
 
111
 
112
- def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
113
- batch_size, src_length = mask.shape
114
- tgt_length = tgt_length if tgt_length is not None else src_length
 
 
 
115
 
116
  expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
117
- return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
118
 
119
 
120
  def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
@@ -145,18 +165,32 @@ def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torc
145
  return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
146
 
147
 
 
148
  def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  out = F.dropout(x, p=prob, training=training)
150
  out = residual + out
151
  return out
152
 
153
 
154
- class Attention(nn.Module):
155
- def __init__(self, config: RWConfig):
156
  super().__init__()
157
 
158
  self.hidden_size = config.hidden_size
159
- self.num_heads = config.n_head
160
  self.head_dim = self.hidden_size // self.num_heads
161
  self.split_size = self.hidden_size
162
  self.hidden_dropout = config.hidden_dropout
@@ -167,26 +201,27 @@ class Attention(nn.Module):
167
  f" {self.num_heads})."
168
  )
169
 
170
- self.maybe_rotary = RotaryEmbedding(config.head_dim) if config.rotary else lambda q, k: (q, k)
171
 
172
  # Layer-wise attention scaling
173
  self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
174
  self.beta = self.inv_norm_factor
175
-
176
- self.query_key_value = Linear(
177
- self.hidden_size,
178
- 3 * self.hidden_size if not config.multi_query else (self.hidden_size + 2 * self.head_dim),
179
- bias=config.bias,
180
- )
 
 
181
  self.multi_query = config.multi_query
182
- self.dense = Linear(self.hidden_size, self.hidden_size, bias=config.bias)
183
  self.attention_dropout = nn.Dropout(config.attention_dropout)
184
- self.num_kv = config.n_head if not self.multi_query else 1
185
 
186
  def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
187
  """
188
- Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
189
- storage as `fused_qkv`
190
 
191
  Args:
192
  fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
@@ -195,7 +230,18 @@ class Attention(nn.Module):
195
  query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
196
  value: [batch_size, seq_length, num_heads, head_dim]
197
  """
198
- if not self.multi_query:
 
 
 
 
 
 
 
 
 
 
 
199
  batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
200
  fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
201
  return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
@@ -204,12 +250,13 @@ class Attention(nn.Module):
204
  fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
205
  return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
206
 
 
207
  def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
208
  """
209
  Merge heads together over the last dimenstion
210
 
211
  Args:
212
- x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
213
 
214
  Returns:
215
  torch.tensor: [batch_size, seq_length, num_heads * head_dim]
@@ -232,7 +279,7 @@ class Attention(nn.Module):
232
  def forward(
233
  self,
234
  hidden_states: torch.Tensor,
235
- alibi: torch.Tensor,
236
  attention_mask: torch.Tensor,
237
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
238
  head_mask: Optional[torch.Tensor] = None,
@@ -240,105 +287,120 @@ class Attention(nn.Module):
240
  output_attentions: bool = False,
241
  ):
242
  fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
243
-
244
  # 3 x [batch_size, seq_length, num_heads, head_dim]
245
  (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
246
 
247
- batch_size, q_length, _, _ = query_layer.shape
248
 
249
- query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
250
  key_layer = key_layer.transpose(1, 2).reshape(
251
- batch_size * self.num_kv,
252
- q_length,
253
  self.head_dim,
254
  )
255
- value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
256
 
257
- query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
 
258
 
259
  if layer_past is not None:
260
  past_key, past_value = layer_past
261
  # concatenate along seq_length dimension:
262
- # - key: [batch_size * self.num_heads, head_dim, kv_length]
263
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
264
  key_layer = torch.cat((past_key, key_layer), dim=1)
265
  value_layer = torch.cat((past_value, value_layer), dim=1)
266
 
267
  _, kv_length, _ = key_layer.shape
268
-
269
- if use_cache is True:
270
  present = (key_layer, value_layer)
271
  else:
272
  present = None
273
 
 
 
 
 
 
 
274
  if alibi is None:
275
- query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
276
- key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
277
- value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
 
 
278
 
279
- attn_output = F.scaled_dot_product_attention(
280
- query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
281
- )
 
 
 
 
 
 
282
 
283
- x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
284
- x = x.permute(0, 2, 1, 3)
285
- attn_output = x.reshape(batch_size, q_length, self.num_heads * self.head_dim)
286
 
287
  output_tensor = self.dense(attn_output)
288
 
289
- outputs = (output_tensor, present)
290
- assert not output_attentions # not supported.
291
- return outputs
 
 
292
  else:
293
- attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16)
294
- matmul_result = query_layer @ key_layer.transpose(-1, -2)
295
 
296
  # change view to [batch_size, num_heads, q_length, kv_length]
297
- attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
298
 
299
  # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
300
  input_dtype = attention_scores.dtype
301
  # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
302
  if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
303
  attention_scores = attention_scores.to(torch.float32)
304
- # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
305
- attention_probs = F.softmax(
306
- (attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) * self.inv_norm_factor + attention_mask_float,
307
- dim=-1,
308
- dtype=hidden_states.dtype,
309
- )
 
310
  # [batch_size, num_heads, q_length, kv_length]
311
  attention_probs = self.attention_dropout(attention_probs)
312
 
313
  if head_mask is not None:
314
  attention_probs = attention_probs * head_mask
315
 
316
- # change view [batch_size x num_heads, q_length, kv_length]
317
- attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
318
 
319
  # matmul: [batch_size * num_heads, q_length, head_dim]
320
- context_layer = attention_probs_reshaped @ value_layer
321
 
322
  # change view [batch_size, num_heads, q_length, head_dim]
323
  context_layer = self._merge_heads(context_layer)
324
 
325
  output_tensor = self.dense(context_layer)
326
 
327
- outputs = (output_tensor, present)
328
  if output_attentions:
329
- outputs += (attention_probs,)
330
-
331
- return outputs
332
 
333
 
334
- class MLP(nn.Module):
335
- def __init__(self, config: RWConfig):
336
  super().__init__()
337
  hidden_size = config.hidden_size
338
 
339
- self.dense_h_to_4h = Linear(hidden_size, 4 * hidden_size, bias=config.bias)
340
  self.act = nn.GELU()
341
- self.dense_4h_to_h = Linear(4 * hidden_size, hidden_size, bias=config.bias)
342
  self.hidden_dropout = config.hidden_dropout
343
 
344
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -347,43 +409,47 @@ class MLP(nn.Module):
347
  return x
348
 
349
 
350
- class DecoderLayer(nn.Module):
351
- def __init__(self, config: RWConfig):
352
  super().__init__()
353
  hidden_size = config.hidden_size
354
-
355
- self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
356
- self.num_heads = config.n_head
357
- self.self_attention = Attention(config)
358
-
359
- if not config.parallel_attn:
360
- # unused if parallel attn
361
- self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
362
-
363
- self.mlp = MLP(config)
364
-
365
- self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
366
  self.hidden_dropout = config.hidden_dropout
367
-
368
  self.config = config
369
 
 
 
 
 
 
 
 
 
 
 
370
  def forward(
371
  self,
372
  hidden_states: torch.Tensor,
373
- alibi: torch.Tensor,
374
  attention_mask: torch.Tensor,
375
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
376
  head_mask: Optional[torch.Tensor] = None,
377
  use_cache: bool = False,
378
  output_attentions: bool = False,
379
  ):
380
-
381
- layernorm_output = self.input_layernorm(hidden_states)
382
  residual = hidden_states
383
 
 
 
 
 
 
 
384
  # Self attention.
385
  attn_outputs = self.self_attention(
386
- layernorm_output,
387
  layer_past=layer_past,
388
  attention_mask=attention_mask,
389
  alibi=alibi,
@@ -394,16 +460,21 @@ class DecoderLayer(nn.Module):
394
 
395
  attention_output = attn_outputs[0]
396
 
397
- if not self.config.parallel_attn:
398
- residual = dropout_add(attention_output, residual, self.config.attention_dropout, training=self.training)
399
- layernorm_output = self.post_attention_layernorm(residual)
 
 
 
 
 
400
 
401
  outputs = attn_outputs[1:]
402
 
403
  # MLP.
404
- mlp_output = self.mlp(layernorm_output)
405
 
406
- if self.config.parallel_attn:
407
  mlp_output += attention_output
408
 
409
  output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
@@ -416,24 +487,93 @@ class DecoderLayer(nn.Module):
416
  return outputs # hidden_states, present, attentions
417
 
418
 
419
- class RWPreTrainedModel(PreTrainedModel):
420
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  """
422
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
423
  models.
424
  """
425
 
426
- config_class = RWConfig
427
  base_model_prefix = "transformer"
428
  supports_gradient_checkpointing = True
429
- _no_split_modules = ["DecoderLayer"]
430
 
431
  def __init__(self, *inputs, **kwargs):
432
  super().__init__(*inputs, **kwargs)
433
 
434
  def _init_weights(self, module: nn.Module):
435
  """Initialize the weights."""
436
- if isinstance(module, nn.Linear) or isinstance(module, Linear):
437
  # Slightly different from the TF version which uses truncated_normal for initialization
438
  # cf https://github.com/pytorch/pytorch/pull/5617
439
  module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
@@ -447,26 +587,28 @@ class RWPreTrainedModel(PreTrainedModel):
447
  module.bias.data.zero_()
448
  module.weight.data.fill_(1.0)
449
 
 
450
  def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
451
- if isinstance(module, RWModel):
452
  module.gradient_checkpointing = value
453
 
454
  @staticmethod
455
- def _convert_to_standard_cache(
456
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
457
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
458
  """
459
  Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
460
  num_heads, ...]))
461
  """
462
- batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
 
 
 
463
  num_heads = batch_size_times_num_heads // batch_size
464
- # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
465
- # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
466
  return tuple(
467
  (
468
- layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
469
- layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
470
  )
471
  for layer_past in past_key_value
472
  )
@@ -475,32 +617,35 @@ class RWPreTrainedModel(PreTrainedModel):
475
  def _convert_to_rw_cache(
476
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
477
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
478
- batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
479
  batch_size_times_num_heads = batch_size * num_heads
480
- # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
481
- # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
482
  return tuple(
483
  (
484
- layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
485
- layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
486
  )
487
  for layer_past in past_key_value
488
  )
489
 
490
 
491
- class RWModel(RWPreTrainedModel):
492
- def __init__(self, config: RWConfig):
 
 
 
 
493
  super().__init__(config)
494
 
495
  self.embed_dim = config.hidden_size
496
- self.num_heads = config.n_head
497
- self.alibi = config.alibi
498
 
499
  # Embedding + LN Embedding
500
  self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
501
 
502
  # Transformer blocks
503
- self.h = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
504
 
505
  # Final Layer Norm
506
  self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@@ -513,22 +658,31 @@ class RWModel(RWPreTrainedModel):
513
  def get_input_embeddings(self):
514
  return self.word_embeddings
515
 
 
516
  def _prepare_attn_mask(
517
- self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
518
  ) -> torch.BoolTensor:
519
- # create causal mask
520
- # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
 
 
 
 
 
 
 
 
521
  combined_attention_mask = None
522
  device = attention_mask.device
523
- _, src_length = input_shape
524
 
525
- if src_length > 1:
526
  combined_attention_mask = _make_causal_mask(
527
  input_shape, device=device, past_key_values_length=past_key_values_length
528
  )
529
 
530
- # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
531
- expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
532
  combined_attention_mask = (
533
  expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
534
  )
@@ -538,6 +692,12 @@ class RWModel(RWPreTrainedModel):
538
  def set_input_embeddings(self, new_embeddings: torch.Tensor):
539
  self.word_embeddings = new_embeddings
540
 
 
 
 
 
 
 
541
  def forward(
542
  self,
543
  input_ids: Optional[torch.LongTensor] = None,
@@ -549,18 +709,7 @@ class RWModel(RWPreTrainedModel):
549
  output_attentions: Optional[bool] = None,
550
  output_hidden_states: Optional[bool] = None,
551
  return_dict: Optional[bool] = None,
552
- **deprecated_arguments,
553
  ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
554
- if deprecated_arguments.pop("position_ids", False) is not False:
555
- # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
556
- warnings.warn(
557
- "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
558
- " passing `position_ids`.",
559
- FutureWarning,
560
- )
561
- if len(deprecated_arguments) > 0:
562
- raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
563
-
564
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
565
  output_hidden_states = (
566
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -579,12 +728,14 @@ class RWModel(RWPreTrainedModel):
579
 
580
  if past_key_values is None:
581
  past_key_values = tuple([None] * len(self.h))
 
 
582
 
583
  # Prepare head mask if needed
584
  # 1.0 in head_mask indicate we keep the head
585
  # attention_probs has shape batch_size x num_heads x N x N
586
  # head_mask has shape n_layer x batch x num_heads x N x N
587
- head_mask = self.get_head_mask(head_mask, self.config.n_layer)
588
 
589
  if inputs_embeds is None:
590
  inputs_embeds = self.word_embeddings(input_ids)
@@ -596,17 +747,15 @@ class RWModel(RWPreTrainedModel):
596
  all_hidden_states = () if output_hidden_states else None
597
 
598
  # Compute alibi tensor: check build_alibi_tensor documentation
599
- seq_length_with_past = seq_length
600
  past_key_values_length = 0
601
  if past_key_values[0] is not None:
602
- past_key_values_length = past_key_values[0][0].shape[2]
603
- seq_length_with_past = seq_length_with_past + past_key_values_length
604
  if attention_mask is None:
605
- attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
606
  else:
607
  attention_mask = attention_mask.to(hidden_states.device)
608
 
609
- if self.alibi:
610
  alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
611
  else:
612
  alibi = None
@@ -618,12 +767,10 @@ class RWModel(RWPreTrainedModel):
618
  )
619
 
620
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
621
-
622
  if output_hidden_states:
623
  all_hidden_states = all_hidden_states + (hidden_states,)
624
 
625
  if self.gradient_checkpointing and self.training:
626
-
627
  if use_cache:
628
  logger.warning(
629
  "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
@@ -668,6 +815,9 @@ class RWModel(RWPreTrainedModel):
668
  if output_hidden_states:
669
  all_hidden_states = all_hidden_states + (hidden_states,)
670
 
 
 
 
671
  if not return_dict:
672
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
673
 
@@ -679,12 +829,16 @@ class RWModel(RWPreTrainedModel):
679
  )
680
 
681
 
682
- class RWForCausalLM(RWPreTrainedModel):
683
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
 
 
 
 
684
 
685
- def __init__(self, config: RWConfig):
686
  super().__init__(config)
687
- self.transformer = RWModel(config)
688
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
689
 
690
  # Initialize weights and apply final processing
@@ -699,25 +853,26 @@ class RWForCausalLM(RWPreTrainedModel):
699
  def prepare_inputs_for_generation(
700
  self,
701
  input_ids: torch.LongTensor,
702
- past: Optional[torch.Tensor] = None,
703
  attention_mask: Optional[torch.Tensor] = None,
704
  **kwargs,
705
  ) -> dict:
706
- # only last token for input_ids if past is not None
707
- if past:
708
- input_ids = input_ids[:, -1].unsqueeze(-1)
709
-
710
- # the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
711
- if past[0][0].shape[0] == input_ids.shape[0]:
712
- past = self._convert_to_rw_cache(past)
713
 
714
  return {
715
  "input_ids": input_ids,
716
- "past_key_values": past,
717
  "use_cache": kwargs.get("use_cache"),
718
  "attention_mask": attention_mask,
719
  }
720
 
 
 
 
 
 
 
721
  def forward(
722
  self,
723
  input_ids: Optional[torch.LongTensor] = None,
@@ -730,7 +885,6 @@ class RWForCausalLM(RWPreTrainedModel):
730
  output_attentions: Optional[bool] = None,
731
  output_hidden_states: Optional[bool] = None,
732
  return_dict: Optional[bool] = None,
733
- **deprecated_arguments,
734
  ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
735
  r"""
736
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -738,15 +892,6 @@ class RWForCausalLM(RWPreTrainedModel):
738
  `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
739
  are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
740
  """
741
- if deprecated_arguments.pop("position_ids", False) is not False:
742
- # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
743
- warnings.warn(
744
- "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
745
- " passing `position_ids`.",
746
- FutureWarning,
747
- )
748
- if len(deprecated_arguments) > 0:
749
- raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
750
 
751
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
752
 
@@ -799,7 +944,6 @@ class RWForCausalLM(RWPreTrainedModel):
799
 
800
  Output shares the same memory storage as `past`.
801
  """
802
- standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
803
 
804
  # Get a copy of `beam_idx` on all the devices where we need those indices.
805
  device_to_beam_idx = {
@@ -810,23 +954,42 @@ class RWForCausalLM(RWPreTrainedModel):
810
  layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
811
  layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
812
  )
813
- for layer_past in standardized_past
814
  )
815
- return self._convert_to_rw_cache(reordered_past)
816
-
817
 
818
- class RWForSequenceClassification(RWPreTrainedModel):
819
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
820
 
821
- def __init__(self, config: RWConfig):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
822
  super().__init__(config)
823
  self.num_labels = config.num_labels
824
- self.transformer = RWModel(config)
825
  self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
826
 
827
  # Initialize weights and apply final processing
828
  self.post_init()
829
 
 
 
 
 
 
 
830
  def forward(
831
  self,
832
  input_ids: Optional[torch.LongTensor] = None,
@@ -839,7 +1002,6 @@ class RWForSequenceClassification(RWPreTrainedModel):
839
  output_attentions: Optional[bool] = None,
840
  output_hidden_states: Optional[bool] = None,
841
  return_dict: Optional[bool] = None,
842
- **deprecated_arguments,
843
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
844
  r"""
845
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -847,15 +1009,6 @@ class RWForSequenceClassification(RWPreTrainedModel):
847
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
848
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
849
  """
850
- if deprecated_arguments.pop("position_ids", False) is not False:
851
- # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
852
- warnings.warn(
853
- "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
854
- " passing `position_ids`.",
855
- FutureWarning,
856
- )
857
- if len(deprecated_arguments) > 0:
858
- raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
859
 
860
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
861
 
@@ -930,17 +1083,22 @@ class RWForSequenceClassification(RWPreTrainedModel):
930
  )
931
 
932
 
933
- class RWForTokenClassification(RWPreTrainedModel):
934
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
935
-
936
- def __init__(self, config: RWConfig):
 
 
 
 
 
937
  super().__init__(config)
938
  self.num_labels = config.num_labels
939
 
940
- self.transformer = RWModel(config)
941
- if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
942
  classifier_dropout = config.classifier_dropout
943
- elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
944
  classifier_dropout = config.hidden_dropout
945
  else:
946
  classifier_dropout = 0.1
@@ -950,6 +1108,12 @@ class RWForTokenClassification(RWPreTrainedModel):
950
  # Initialize weights and apply final processing
951
  self.post_init()
952
 
 
 
 
 
 
 
953
  def forward(
954
  self,
955
  input_ids: Optional[torch.LongTensor] = None,
@@ -962,7 +1126,6 @@ class RWForTokenClassification(RWPreTrainedModel):
962
  output_attentions: Optional[bool] = None,
963
  output_hidden_states: Optional[bool] = None,
964
  return_dict: Optional[bool] = None,
965
- **deprecated_arguments,
966
  ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
967
  r"""
968
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -970,15 +1133,6 @@ class RWForTokenClassification(RWPreTrainedModel):
970
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
971
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
972
  """
973
- if deprecated_arguments.pop("position_ids", False) is not False:
974
- # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
975
- warnings.warn(
976
- "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
977
- " passing `position_ids`.",
978
- FutureWarning,
979
- )
980
- if len(deprecated_arguments) > 0:
981
- raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
982
 
983
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
984
 
@@ -1002,7 +1156,9 @@ class RWForTokenClassification(RWPreTrainedModel):
1002
  if labels is not None:
1003
  batch_size, seq_length = labels.shape
1004
  loss_fct = CrossEntropyLoss()
1005
- loss = loss_fct(logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length))
 
 
1006
 
1007
  if not return_dict:
1008
  output = (logits,) + transformer_outputs[2:]
@@ -1016,22 +1172,27 @@ class RWForTokenClassification(RWPreTrainedModel):
1016
  )
1017
 
1018
 
1019
- class RWForQuestionAnswering(RWPreTrainedModel):
1020
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
1021
-
 
 
 
 
 
1022
  def __init__(self, config):
1023
  super().__init__(config)
1024
- self.transformer = RWModel(config)
1025
  self.qa_outputs = nn.Linear(config.hidden_size, 2)
1026
 
1027
  # Initialize weights and apply final processing
1028
  self.post_init()
1029
 
 
1030
  def forward(
1031
  self,
1032
  input_ids: Optional[torch.LongTensor] = None,
1033
  attention_mask: Optional[torch.FloatTensor] = None,
1034
- position_ids: Optional[torch.LongTensor] = None,
1035
  head_mask: Optional[torch.FloatTensor] = None,
1036
  inputs_embeds: Optional[torch.FloatTensor] = None,
1037
  start_positions: Optional[torch.LongTensor] = None,
@@ -1055,7 +1216,6 @@ class RWForQuestionAnswering(RWPreTrainedModel):
1055
  outputs = self.transformer(
1056
  input_ids,
1057
  attention_mask=attention_mask,
1058
- position_ids=position_ids,
1059
  head_mask=head_mask,
1060
  inputs_embeds=inputs_embeds,
1061
  output_attentions=output_attentions,
 
1
+ # coding=utf-8
2
+ # Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Falcon model."""
16
 
17
  import math
 
18
  from typing import Optional, Tuple, Union
19
 
20
  import torch
 
22
  from torch import nn
23
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
24
  from torch.nn import functional as F
 
25
  from transformers.modeling_outputs import (
26
  BaseModelOutputWithPastAndCrossAttentions,
27
+ CausalLMOutputWithCrossAttentions, QuestionAnsweringModelOutput,
28
+ SequenceClassifierOutputWithPast, TokenClassifierOutput)
 
 
 
29
  from transformers.modeling_utils import PreTrainedModel
30
+ from transformers.utils import (add_code_sample_docstrings,
31
+ add_start_docstrings,
32
+ add_start_docstrings_to_model_forward, logging)
33
+
34
+ from .configuration_falcon import FalconConfig
35
 
36
  logger = logging.get_logger(__name__)
37
 
38
+ FALCON_PRETRAINED_MODEL_ARCHIVE_LIST = [
39
+ "tiiuae/falcon-40b",
40
+ "tiiuae/falcon-40b-instruct",
41
+ "tiiuae/falcon-7b",
42
+ "tiiuae/falcon-7b-instruct",
43
+ "tiiuae/falcon-rw-7b",
44
+ "tiiuae/falcon-rw-1b",
45
+ ]
46
+ _CHECKPOINT_FOR_DOC = "Rocketknight1/falcon-rw-1b"
47
+ _CONFIG_FOR_DOC = "FalconConfig"
48
+
49
+
50
  # NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
51
  # In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
52
+ class FalconLinear(nn.Linear):
53
  def forward(self, input: torch.Tensor) -> torch.Tensor:
54
+ hidden_states = input @ self.weight.T
55
  if self.bias is None:
56
+ return hidden_states
57
+ return hidden_states + self.bias
 
 
58
 
 
59
 
60
  # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
61
  def rotate_half(x):
62
  x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
63
+ return torch.cat((-x2, x1), dim=-1)
64
 
65
 
66
+ class FalconRotaryEmbedding(nn.Module):
67
  """Implementation of RotaryEmbedding from GPT-NeoX.
68
+ This implementation is designed to operate on queries and keys that are compatible with `[batch_size,
69
+ n_heads_per_partition, seq_len, head_dim]` (e.g. MinGPTAttention format).
70
  """
71
 
72
+ def __init__(self, head_dim: int, base=10000):
 
 
 
 
73
  super().__init__()
74
  inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
75
  self.register_buffer("inv_freq", inv_freq, persistent=False)
76
  self.head_dim = head_dim
77
+ self.seq_len_cached = -1
 
78
  self.cos_cached: torch.Tensor | None = None
79
  self.sin_cached: torch.Tensor | None = None
80
 
81
+ def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
82
+ total_length = seq_len + past_key_values_length
83
+ if total_length > self.seq_len_cached:
84
+ self.seq_len_cached = total_length
85
+ t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
 
 
 
 
86
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
87
  emb = torch.cat((freqs, freqs), dim=-1).to(device)
88
 
 
95
  self.cos_cached = self.cos_cached.type(dtype)
96
  self.sin_cached = self.sin_cached.type(dtype)
97
 
98
+ return (
99
+ self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length],
100
+ self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length],
101
+ )
102
 
103
+ def forward(self, query, key, past_key_values_length=0):
104
+ batch, seq_len, head_dim = query.shape
105
+ cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype)
106
+ return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
107
 
108
 
109
  def _make_causal_mask(
110
  input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
111
  ) -> torch.BoolTensor:
112
+ """
113
+ Make causal mask used for self-attention. This mask does not take the existing attention mask into account - it
114
+ just blocks tokens from attending forwards in the sequence. The output shape will be `[batch_size, 1,
115
+ target_length, target_length+past_key_values_length]`.
116
+ """
117
  batch_size, target_length = input_ids_shape
 
 
 
 
 
 
 
118
 
119
+ mask = torch.triu(torch.ones((target_length, target_length), dtype=torch.bool, device=device), diagonal=1)
120
+ # If past_key_values_length is 0 this is an empty tensor and the concatenation is a no-op.
121
+ # This code style is an unfortunate consequence of getting your TF engineer to port models; doing it this
122
+ # way avoids a data-dependent conditional, which will help me when I have to port this to XLA later.
123
+ past_mask = torch.zeros((target_length, past_key_values_length), dtype=torch.bool, device=device)
124
+ mask = torch.cat([past_mask, mask], dim=-1)
125
  expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
126
  return expanded_mask
127
 
128
 
129
+ def _expand_mask(mask: torch.Tensor, past_key_values_length: int) -> torch.BoolTensor:
130
+ """
131
+ Expands attention_mask from `[batch_size, seq_length]` to `[batch_size, 1, seq_length, seq_length + past_length]`.
132
+ """
133
+ batch_size, total_length = mask.shape
134
+ seq_length = total_length - past_key_values_length if past_key_values_length is not None else total_length
135
 
136
  expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
137
+ return expanded_mask.expand(batch_size, 1, seq_length, total_length)
138
 
139
 
140
  def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
 
165
  return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
166
 
167
 
168
+ # Copied from transformers.models.bloom.modeling_bloom.dropout_add
169
  def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
170
+ """
171
+ Dropout add function
172
+
173
+ Args:
174
+ x (`torch.tensor`, *required*):
175
+ input tensor
176
+ residual (`torch.tensor`, *required*):
177
+ residual tensor
178
+ prob (`float`, *required*):
179
+ dropout probability
180
+ training (`bool`, *required*):
181
+ training mode
182
+ """
183
  out = F.dropout(x, p=prob, training=training)
184
  out = residual + out
185
  return out
186
 
187
 
188
+ class FalconAttention(nn.Module):
189
+ def __init__(self, config: FalconConfig):
190
  super().__init__()
191
 
192
  self.hidden_size = config.hidden_size
193
+ self.num_heads = config.num_attention_heads
194
  self.head_dim = self.hidden_size // self.num_heads
195
  self.split_size = self.hidden_size
196
  self.hidden_dropout = config.hidden_dropout
 
201
  f" {self.num_heads})."
202
  )
203
 
204
+ self.maybe_rotary = FalconRotaryEmbedding(config.head_dim) if config.rotary else lambda q, k, t: (q, k)
205
 
206
  # Layer-wise attention scaling
207
  self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
208
  self.beta = self.inv_norm_factor
209
+ if config.new_decoder_architecture:
210
+ qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim
211
+ elif config.multi_query:
212
+ qkv_out_dim = self.hidden_size + 2 * self.head_dim
213
+ else:
214
+ qkv_out_dim = 3 * self.hidden_size
215
+ self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias)
216
+ self.new_decoder_architecture = config.new_decoder_architecture
217
  self.multi_query = config.multi_query
218
+ self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias)
219
  self.attention_dropout = nn.Dropout(config.attention_dropout)
220
+ self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
221
 
222
  def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
223
  """
224
+ Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv`
 
225
 
226
  Args:
227
  fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
 
230
  query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
231
  value: [batch_size, seq_length, num_heads, head_dim]
232
  """
233
+ if self.new_decoder_architecture:
234
+ batch, seq_len, _ = fused_qkv.shape
235
+ qkv = fused_qkv.view(batch, seq_len, -1, self.num_heads // self.num_kv_heads + 2, self.head_dim)
236
+ query = qkv[:, :, :, :-2]
237
+ key = qkv[:, :, :, [-2]]
238
+ value = qkv[:, :, :, [-1]]
239
+ key = torch.broadcast_to(key, query.shape)
240
+ value = torch.broadcast_to(value, query.shape)
241
+
242
+ query, key, value = [x.flatten(2, 3) for x in (query, key, value)]
243
+ return query, key, value
244
+ elif not self.multi_query:
245
  batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
246
  fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
247
  return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
 
250
  fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
251
  return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
252
 
253
+ # Copied from transformers.models.bloom.modeling_bloom.BloomAttention._merge_heads
254
  def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
255
  """
256
  Merge heads together over the last dimenstion
257
 
258
  Args:
259
+ x (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
260
 
261
  Returns:
262
  torch.tensor: [batch_size, seq_length, num_heads * head_dim]
 
279
  def forward(
280
  self,
281
  hidden_states: torch.Tensor,
282
+ alibi: Optional[torch.Tensor],
283
  attention_mask: torch.Tensor,
284
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
285
  head_mask: Optional[torch.Tensor] = None,
 
287
  output_attentions: bool = False,
288
  ):
289
  fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
290
+ num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
291
  # 3 x [batch_size, seq_length, num_heads, head_dim]
292
  (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
293
 
294
+ batch_size, query_length, _, _ = query_layer.shape
295
 
296
+ query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
297
  key_layer = key_layer.transpose(1, 2).reshape(
298
+ batch_size * num_kv_heads,
299
+ query_length,
300
  self.head_dim,
301
  )
302
+ value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
303
 
304
+ past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
305
+ query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
306
 
307
  if layer_past is not None:
308
  past_key, past_value = layer_past
309
  # concatenate along seq_length dimension:
310
+ # - key: [batch_size * self.num_heads, kv_length, head_dim]
311
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
312
  key_layer = torch.cat((past_key, key_layer), dim=1)
313
  value_layer = torch.cat((past_value, value_layer), dim=1)
314
 
315
  _, kv_length, _ = key_layer.shape
316
+ if use_cache:
 
317
  present = (key_layer, value_layer)
318
  else:
319
  present = None
320
 
321
+ attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
322
+
323
+ query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
324
+ key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
325
+ value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
326
+
327
  if alibi is None:
328
+ if output_attentions:
329
+ # F.scaled_dot_product_attention doesn't return the attention weights, so we have
330
+ # to do it by hand if we want them
331
+ attention_scores = query_layer_ @ key_layer_.transpose(-1, -2)
332
+ attention_scores /= math.sqrt(self.head_dim)
333
 
334
+ attention_scores = F.softmax(
335
+ attention_scores + attention_mask_float, dim=-1, dtype=hidden_states.dtype
336
+ )
337
+ attn_output = attention_scores @ value_layer_
338
+ else:
339
+ attn_output = F.scaled_dot_product_attention(
340
+ query_layer_, key_layer_, value_layer_, attention_mask_float, 0.0, is_causal=False
341
+ )
342
+ attention_scores = None
343
 
344
+ attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
345
+ attn_output = attn_output.permute(0, 2, 1, 3)
346
+ attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
347
 
348
  output_tensor = self.dense(attn_output)
349
 
350
+ if output_attentions:
351
+ return output_tensor, present, attention_scores
352
+ else:
353
+ return output_tensor, present
354
+
355
  else:
356
+ matmul_result = query_layer_ @ key_layer_.transpose(-1, -2)
 
357
 
358
  # change view to [batch_size, num_heads, q_length, kv_length]
359
+ attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
360
 
361
  # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
362
  input_dtype = attention_scores.dtype
363
  # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
364
  if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
365
  attention_scores = attention_scores.to(torch.float32)
366
+ # Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
367
+ # adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically
368
+ # equivalent and more performant, but there might be a numerical difference. If you're reading this
369
+ # and you'd like to experiment and maybe file a PR, feel free!
370
+ attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
371
+ attention_logits *= self.inv_norm_factor
372
+ attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype)
373
  # [batch_size, num_heads, q_length, kv_length]
374
  attention_probs = self.attention_dropout(attention_probs)
375
 
376
  if head_mask is not None:
377
  attention_probs = attention_probs * head_mask
378
 
379
+ # change view [batch_size, num_heads, q_length, kv_length]
380
+ attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
381
 
382
  # matmul: [batch_size * num_heads, q_length, head_dim]
383
+ context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1)
384
 
385
  # change view [batch_size, num_heads, q_length, head_dim]
386
  context_layer = self._merge_heads(context_layer)
387
 
388
  output_tensor = self.dense(context_layer)
389
 
 
390
  if output_attentions:
391
+ return output_tensor, present, attention_probs
392
+ else:
393
+ return output_tensor, present
394
 
395
 
396
+ class FalconMLP(nn.Module):
397
+ def __init__(self, config: FalconConfig):
398
  super().__init__()
399
  hidden_size = config.hidden_size
400
 
401
+ self.dense_h_to_4h = FalconLinear(hidden_size, 4 * hidden_size, bias=config.bias)
402
  self.act = nn.GELU()
403
+ self.dense_4h_to_h = FalconLinear(4 * hidden_size, hidden_size, bias=config.bias)
404
  self.hidden_dropout = config.hidden_dropout
405
 
406
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
409
  return x
410
 
411
 
412
+ class FalconDecoderLayer(nn.Module):
413
+ def __init__(self, config: FalconConfig):
414
  super().__init__()
415
  hidden_size = config.hidden_size
416
+ self.num_heads = config.num_attention_heads
417
+ self.self_attention = FalconAttention(config)
418
+ self.mlp = FalconMLP(config)
 
 
 
 
 
 
 
 
 
419
  self.hidden_dropout = config.hidden_dropout
 
420
  self.config = config
421
 
422
+ if config.new_decoder_architecture:
423
+ # The layer norm before self-attention
424
+ self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
425
+ # The layer norm before the MLP
426
+ self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
427
+ else:
428
+ self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
429
+ if not config.parallel_attn:
430
+ self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
431
+
432
  def forward(
433
  self,
434
  hidden_states: torch.Tensor,
435
+ alibi: Optional[torch.Tensor],
436
  attention_mask: torch.Tensor,
437
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
438
  head_mask: Optional[torch.Tensor] = None,
439
  use_cache: bool = False,
440
  output_attentions: bool = False,
441
  ):
 
 
442
  residual = hidden_states
443
 
444
+ if self.config.new_decoder_architecture:
445
+ attention_layernorm_out = self.ln_attn(hidden_states)
446
+ mlp_layernorm_out = self.ln_mlp(hidden_states)
447
+ else:
448
+ attention_layernorm_out = self.input_layernorm(hidden_states)
449
+
450
  # Self attention.
451
  attn_outputs = self.self_attention(
452
+ attention_layernorm_out,
453
  layer_past=layer_past,
454
  attention_mask=attention_mask,
455
  alibi=alibi,
 
460
 
461
  attention_output = attn_outputs[0]
462
 
463
+ if not self.config.new_decoder_architecture:
464
+ if self.config.parallel_attn:
465
+ mlp_layernorm_out = attention_layernorm_out
466
+ else:
467
+ residual = dropout_add(
468
+ attention_output, residual, self.config.attention_dropout, training=self.training
469
+ )
470
+ mlp_layernorm_out = self.post_attention_layernorm(residual)
471
 
472
  outputs = attn_outputs[1:]
473
 
474
  # MLP.
475
+ mlp_output = self.mlp(mlp_layernorm_out)
476
 
477
+ if self.config.new_decoder_architecture or self.config.parallel_attn:
478
  mlp_output += attention_output
479
 
480
  output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
 
487
  return outputs # hidden_states, present, attentions
488
 
489
 
490
+ FALCON_START_DOCSTRING = r"""
491
+
492
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
493
+ library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
494
+
495
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
496
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
497
+ and behavior.
498
+
499
+ Parameters:
500
+ config ([`FalconConfig`]): Model configuration class with all the parameters of the model.
501
+ Initializing with a config file does not load the weights associated with the model, only the
502
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
503
+ """
504
+
505
+ FALCON_INPUTS_DOCSTRING = r"""
506
+ Args:
507
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
508
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]`
509
+ (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
510
+
511
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
512
+ `input_ids`.
513
+
514
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
515
+ [`PreTrainedTokenizer.__call__`] for details.
516
+
517
+ [What are input IDs?](../glossary#input-ids)
518
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.num_hidden_layers`):
519
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
520
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
521
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
522
+
523
+ Each element of `past_key_values` is a tuple (past_key, past_value):
524
+ - past_key: [batch_size * num_heads, head_dim, kv_length]
525
+ - past_value: [batch_size * num_heads, kv_length, head_dim]
526
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
527
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
528
+
529
+ - 1 for tokens that are **not masked**,
530
+ - 0 for tokens that are **masked**.
531
+
532
+ [What are attention masks?](../glossary#attention-mask)
533
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
534
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
535
+
536
+ - 1 indicates the head is **not masked**,
537
+ - 0 indicates the head is **masked**.
538
+
539
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
540
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
541
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
542
+ model's internal embedding lookup matrix.
543
+
544
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
545
+ `past_key_values`).
546
+ use_cache (`bool`, *optional*):
547
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
548
+ `past_key_values`).
549
+ output_attentions (`bool`, *optional*):
550
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
551
+ tensors for more detail.
552
+ output_hidden_states (`bool`, *optional*):
553
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
554
+ more detail.
555
+ return_dict (`bool`, *optional*):
556
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
557
+ """
558
+
559
+
560
+ class FalconPreTrainedModel(PreTrainedModel):
561
  """
562
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
563
  models.
564
  """
565
 
566
+ config_class = FalconConfig
567
  base_model_prefix = "transformer"
568
  supports_gradient_checkpointing = True
569
+ _no_split_modules = ["FalconDecoderLayer"]
570
 
571
  def __init__(self, *inputs, **kwargs):
572
  super().__init__(*inputs, **kwargs)
573
 
574
  def _init_weights(self, module: nn.Module):
575
  """Initialize the weights."""
576
+ if isinstance(module, nn.Linear) or isinstance(module, FalconLinear):
577
  # Slightly different from the TF version which uses truncated_normal for initialization
578
  # cf https://github.com/pytorch/pytorch/pull/5617
579
  module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
 
587
  module.bias.data.zero_()
588
  module.weight.data.fill_(1.0)
589
 
590
+ # Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._set_gradient_checkpointing with BloomModel->FalconModel
591
  def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
592
+ if isinstance(module, FalconModel):
593
  module.gradient_checkpointing = value
594
 
595
  @staticmethod
596
+ def _convert_cache_to_standard_format(
597
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
598
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
599
  """
600
  Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
601
  num_heads, ...]))
602
  """
603
+ batch_size_times_num_heads, kv_length, head_dim = past_key_value[0][0].shape
604
+ # [batch_size * self.num_heads, kv_length, head_dim] -> [batch_size, num_heads, kv_length, head_dim]
605
+ # Note that don't want to use self.num_attention_heads because the number of heads may vary depending
606
+ # on whether we use multi_query attention.
607
  num_heads = batch_size_times_num_heads // batch_size
 
 
608
  return tuple(
609
  (
610
+ layer_past[0].view(batch_size, num_heads, kv_length, head_dim),
611
+ layer_past[1].view(batch_size, num_heads, kv_length, head_dim),
612
  )
613
  for layer_past in past_key_value
614
  )
 
617
  def _convert_to_rw_cache(
618
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
619
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
620
+ batch_size, num_heads, kv_length, head_dim = past_key_value[0][0].shape
621
  batch_size_times_num_heads = batch_size * num_heads
622
+ # [batch_size, num_heads, kv_length, head_dim] -> [batch_size * num_heads, kv_length, head_dim]
 
623
  return tuple(
624
  (
625
+ layer_past[0].view(batch_size_times_num_heads, kv_length, head_dim),
626
+ layer_past[1].view(batch_size_times_num_heads, kv_length, head_dim),
627
  )
628
  for layer_past in past_key_value
629
  )
630
 
631
 
632
+ @add_start_docstrings(
633
+ "The bare Falcon Model transformer outputting raw hidden-states without any specific head on top.",
634
+ FALCON_START_DOCSTRING,
635
+ )
636
+ class FalconModel(FalconPreTrainedModel):
637
+ def __init__(self, config: FalconConfig):
638
  super().__init__(config)
639
 
640
  self.embed_dim = config.hidden_size
641
+ self.num_heads = config.num_attention_heads
642
+ self.use_alibi = config.alibi
643
 
644
  # Embedding + LN Embedding
645
  self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
646
 
647
  # Transformer blocks
648
+ self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)])
649
 
650
  # Final Layer Norm
651
  self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
 
658
  def get_input_embeddings(self):
659
  return self.word_embeddings
660
 
661
+ @staticmethod
662
  def _prepare_attn_mask(
663
+ attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
664
  ) -> torch.BoolTensor:
665
+ # Create a causal mask
666
+ # The attention mask we receive as input should cover the whole extended sequence, including any past
667
+ # cache, so its shape should be [batch_size, seq_length + past_key_values_length]
668
+ # The output shape will be [batch_size, 1, seq_length, seq_length + past_key_values_length]
669
+ if input_shape[1] + past_key_values_length != attention_mask.shape[1]:
670
+ raise ValueError(
671
+ "Attention mask shape should be (batch_size, seq_length + past_key_values_length)"
672
+ f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length"
673
+ f" {past_key_values_length}."
674
+ )
675
  combined_attention_mask = None
676
  device = attention_mask.device
677
+ _, seq_length = input_shape
678
 
679
+ if seq_length > 1:
680
  combined_attention_mask = _make_causal_mask(
681
  input_shape, device=device, past_key_values_length=past_key_values_length
682
  )
683
 
684
+ # [batch_size, seq_length + past_key_values_length] -> [batch_size, 1, seq_length, seq_length + past_key_values_length]
685
+ expanded_attn_mask = _expand_mask(attention_mask, past_key_values_length=past_key_values_length)
686
  combined_attention_mask = (
687
  expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
688
  )
 
692
  def set_input_embeddings(self, new_embeddings: torch.Tensor):
693
  self.word_embeddings = new_embeddings
694
 
695
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
696
+ @add_code_sample_docstrings(
697
+ checkpoint=_CHECKPOINT_FOR_DOC,
698
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
699
+ config_class=_CONFIG_FOR_DOC,
700
+ )
701
  def forward(
702
  self,
703
  input_ids: Optional[torch.LongTensor] = None,
 
709
  output_attentions: Optional[bool] = None,
710
  output_hidden_states: Optional[bool] = None,
711
  return_dict: Optional[bool] = None,
 
712
  ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
 
 
 
 
 
 
 
 
 
 
713
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
714
  output_hidden_states = (
715
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
728
 
729
  if past_key_values is None:
730
  past_key_values = tuple([None] * len(self.h))
731
+ else:
732
+ past_key_values = self._convert_to_rw_cache(past_key_values)
733
 
734
  # Prepare head mask if needed
735
  # 1.0 in head_mask indicate we keep the head
736
  # attention_probs has shape batch_size x num_heads x N x N
737
  # head_mask has shape n_layer x batch x num_heads x N x N
738
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
739
 
740
  if inputs_embeds is None:
741
  inputs_embeds = self.word_embeddings(input_ids)
 
747
  all_hidden_states = () if output_hidden_states else None
748
 
749
  # Compute alibi tensor: check build_alibi_tensor documentation
 
750
  past_key_values_length = 0
751
  if past_key_values[0] is not None:
752
+ past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format
 
753
  if attention_mask is None:
754
+ attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device)
755
  else:
756
  attention_mask = attention_mask.to(hidden_states.device)
757
 
758
+ if self.use_alibi:
759
  alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
760
  else:
761
  alibi = None
 
767
  )
768
 
769
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
 
770
  if output_hidden_states:
771
  all_hidden_states = all_hidden_states + (hidden_states,)
772
 
773
  if self.gradient_checkpointing and self.training:
 
774
  if use_cache:
775
  logger.warning(
776
  "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
 
815
  if output_hidden_states:
816
  all_hidden_states = all_hidden_states + (hidden_states,)
817
 
818
+ if presents is not None:
819
+ presents = self._convert_cache_to_standard_format(presents, batch_size)
820
+
821
  if not return_dict:
822
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
823
 
 
829
  )
830
 
831
 
832
+ @add_start_docstrings(
833
+ "The Falcon Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).",
834
+ FALCON_START_DOCSTRING,
835
+ )
836
+ class FalconForCausalLM(FalconPreTrainedModel):
837
+ _tied_weights_keys = ["lm_head.weight"]
838
 
839
+ def __init__(self, config: FalconConfig):
840
  super().__init__(config)
841
+ self.transformer = FalconModel(config)
842
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
843
 
844
  # Initialize weights and apply final processing
 
853
  def prepare_inputs_for_generation(
854
  self,
855
  input_ids: torch.LongTensor,
856
+ past_key_values: Optional[torch.Tensor] = None,
857
  attention_mask: Optional[torch.Tensor] = None,
858
  **kwargs,
859
  ) -> dict:
860
+ if past_key_values is not None:
861
+ input_ids = input_ids[:, -1:]
 
 
 
 
 
862
 
863
  return {
864
  "input_ids": input_ids,
865
+ "past_key_values": past_key_values,
866
  "use_cache": kwargs.get("use_cache"),
867
  "attention_mask": attention_mask,
868
  }
869
 
870
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
871
+ @add_code_sample_docstrings(
872
+ checkpoint=_CHECKPOINT_FOR_DOC,
873
+ output_type=CausalLMOutputWithCrossAttentions,
874
+ config_class=_CONFIG_FOR_DOC,
875
+ )
876
  def forward(
877
  self,
878
  input_ids: Optional[torch.LongTensor] = None,
 
885
  output_attentions: Optional[bool] = None,
886
  output_hidden_states: Optional[bool] = None,
887
  return_dict: Optional[bool] = None,
 
888
  ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
889
  r"""
890
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 
892
  `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
893
  are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
894
  """
 
 
 
 
 
 
 
 
 
895
 
896
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
897
 
 
944
 
945
  Output shares the same memory storage as `past`.
946
  """
 
947
 
948
  # Get a copy of `beam_idx` on all the devices where we need those indices.
949
  device_to_beam_idx = {
 
954
  layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
955
  layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
956
  )
957
+ for layer_past in past
958
  )
959
+ return reordered_past
 
960
 
 
 
961
 
962
+ @add_start_docstrings(
963
+ """
964
+ The Falcon Model transformer with a sequence classification head on top (linear layer).
965
+
966
+ [`FalconForSequenceClassification`] uses the last token in order to do the classification, as other causal models
967
+ (e.g. GPT-1) do.
968
+
969
+ Since it does classification on the last token, it requires to know the position of the last token. If a
970
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
971
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
972
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
973
+ each row of the batch).
974
+ """,
975
+ FALCON_START_DOCSTRING,
976
+ )
977
+ class FalconForSequenceClassification(FalconPreTrainedModel):
978
+ def __init__(self, config: FalconConfig):
979
  super().__init__(config)
980
  self.num_labels = config.num_labels
981
+ self.transformer = FalconModel(config)
982
  self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
983
 
984
  # Initialize weights and apply final processing
985
  self.post_init()
986
 
987
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
988
+ @add_code_sample_docstrings(
989
+ checkpoint=_CHECKPOINT_FOR_DOC,
990
+ output_type=SequenceClassifierOutputWithPast,
991
+ config_class=_CONFIG_FOR_DOC,
992
+ )
993
  def forward(
994
  self,
995
  input_ids: Optional[torch.LongTensor] = None,
 
1002
  output_attentions: Optional[bool] = None,
1003
  output_hidden_states: Optional[bool] = None,
1004
  return_dict: Optional[bool] = None,
 
1005
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
1006
  r"""
1007
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
1009
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1010
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1011
  """
 
 
 
 
 
 
 
 
 
1012
 
1013
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1014
 
 
1083
  )
1084
 
1085
 
1086
+ @add_start_docstrings(
1087
+ """
1088
+ Falcon Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1089
+ Named-Entity-Recognition (NER) tasks.
1090
+ """,
1091
+ FALCON_START_DOCSTRING,
1092
+ )
1093
+ class FalconForTokenClassification(FalconPreTrainedModel):
1094
+ def __init__(self, config: FalconConfig):
1095
  super().__init__(config)
1096
  self.num_labels = config.num_labels
1097
 
1098
+ self.transformer = FalconModel(config)
1099
+ if getattr(config, "classifier_dropout", None) is not None:
1100
  classifier_dropout = config.classifier_dropout
1101
+ elif getattr(config, "hidden_dropout", None) is not None:
1102
  classifier_dropout = config.hidden_dropout
1103
  else:
1104
  classifier_dropout = 0.1
 
1108
  # Initialize weights and apply final processing
1109
  self.post_init()
1110
 
1111
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
1112
+ @add_code_sample_docstrings(
1113
+ checkpoint=_CHECKPOINT_FOR_DOC,
1114
+ output_type=TokenClassifierOutput,
1115
+ config_class=_CONFIG_FOR_DOC,
1116
+ )
1117
  def forward(
1118
  self,
1119
  input_ids: Optional[torch.LongTensor] = None,
 
1126
  output_attentions: Optional[bool] = None,
1127
  output_hidden_states: Optional[bool] = None,
1128
  return_dict: Optional[bool] = None,
 
1129
  ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1130
  r"""
1131
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
1133
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1134
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1135
  """
 
 
 
 
 
 
 
 
 
1136
 
1137
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1138
 
 
1156
  if labels is not None:
1157
  batch_size, seq_length = labels.shape
1158
  loss_fct = CrossEntropyLoss()
1159
+ loss = loss_fct(
1160
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1161
+ )
1162
 
1163
  if not return_dict:
1164
  output = (logits,) + transformer_outputs[2:]
 
1172
  )
1173
 
1174
 
1175
+ @add_start_docstrings(
1176
+ """
1177
+ The Falcon Model transformer with a span classification head on top for extractive question-answering tasks like
1178
+ SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1179
+ """,
1180
+ FALCON_START_DOCSTRING,
1181
+ )
1182
+ class FalconForQuestionAnswering(FalconPreTrainedModel):
1183
  def __init__(self, config):
1184
  super().__init__(config)
1185
+ self.transformer = FalconModel(config)
1186
  self.qa_outputs = nn.Linear(config.hidden_size, 2)
1187
 
1188
  # Initialize weights and apply final processing
1189
  self.post_init()
1190
 
1191
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
1192
  def forward(
1193
  self,
1194
  input_ids: Optional[torch.LongTensor] = None,
1195
  attention_mask: Optional[torch.FloatTensor] = None,
 
1196
  head_mask: Optional[torch.FloatTensor] = None,
1197
  inputs_embeds: Optional[torch.FloatTensor] = None,
1198
  start_positions: Optional[torch.LongTensor] = None,
 
1216
  outputs = self.transformer(
1217
  input_ids,
1218
  attention_mask=attention_mask,
 
1219
  head_mask=head_mask,
1220
  inputs_embeds=inputs_embeds,
1221
  output_attentions=output_attentions,