Lukas Möller commited on
Commit
4be435e
·
1 Parent(s): 82a66d3

add implementation files and modified tokenizer

Browse files
attention.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Attention layers."""
5
+
6
+ import math
7
+ import warnings
8
+ from typing import Optional
9
+
10
+ import torch
11
+ from einops import rearrange
12
+ from torch import nn
13
+
14
+ from .low_precision_layernorm import LPLayerNorm
15
+
16
+
17
+ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int,
18
+ original_is_causal: bool):
19
+ if original_is_causal and num_query_tokens != num_key_tokens:
20
+ if num_query_tokens != 1:
21
+ raise NotImplementedError(
22
+ 'ReplitLM does not support query and key with different number of tokens, unless number of query tokens is 1.'
23
+ )
24
+ else:
25
+ return False
26
+ return original_is_causal
27
+
28
+
29
+ def scaled_multihead_dot_product_attention(
30
+ query,
31
+ key,
32
+ value,
33
+ n_heads,
34
+ softmax_scale=None,
35
+ attn_bias=None,
36
+ key_padding_mask=None,
37
+ is_causal=False,
38
+ dropout_p=0.0,
39
+ training=False,
40
+ needs_weights=False,
41
+ ):
42
+
43
+ q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
44
+ k = rearrange(key, 'b s (h d) -> b h d s', h=n_heads) # includes key.t()
45
+ v = rearrange(value, 'b s (h d) -> b h s d', h=n_heads)
46
+
47
+ min_val = torch.finfo(q.dtype).min
48
+
49
+ b, _, s_q, d = q.shape
50
+ s_k = k.size(-1)
51
+
52
+ if softmax_scale is None:
53
+ softmax_scale = 1 / math.sqrt(d)
54
+
55
+ attn_weight = q.matmul(k) * softmax_scale
56
+
57
+ if attn_bias is not None:
58
+ if (attn_bias.size(-1) != 1 and
59
+ attn_bias.size(-1) != s_k) or (attn_bias.size(-2) != 1 and
60
+ attn_bias.size(-2) != s_q):
61
+ raise RuntimeError(
62
+ f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.'
63
+ )
64
+ attn_weight = attn_weight #+ attn_bias
65
+
66
+ if key_padding_mask is not None:
67
+ if attn_bias is not None:
68
+ warnings.warn(
69
+ 'Propogating key_padding_mask to the attention module ' +
70
+ 'and applying it within the attention module can cause ' +
71
+ 'unneccessary computation/memory usage. Consider integrating ' +
72
+ 'into attn_bias once and passing that to each attention ' +
73
+ 'module instead.'
74
+ )
75
+ attn_weight = attn_weight.masked_fill(
76
+ ~key_padding_mask.view((b, 1, 1, s_k)), min_val)
77
+
78
+ if is_causal:
79
+ s = max(s_q, s_k)
80
+ causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
81
+ causal_mask = causal_mask.tril()
82
+ causal_mask = causal_mask.to(torch.bool)
83
+ causal_mask = ~causal_mask
84
+ causal_mask = causal_mask[-s_q:, -s_k:]
85
+ attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k),
86
+ min_val)
87
+
88
+ attn_weight = torch.softmax(attn_weight, dim=-1)
89
+
90
+ if dropout_p:
91
+ attn_weight = torch.nn.functional.dropout(attn_weight,
92
+ p=dropout_p,
93
+ training=training,
94
+ inplace=True)
95
+
96
+ out = attn_weight.matmul(v)
97
+ out = rearrange(out, 'b h s d -> b s (h d)')
98
+
99
+ if needs_weights:
100
+ return out, attn_weight
101
+ return out, None
102
+
103
+
104
+ def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
105
+ for tensor in tensors:
106
+ if tensor.dtype not in valid_dtypes:
107
+ raise TypeError(f'{tensor.dtype=} must be in {valid_dtypes=}.')
108
+ if not tensor.is_cuda:
109
+ raise TypeError(
110
+ f'Inputs must be cuda tensors ({tensor.is_cuda=}).')
111
+
112
+
113
+ def flash_attn_fn(
114
+ query,
115
+ key,
116
+ value,
117
+ n_heads,
118
+ softmax_scale=None,
119
+ attn_bias=None,
120
+ key_padding_mask=None,
121
+ is_causal=False,
122
+ dropout_p=0.0,
123
+ training=False,
124
+ needs_weights=False,
125
+ ):
126
+ try:
127
+ from flash_attn import bert_padding, flash_attn_interface
128
+ except:
129
+ raise RuntimeError('Please install flash_attn==0.2.8')
130
+
131
+ check_valid_inputs(query, key, value)
132
+
133
+ if attn_bias is not None:
134
+ raise NotImplementedError(f'attn_bias not implemented for flash attn.')
135
+
136
+ batch_size, seqlen = query.shape[:2]
137
+
138
+ if key_padding_mask is None:
139
+ key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
140
+ query_padding_mask = key_padding_mask[:, -query.size(1):]
141
+
142
+ query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input(
143
+ query, query_padding_mask)
144
+ query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
145
+
146
+ key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input(
147
+ key, key_padding_mask)
148
+ key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
149
+
150
+ value_unpad, _, _, _ = bert_padding.unpad_input(value, key_padding_mask)
151
+ value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
152
+
153
+ dropout_p = dropout_p if training else 0.0
154
+
155
+ reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
156
+
157
+ output_unpad = flash_attn_interface.flash_attn_unpadded_func(
158
+ query_unpad,
159
+ key_unpad,
160
+ value_unpad,
161
+ cu_seqlens_q,
162
+ cu_seqlens_k,
163
+ max_seqlen_q,
164
+ max_seqlen_k,
165
+ dropout_p,
166
+ softmax_scale=softmax_scale,
167
+ causal=reset_is_causal,
168
+ return_attn_probs=needs_weights)
169
+
170
+ output = bert_padding.pad_input(
171
+ rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size,
172
+ seqlen)
173
+ return output, None
174
+
175
+
176
+ def triton_flash_attn_fn(
177
+ query,
178
+ key,
179
+ value,
180
+ n_heads,
181
+ softmax_scale=None,
182
+ attn_bias=None,
183
+ key_padding_mask=None,
184
+ is_causal=False,
185
+ dropout_p=0.0,
186
+ training=False,
187
+ needs_weights=False,
188
+ ):
189
+ try:
190
+ from flash_attn import flash_attn_triton # type: ignore
191
+ except:
192
+ raise RuntimeError(
193
+ 'Please install flash_attn==0.2.8 and triton==2.0.0.dev20221202.')
194
+
195
+ check_valid_inputs(query, key, value)
196
+
197
+ if dropout_p:
198
+ raise NotImplementedError(
199
+ f'Dropout not implemented for attn_impl: triton.')
200
+
201
+ if needs_weights:
202
+ raise NotImplementedError(
203
+ f'attn_impl: triton cannot return attn weights.')
204
+
205
+ if key_padding_mask is not None:
206
+ warnings.warn(
207
+ 'Propagating key_padding_mask to the attention module ' +
208
+ 'and applying it within the attention module can cause ' +
209
+ 'unnecessary computation/memory usage. Consider integrating ' +
210
+ 'into attn_bias once and passing that to each attention ' +
211
+ 'module instead.'
212
+ )
213
+ b_size, s_k = key_padding_mask.shape[:2]
214
+
215
+ if attn_bias is None:
216
+ attn_bias = query.new_zeros(b_size, 1, 1, s_k)
217
+
218
+ attn_bias = attn_bias.masked_fill(
219
+ ~key_padding_mask.view((b_size, 1, 1, s_k)),
220
+ torch.finfo(query.dtype).min)
221
+
222
+ query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
223
+ key = rearrange(key, 'b s (h d) -> b s h d', h=n_heads)
224
+ value = rearrange(value, 'b s (h d) -> b s h d', h=n_heads)
225
+
226
+ reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
227
+ attn_output = flash_attn_triton.flash_attn_func(query, key, value,
228
+ attn_bias, reset_is_causal,
229
+ softmax_scale)
230
+
231
+ output = attn_output.view(*attn_output.shape[:2], -1)
232
+
233
+ return output, None
234
+
235
+
236
+ class MultiheadAttention(nn.Module):
237
+ """Multi-head self attention.
238
+
239
+ Using torch or triton attention implemetation enables user to also use
240
+ additive bias.
241
+ """
242
+
243
+ def __init__(
244
+ self,
245
+ d_model: int,
246
+ n_heads: int,
247
+ attn_impl: str = 'triton',
248
+ attn_clip_qkv: Optional[float] = None,
249
+ attn_qk_ln: bool = False,
250
+ softmax_scale: Optional[float] = None,
251
+ attn_pdrop: float = 0.0,
252
+ low_precision_layernorm: bool = False,
253
+ device: Optional[str] = None,
254
+ ):
255
+ super().__init__()
256
+
257
+ self.attn_impl = attn_impl
258
+ self.clip_qkv = attn_clip_qkv
259
+ self.attn_qk_ln = attn_qk_ln
260
+
261
+ self.d_model = d_model
262
+ self.n_heads = n_heads
263
+ self.softmax_scale = softmax_scale
264
+ if self.softmax_scale is None:
265
+ self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
266
+ self.attn_dropout_p = attn_pdrop
267
+
268
+ self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device)
269
+ # for param init fn; enables shape based init of fused layers
270
+ fuse_splits = (d_model, 2 * d_model)
271
+ self.Wqkv._fused = (0, fuse_splits) # type: ignore
272
+
273
+ if self.attn_qk_ln:
274
+ layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
275
+ self.q_ln = layernorm_class(self.d_model, device=device)
276
+ self.k_ln = layernorm_class(self.d_model, device=device)
277
+
278
+ if self.attn_impl == 'flash':
279
+ self.attn_fn = flash_attn_fn
280
+ elif self.attn_impl == 'triton':
281
+ self.attn_fn = triton_flash_attn_fn
282
+ warnings.warn(
283
+ 'While `attn_impl: triton` can be faster than `attn_impl: flash` ' +
284
+ 'it uses more memory. When training larger models this can trigger ' +
285
+ 'alloc retries which hurts performance. If encountered, we recommend ' +
286
+ 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
287
+ elif self.attn_impl == 'torch':
288
+ self.attn_fn = scaled_multihead_dot_product_attention
289
+ if torch.cuda.is_available():
290
+ warnings.warn(
291
+ 'Using `attn_impl: torch`. If your model does not use `alibi` or ' +
292
+ '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' +
293
+ 'we recommend using `attn_impl: triton`.'
294
+ )
295
+ else:
296
+ raise ValueError(f'{attn_impl=} is an invalid setting.')
297
+
298
+ self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
299
+ self.out_proj._is_residual = True # type: ignore
300
+
301
+ def forward(self,
302
+ x,
303
+ past_key_value=None,
304
+ attn_bias=None,
305
+ attention_mask=None,
306
+ is_causal=True,
307
+ needs_weights=False):
308
+ qkv = self.Wqkv(x)
309
+
310
+ if self.clip_qkv:
311
+ qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
312
+
313
+ query, key, value = qkv.chunk(3, dim=2)
314
+
315
+ key_padding_mask = attention_mask
316
+
317
+ if self.attn_qk_ln:
318
+ # Applying layernorm to qk
319
+ dtype = query.dtype
320
+ query = self.q_ln(query).to(dtype)
321
+ key = self.k_ln(key).to(dtype)
322
+
323
+ if past_key_value is not None:
324
+ if len(past_key_value) != 0:
325
+ key = torch.cat([past_key_value[0], key], dim=1)
326
+ value = torch.cat([past_key_value[1], value], dim=1)
327
+
328
+ past_key_value = (key, value)
329
+
330
+ if attn_bias is not None:
331
+ attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
332
+
333
+ context, attn_weights = self.attn_fn(
334
+ query,
335
+ key,
336
+ value,
337
+ self.n_heads,
338
+ softmax_scale=self.softmax_scale,
339
+ attn_bias=attn_bias,
340
+ key_padding_mask=key_padding_mask,
341
+ is_causal=is_causal,
342
+ dropout_p=self.attn_dropout_p,
343
+ training=self.training,
344
+ needs_weights=needs_weights,
345
+ )
346
+
347
+ return self.out_proj(context), attn_weights, past_key_value
348
+
349
+
350
+ def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal,
351
+ use_sequence_id):
352
+ if attn_impl == 'flash':
353
+ return None
354
+ elif attn_impl in ['torch', 'triton']:
355
+ if alibi:
356
+ if (prefix_lm or not causal) or use_sequence_id:
357
+ return (1, n_heads, seq_len, seq_len)
358
+ return (1, n_heads, 1, seq_len)
359
+ elif prefix_lm or use_sequence_id:
360
+ return (1, 1, seq_len, seq_len)
361
+ return None
362
+ else:
363
+ raise ValueError(f'{attn_impl=} is an invalid setting.')
364
+
365
+
366
+ def attn_bias(attn_impl,
367
+ attn_bias,
368
+ n_heads,
369
+ seq_len,
370
+ causal=False,
371
+ alibi=False,
372
+ alibi_bias_max=8):
373
+ if attn_impl == 'flash':
374
+ return None
375
+ elif attn_impl in ['torch', 'triton']:
376
+ if alibi:
377
+ # in place add alibi to attn bias
378
+ device, dtype = attn_bias.device, attn_bias.dtype
379
+ attn_bias = attn_bias.add(
380
+ alibi_bias(n_heads,
381
+ alibi_bias_max=alibi_bias_max,
382
+ device=device,
383
+ dtype=dtype))
384
+ return attn_bias
385
+ else:
386
+ raise ValueError(f'{attn_impl=} is an invalid setting.')
387
+
388
+
389
+ def alibi_bias(n_heads,
390
+ alibi_bias_max=8,
391
+ device=None,
392
+ dtype=None):
393
+ seq_len = 2048
394
+ alibi_bias = torch.arange(1 - seq_len, 1, dtype=dtype,
395
+ device=device).view(1, 1, 1, seq_len)
396
+ m = torch.arange(1, n_heads + 1, dtype=dtype, device=device)
397
+ m = m.mul(alibi_bias_max / n_heads)
398
+ alibi_bias = alibi_bias * (1. / (2**m.view(1, n_heads, 1, 1)))
399
+ return alibi_bias
gpt_blocks.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """GPT Blocks used for the GPT Model."""
5
+
6
+ from typing import Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from .attention import MultiheadAttention
12
+ from .low_precision_layernorm import LPLayerNorm
13
+
14
+
15
+ class GPTMLP(nn.Module):
16
+
17
+ def __init__(self,
18
+ d_model: int,
19
+ mlp_ratio: int,
20
+ device: Optional[str] = None):
21
+ super().__init__()
22
+ self.mlp_up = nn.Linear(d_model, mlp_ratio * d_model, device=device)
23
+ self.mlp_act = nn.GELU(approximate='none')
24
+ self.mlp_down = nn.Linear(mlp_ratio * d_model, d_model, device=device)
25
+ self.mlp_down._is_residual = True # type: ignore
26
+
27
+ def forward(self, x):
28
+ return self.mlp_down(self.mlp_act(self.mlp_up(x)))
29
+
30
+
31
+ class GPTBlock(nn.Module):
32
+
33
+ def __init__(self,
34
+ attn_impl: str,
35
+ d_model: int,
36
+ n_heads: int,
37
+ mlp_ratio: int,
38
+ attn_clip_qkv: Optional[float] = None,
39
+ attn_qk_ln: bool = False,
40
+ softmax_scale: Optional[float] = None,
41
+ attn_pdrop: float = 0.0,
42
+ alibi: bool = False,
43
+ resid_pdrop: float = 0.0,
44
+ low_precision_layernorm: bool = False,
45
+ device: Optional[str] = None,
46
+ **kwargs):
47
+ del kwargs # unused, just to capture any extra args from the config
48
+ super().__init__()
49
+
50
+ layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
51
+
52
+ self.ln_1 = layernorm_class(d_model, device=device)
53
+ self.attn = MultiheadAttention(
54
+ attn_impl=attn_impl,
55
+ attn_clip_qkv=attn_clip_qkv,
56
+ attn_qk_ln=attn_qk_ln,
57
+ softmax_scale=softmax_scale,
58
+ attn_pdrop=attn_pdrop,
59
+ d_model=d_model,
60
+ n_heads=n_heads,
61
+ device=device,
62
+ )
63
+ self.ln_2 = layernorm_class(d_model, device=device)
64
+ self.mlp = GPTMLP(
65
+ d_model=d_model,
66
+ mlp_ratio=mlp_ratio,
67
+ device=device,
68
+ )
69
+ self.resid_attn_dropout = nn.Dropout(resid_pdrop)
70
+ self.resid_mlp_dropout = nn.Dropout(resid_pdrop)
71
+
72
+ def forward(
73
+ self,
74
+ inp_l: torch.Tensor,
75
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
76
+ attn_bias: Optional[torch.Tensor] = None,
77
+ attention_mask: Optional[torch.ByteTensor] = None,
78
+ is_causal: bool = True,
79
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
80
+ cur = self.ln_1(inp_l)
81
+ cur, _, past_key_value = self.attn(cur,
82
+ past_key_value=past_key_value,
83
+ attn_bias=attn_bias,
84
+ attention_mask=attention_mask,
85
+ is_causal=is_causal)
86
+ inp_l = inp_l + cur
87
+ cur = self.ln_2(inp_l)
88
+ cur = self.mlp(cur)
89
+ inp_l = inp_l + cur
90
+ return inp_l, past_key_value
low_precision_layernorm.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class LPLayerNorm(torch.nn.LayerNorm):
6
+ def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
7
+ super().__init__(
8
+ normalized_shape=normalized_shape,
9
+ eps=eps,
10
+ elementwise_affine=elementwise_affine,
11
+ device=device,
12
+ dtype=dtype,
13
+ )
14
+
15
+ def forward(self, x):
16
+ module_device = x.device
17
+ downcast_x = _cast_if_autocast_enabled(x)
18
+ downcast_weight = _cast_if_autocast_enabled(
19
+ self.weight) if self.weight is not None else self.weight
20
+ downcast_bias = _cast_if_autocast_enabled(
21
+ self.bias) if self.bias is not None else self.bias
22
+ with torch.autocast(enabled=False, device_type=module_device.type):
23
+ return F.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
24
+
25
+
26
+ def _cast_if_autocast_enabled(tensor):
27
+ if torch.is_autocast_enabled():
28
+ if tensor.device.type == 'cuda':
29
+ dtype = torch.get_autocast_gpu_dtype()
30
+ elif tensor.device.type == 'cpu':
31
+ dtype = torch.get_autocast_cpu_dtype()
32
+ else:
33
+ raise NotImplementedError()
34
+ return tensor.to(dtype=dtype)
35
+ return tensor
param_init_fns.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import math
4
+ import warnings
5
+ from collections.abc import Sequence
6
+ from functools import partial
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+
13
+ def torch_default_param_init_fn_(
14
+ module: nn.Module,
15
+ verbose: int = 0,
16
+ **kwargs,
17
+ ):
18
+ del kwargs # unused, just to capture any extra args from the config
19
+ if verbose > 1:
20
+ warnings.warn(
21
+ f"Initializing network using module's reset_parameters attribute")
22
+
23
+ if hasattr(module, 'reset_parameters'):
24
+ module.reset_parameters() # type: ignore
25
+
26
+
27
+ def fused_init_helper_(module: nn.Module, init_fn_):
28
+ # parameter initialization is often based on the parameters shape.
29
+ # If a layer is fused, initialization should be based on the shapes
30
+ # of the original tensor instead of the shape of the fused tensor.
31
+ # Layers which are fused should have the _fused attibute defined.
32
+ # The first element of _fused is the dimension along which the tensor is fused.
33
+ # This is followed by an iterable of split indices."
34
+
35
+ _fused = getattr(module, '_fused', None)
36
+
37
+ if _fused is None:
38
+ raise RuntimeError(f'Internal logic error')
39
+
40
+ dim, splits = _fused
41
+ splits = (0, *splits, module.weight.size(dim)) # type: ignore
42
+ for s, e in zip(splits[:-1], splits[1:]):
43
+ slice_indices = [slice(None)] * module.weight.ndim # type: ignore
44
+ slice_indices[dim] = slice(s, e)
45
+ init_fn_(module.weight[slice_indices]) # type: ignore
46
+
47
+
48
+ def generic_param_init_fn_(
49
+ module: nn.Module,
50
+ init_fn_,
51
+ n_layers: int,
52
+ d_model: Optional[int] = None,
53
+ init_div_is_residual: Union[int, float, str, bool] = True,
54
+ emb_init_std: Optional[float] = None,
55
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
56
+ verbose: int = 0,
57
+ **kwargs,
58
+ ):
59
+ del kwargs # unused, just to capture any extra args from the config
60
+ if verbose > 1:
61
+ warnings.warn(
62
+ f'If model has bias parameters they are initialized to 0.')
63
+
64
+ # enable user to divide _is_residual weights by
65
+ # a value which defaults to math.sqrt(2 * cfg.n_layers)
66
+ init_div_is_residual = init_div_is_residual
67
+
68
+ if init_div_is_residual is False:
69
+ # not used, for pyright
70
+ div_is_residual = 1.0
71
+ elif init_div_is_residual is True:
72
+ div_is_residual = math.sqrt(2 * n_layers)
73
+ elif isinstance(init_div_is_residual, float) or isinstance(
74
+ init_div_is_residual, int):
75
+ div_is_residual = init_div_is_residual
76
+ elif isinstance(init_div_is_residual,
77
+ str) and init_div_is_residual.isnumeric():
78
+ # do not trust YAML parsing to always convert numbers to numbers
79
+ div_is_residual = float(init_div_is_residual)
80
+ else:
81
+ # not used, for pyright
82
+ div_is_residual = 1.0
83
+ raise ValueError(
84
+ f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}'
85
+ )
86
+
87
+ if init_div_is_residual is not False:
88
+ if verbose > 1:
89
+ warnings.warn(
90
+ f'Initializing _is_residual layers then dividing them by {div_is_residual}.' +
91
+ f'set `init_div_is_residual: false` in model config to disable this.'
92
+ )
93
+
94
+ if isinstance(module, nn.Linear):
95
+ # Linear
96
+ if hasattr(module, '_fused'):
97
+ fused_init_helper_(module, init_fn_)
98
+ else:
99
+ init_fn_(module.weight)
100
+ if module.bias is not None:
101
+ torch.nn.init.zeros_(module.bias)
102
+
103
+ if init_div_is_residual is not False and getattr(
104
+ module, '_is_residual', False):
105
+ with torch.no_grad():
106
+ module.weight.div_(div_is_residual)
107
+
108
+ elif isinstance(module, nn.Embedding):
109
+ # Embedding
110
+ if emb_init_std is not None:
111
+ std = emb_init_std
112
+ if std == 0:
113
+ warnings.warn(f'Embedding layer initialized to 0.')
114
+ emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
115
+ if verbose > 1:
116
+ warnings.warn(
117
+ f'Embedding layer initialized using normal distribution with mean=0 and {std=}.'
118
+ )
119
+ elif emb_init_uniform_lim is not None:
120
+ lim = emb_init_uniform_lim
121
+ if isinstance(lim, Sequence):
122
+ if len(lim) > 2:
123
+ raise ValueError(
124
+ f'Uniform init requires a min and a max limit. User input: {lim}.'
125
+ )
126
+ if lim[0] == lim[1]:
127
+ warnings.warn(f'Embedding layer initialized to {lim[0]}.')
128
+ else:
129
+ if lim == 0:
130
+ warnings.warn(f'Embedding layer initialized to 0.')
131
+ lim = [-lim, lim]
132
+ a, b = lim
133
+ emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
134
+ if verbose > 1:
135
+ warnings.warn(
136
+ f'Embedding layer initialized using uniform distribution in range {lim}.'
137
+ )
138
+ else:
139
+ emb_init_fn_ = init_fn_
140
+
141
+ emb_init_fn_(module.weight)
142
+
143
+ elif isinstance(module, nn.LayerNorm):
144
+ # LayerNorm
145
+ if verbose > 1:
146
+ warnings.warn(
147
+ f'LayerNorm gamma weights are set to 1. If the layer has a bias it is initialized to 0.'
148
+ )
149
+ torch.nn.init.ones_(module.weight)
150
+ if module.bias is not None:
151
+ torch.nn.init.zeros_(module.bias)
152
+
153
+ elif isinstance(module, nn.MultiheadAttention):
154
+ # torch's MultiheadAttention
155
+ if module._qkv_same_embed_dim:
156
+ assert module.in_proj_weight is not None
157
+ assert module.q_proj_weight is None and module.k_proj_weight is None and module.v_proj_weight is None
158
+ assert d_model is not None
159
+ # in_proj_weight is actually 3 layers and should be split up for width based init
160
+ _d = d_model
161
+ splits = (0, _d, 2 * _d, 3 * _d)
162
+ for s, e in zip(splits[:-1], splits[1:]):
163
+ init_fn_(module.in_proj_weight[s:e])
164
+ else:
165
+ assert module.q_proj_weight is not None and module.k_proj_weight is not None and module.v_proj_weight is not None
166
+ assert module.in_proj_weight is None
167
+ init_fn_(module.q_proj_weight)
168
+ init_fn_(module.k_proj_weight)
169
+ init_fn_(module.v_proj_weight)
170
+
171
+ # bias
172
+ if module.in_proj_bias is not None:
173
+ torch.nn.init.zeros_(module.in_proj_bias)
174
+ if module.bias_k is not None:
175
+ torch.nn.init.zeros_(module.bias_k)
176
+ if module.bias_v is not None:
177
+ torch.nn.init.zeros_(module.bias_v)
178
+
179
+ # out proj
180
+ init_fn_(module.out_proj.weight)
181
+ if init_div_is_residual is not False and getattr(
182
+ module.out_proj, '_is_residual', False):
183
+ with torch.no_grad():
184
+ module.out_proj.weight.div_(div_is_residual)
185
+ if module.out_proj.bias is not None:
186
+ torch.nn.init.zeros_(module.out_proj.bias)
187
+
188
+ else:
189
+ for _ in module.parameters(recurse=False):
190
+ # raise error if uninitialized module has any parameters
191
+ raise NotImplementedError(
192
+ f'{module.__class__.__name__} parameters are not initialized by param_init_fn.'
193
+ )
194
+
195
+
196
+ def _normal_init_(std, mean=0.0):
197
+ return partial(torch.nn.init.normal_, mean=mean, std=std)
198
+
199
+
200
+ def _normal_param_init_fn_(
201
+ module: nn.Module,
202
+ std: float,
203
+ n_layers: int,
204
+ d_model: Optional[int] = None,
205
+ init_div_is_residual: Union[int, float, str, bool] = True,
206
+ emb_init_std: Optional[float] = None,
207
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
208
+ verbose: int = 0,
209
+ **kwargs,
210
+ ):
211
+ del kwargs # unused, just to capture any extra args from the config
212
+ init_fn_ = _normal_init_(std=std)
213
+
214
+ if verbose > 1:
215
+ warnings.warn(
216
+ f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}')
217
+
218
+ generic_param_init_fn_(
219
+ module=module,
220
+ init_fn_=init_fn_,
221
+ d_model=d_model,
222
+ n_layers=n_layers,
223
+ init_div_is_residual=init_div_is_residual,
224
+ emb_init_std=emb_init_std,
225
+ emb_init_uniform_lim=emb_init_uniform_lim,
226
+ verbose=verbose,
227
+ )
228
+
229
+
230
+ def baseline_param_init_fn_(
231
+ module: nn.Module,
232
+ init_std: float,
233
+ n_layers: int,
234
+ d_model: Optional[int] = None,
235
+ init_div_is_residual: Union[int, float, str, bool] = True,
236
+ emb_init_std: Optional[float] = None,
237
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
238
+ verbose: int = 0,
239
+ **kwargs,
240
+ ):
241
+ del kwargs # unused, just to capture any extra args from the config
242
+ if init_std is None:
243
+ raise ValueError(
244
+ 'You must set model.init_std to a float value to use the default initialization scheme.'
245
+ )
246
+ _normal_param_init_fn_(
247
+ module=module,
248
+ std=init_std,
249
+ d_model=d_model,
250
+ n_layers=n_layers,
251
+ init_div_is_residual=init_div_is_residual,
252
+ emb_init_std=emb_init_std,
253
+ emb_init_uniform_lim=emb_init_uniform_lim,
254
+ verbose=verbose,
255
+ )
256
+
257
+
258
+ def small_param_init_fn_(
259
+ module: nn.Module,
260
+ n_layers: int,
261
+ d_model: int,
262
+ init_div_is_residual: Union[int, float, str, bool] = True,
263
+ emb_init_std: Optional[float] = None,
264
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
265
+ verbose: int = 0,
266
+ **kwargs,
267
+ ):
268
+ del kwargs # unused, just to capture any extra args from the config
269
+ # very close to kaiming normal
270
+ # from Transformers without Tears (2019) - Nguyen & Salazar
271
+ std = math.sqrt(2 / (5 * d_model))
272
+ _normal_param_init_fn_(
273
+ module=module,
274
+ std=std,
275
+ d_model=d_model,
276
+ n_layers=n_layers,
277
+ init_div_is_residual=init_div_is_residual,
278
+ emb_init_std=emb_init_std,
279
+ emb_init_uniform_lim=emb_init_uniform_lim,
280
+ verbose=verbose,
281
+ )
282
+
283
+
284
+ def neox_param_init_fn_(
285
+ module: nn.Module,
286
+ n_layers: int,
287
+ d_model: int,
288
+ emb_init_std: Optional[float] = None,
289
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
290
+ verbose: int = 0,
291
+ **kwargs,
292
+ ):
293
+ """From section 2.3.1 of GPT-NeoX-20B:
294
+
295
+ An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
296
+ see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151
297
+ and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py
298
+ """
299
+ del kwargs # unused, just to capture any extra args from the config
300
+ residual_div = n_layers / math.sqrt(10) # small std / wang std
301
+
302
+ if verbose > 1:
303
+ warnings.warn(f'setting init_div_is_residual to {residual_div}')
304
+
305
+ small_param_init_fn_(
306
+ module=module,
307
+ d_model=d_model,
308
+ n_layers=n_layers,
309
+ init_div_is_residual=residual_div,
310
+ emb_init_std=emb_init_std,
311
+ emb_init_uniform_lim=emb_init_uniform_lim,
312
+ verbose=verbose,
313
+ )
314
+
315
+
316
+ def kaiming_uniform_param_init_fn_(
317
+ module: nn.Module,
318
+ n_layers: int,
319
+ d_model: Optional[int] = None,
320
+ init_div_is_residual: Union[int, float, str, bool] = True,
321
+ emb_init_std: Optional[float] = None,
322
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
323
+ init_gain: float = 0,
324
+ fan_mode: str = 'fan_in',
325
+ init_nonlinearity: str = 'leaky_relu',
326
+ verbose: int = 0,
327
+ **kwargs,
328
+ ):
329
+ del kwargs # unused, just to capture any extra args from the config
330
+
331
+ if verbose > 1:
332
+ warnings.warn(
333
+ f'Using nn.init.kaiming_uniform_ init fn with parameters: ' +
334
+ f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}'
335
+ )
336
+
337
+ kaiming_uniform_ = partial(nn.init.kaiming_uniform_,
338
+ a=init_gain,
339
+ mode=fan_mode,
340
+ nonlinearity=init_nonlinearity)
341
+
342
+ generic_param_init_fn_(
343
+ module=module,
344
+ init_fn_=kaiming_uniform_,
345
+ d_model=d_model,
346
+ n_layers=n_layers,
347
+ init_div_is_residual=init_div_is_residual,
348
+ emb_init_std=emb_init_std,
349
+ emb_init_uniform_lim=emb_init_uniform_lim,
350
+ verbose=verbose,
351
+ )
352
+
353
+
354
+ def kaiming_normal_param_init_fn_(
355
+ module: nn.Module,
356
+ n_layers: int,
357
+ d_model: Optional[int] = None,
358
+ init_div_is_residual: Union[int, float, str, bool] = True,
359
+ emb_init_std: Optional[float] = None,
360
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
361
+ init_gain: float = 0,
362
+ fan_mode: str = 'fan_in',
363
+ init_nonlinearity: str = 'leaky_relu',
364
+ verbose: int = 0,
365
+ **kwargs,
366
+ ):
367
+ del kwargs # unused, just to capture any extra args from the config
368
+
369
+ if verbose > 1:
370
+ warnings.warn(
371
+ f'Using nn.init.kaiming_normal_ init fn with parameters: ' +
372
+ f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}'
373
+ )
374
+
375
+ kaiming_normal_ = partial(torch.nn.init.kaiming_normal_,
376
+ a=init_gain,
377
+ mode=fan_mode,
378
+ nonlinearity=init_nonlinearity)
379
+
380
+ generic_param_init_fn_(
381
+ module=module,
382
+ init_fn_=kaiming_normal_,
383
+ d_model=d_model,
384
+ n_layers=n_layers,
385
+ init_div_is_residual=init_div_is_residual,
386
+ emb_init_std=emb_init_std,
387
+ emb_init_uniform_lim=emb_init_uniform_lim,
388
+ verbose=verbose,
389
+ )
390
+
391
+
392
+ def xavier_uniform_param_init_fn_(
393
+ module: nn.Module,
394
+ n_layers: int,
395
+ d_model: Optional[int] = None,
396
+ init_div_is_residual: Union[int, float, str, bool] = True,
397
+ emb_init_std: Optional[float] = None,
398
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
399
+ init_gain: float = 0,
400
+ verbose: int = 0,
401
+ **kwargs,
402
+ ):
403
+ del kwargs # unused, just to capture any extra args from the config
404
+ xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
405
+
406
+ if verbose > 1:
407
+ warnings.warn(
408
+ f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' +
409
+ f'gain={init_gain}'
410
+ )
411
+
412
+ generic_param_init_fn_(
413
+ module=module,
414
+ init_fn_=xavier_uniform_,
415
+ d_model=d_model,
416
+ n_layers=n_layers,
417
+ init_div_is_residual=init_div_is_residual,
418
+ emb_init_std=emb_init_std,
419
+ emb_init_uniform_lim=emb_init_uniform_lim,
420
+ verbose=verbose,
421
+ )
422
+
423
+
424
+ def xavier_normal_param_init_fn_(
425
+ module: nn.Module,
426
+ n_layers: int,
427
+ d_model: Optional[int] = None,
428
+ init_div_is_residual: Union[int, float, str, bool] = True,
429
+ emb_init_std: Optional[float] = None,
430
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
431
+ init_gain: float = 0,
432
+ verbose: int = 0,
433
+ **kwargs,
434
+ ):
435
+ xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
436
+
437
+ if verbose > 1:
438
+ warnings.warn(
439
+ f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' +
440
+ f'gain={init_gain}'
441
+ )
442
+
443
+ generic_param_init_fn_(
444
+ module=module,
445
+ init_fn_=xavier_normal_,
446
+ d_model=d_model,
447
+ n_layers=n_layers,
448
+ init_div_is_residual=init_div_is_residual,
449
+ emb_init_std=emb_init_std,
450
+ emb_init_uniform_lim=emb_init_uniform_lim,
451
+ verbose=verbose,
452
+ )
453
+
454
+
455
+ MODEL_INIT_REGISTRY = {
456
+ 'default_': torch_default_param_init_fn_,
457
+ 'baseline_': baseline_param_init_fn_,
458
+ 'kaiming_uniform_': kaiming_uniform_param_init_fn_,
459
+ 'kaiming_normal_': kaiming_normal_param_init_fn_,
460
+ 'neox_init_': neox_param_init_fn_,
461
+ 'small_init_': small_param_init_fn_,
462
+ 'xavier_uniform_': xavier_uniform_param_init_fn_,
463
+ 'xavier_normal_': xavier_normal_param_init_fn_,
464
+ }
replit_lm.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Forked from the MosaicGPT model class from the Mosaic Examples codebase of date May 1st, 2023.
5
+ Permalink: https://github.com/mosaicml/examples/blob/52cd4fef69497f225a034fcd10692f8613732d10/examples/llm/src/models/mosaic_gpt/mosaic_gpt.py
6
+ """
7
+
8
+ """A simple, flexible implementation of a GPT model.
9
+
10
+ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
11
+ """
12
+
13
+ import math
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import warnings
18
+
19
+ from transformers import PreTrainedModel
20
+ from transformers.modeling_outputs import CausalLMOutputWithPast
21
+ from typing import List, Optional, Tuple
22
+
23
+ from .attention import attn_bias as module_attn_bias, attn_bias_shape as module_attn_bias_shape
24
+ from .gpt_blocks import GPTBlock
25
+ from .configuration_replit_lm import \
26
+ ReplitLMConfig
27
+ from .param_init_fns import MODEL_INIT_REGISTRY
28
+ from .low_precision_layernorm import LPLayerNorm
29
+
30
+
31
+ class ReplitLM(PreTrainedModel):
32
+ config_class = ReplitLMConfig
33
+ base_model_prefix = 'replit_lm'
34
+
35
+ def __init__(self, config: ReplitLMConfig):
36
+ super().__init__(config)
37
+
38
+ if config.attn_impl == 'flash' and config.alibi:
39
+ raise RuntimeError("ALiBi is not supported with flash attention. Please use triton or torch.")
40
+
41
+ self.attn_impl = config.attn_impl
42
+ self.prefix_lm = config.prefix_lm
43
+ self.attn_uses_sequence_id = config.attn_uses_sequence_id
44
+ self.alibi = config.alibi
45
+ self.alibi_bias_max = config.alibi_bias_max
46
+
47
+ layernorm_class = LPLayerNorm if config.low_precision_layernorm else nn.LayerNorm
48
+
49
+ # CogView (https://arxiv.org/abs/2105.13290) and GLM-130B (https://arxiv.org/abs/2210.02414)
50
+ # both report this helping with stabilizing training
51
+ self.embedding_fraction = config.embedding_fraction
52
+
53
+ self.transformer = nn.ModuleDict({
54
+ 'wte':
55
+ nn.Embedding(config.vocab_size,
56
+ config.d_model,
57
+ device=config.init_device)
58
+ })
59
+ if not self.alibi:
60
+ self.transformer.update({
61
+ 'wpe':
62
+ nn.Embedding(config.max_seq_len,
63
+ config.d_model,
64
+ device=config.init_device)
65
+ })
66
+ self.transformer.update({'emb_drop': nn.Dropout(config.emb_pdrop)})
67
+ self.transformer.update({
68
+ 'blocks':
69
+ nn.ModuleList([
70
+ GPTBlock(device=config.init_device,
71
+ **config.to_dict())
72
+ for _ in range(config.n_layers)
73
+ ])
74
+ })
75
+ self.transformer.update({
76
+ 'ln_f': layernorm_class(config.d_model, device=config.init_device)
77
+ })
78
+
79
+ # enables scaling output logits; similar to a softmax "temperature"
80
+ # PaLM paper uses scale 1/sqrt(config.d_model)
81
+ self.logit_scale = None
82
+ if config.logit_scale is not None:
83
+ logit_scale = config.logit_scale
84
+ if isinstance(logit_scale, str):
85
+ if logit_scale == 'inv_sqrt_d_model':
86
+ logit_scale = 1 / math.sqrt(config.d_model)
87
+ else:
88
+ raise ValueError(
89
+ f"{logit_scale=} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
90
+ )
91
+ self.logit_scale = logit_scale
92
+
93
+ if config.init_device != 'meta':
94
+ print(
95
+ f'You are using {config.init_device=}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.'
96
+ )
97
+ self.apply(self.param_init_fn)
98
+
99
+ self.is_causal = not self.prefix_lm
100
+
101
+ # define attn mask
102
+ self._attn_bias_initialized = False
103
+ self.attn_bias = None
104
+ self.attn_bias_shape = module_attn_bias_shape(
105
+ self.attn_impl,
106
+ config.n_heads,
107
+ config.max_seq_len,
108
+ self.alibi,
109
+ prefix_lm=self.prefix_lm,
110
+ causal=self.is_causal,
111
+ use_sequence_id=self.attn_uses_sequence_id)
112
+
113
+ if config.no_bias:
114
+ for module in self.modules():
115
+ if hasattr(module, 'bias') and isinstance(
116
+ module.bias, nn.Parameter):
117
+ if config.verbose:
118
+ print(f'Removing bias ({module.bias}) from {module}.')
119
+ module.register_parameter('bias', None)
120
+
121
+ if config.verbose and config.verbose > 2:
122
+ print(self)
123
+
124
+ @torch.no_grad()
125
+ def _attn_bias(self,
126
+ device,
127
+ dtype,
128
+ attention_mask: Optional[torch.ByteTensor] = None,
129
+ prefix_mask: Optional[torch.ByteTensor] = None,
130
+ sequence_id: Optional[torch.LongTensor] = None):
131
+ if not self._attn_bias_initialized:
132
+ if self.attn_bias_shape:
133
+
134
+ self.attn_bias = torch.zeros(self.attn_bias_shape,
135
+ device=device,
136
+ dtype=dtype)
137
+ self.attn_bias = module_attn_bias(
138
+ self.attn_impl,
139
+ self.attn_bias,
140
+ self.config.n_heads,
141
+ self.config.max_seq_len,
142
+ causal=self.is_causal,
143
+ alibi=self.alibi,
144
+ alibi_bias_max=self.alibi_bias_max)
145
+
146
+ self._attn_bias_initialized = True
147
+
148
+ # flash does not support prefix_lm and will incorporate any
149
+ # attention_mask inside the attention module
150
+ if self.attn_impl == 'flash':
151
+ return self.attn_bias, attention_mask
152
+
153
+ attn_bias = self.attn_bias
154
+
155
+ # If using torch or triton, we incorporate attention_mask. This will output
156
+ # None in place of attention_mask since it will not be further needed in the
157
+ # attention modules.
158
+ if attention_mask is not None:
159
+ s_k = attention_mask.shape[-1]
160
+ if attn_bias is None:
161
+ attn_bias = torch.zeros((1, 1, 1, s_k),
162
+ device=device,
163
+ dtype=dtype)
164
+ else:
165
+ attn_bias = attn_bias[:, :, :, -s_k:]
166
+ if prefix_mask is not None and (attention_mask.shape !=
167
+ prefix_mask.shape):
168
+ raise ValueError(
169
+ f'attention_mask shape={attention_mask.shape} ' +\
170
+ f'and prefix_mask shape={prefix_mask.shape} are not equal.'
171
+ )
172
+ min_val = torch.finfo(attn_bias.dtype).min
173
+ attn_bias = attn_bias.masked_fill(
174
+ ~attention_mask.view(-1, 1, 1, s_k), min_val)
175
+
176
+
177
+ return attn_bias, None
178
+
179
+ def _apply_prefix_mask(self, attn_bias: torch.Tensor,
180
+ prefix_mask: torch.Tensor):
181
+ s_k, s_q = attn_bias.shape[-2:]
182
+ if (s_k != self.config.max_seq_len) or (s_q != self.config.max_seq_len):
183
+ raise ValueError(
184
+ 'attn_bias does not match the expected shape. ' +\
185
+ f'The last two dimensions should both be {self.config.max_length} ' +\
186
+ f'but are {s_k} and {s_q}.'
187
+ )
188
+ seq_len = prefix_mask.shape[-1]
189
+ if seq_len > self.config.max_seq_len:
190
+ raise ValueError(
191
+ f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}'
192
+ )
193
+
194
+ # select seq_len subset of attn mask
195
+ attn_bias = attn_bias[..., :seq_len, :seq_len]
196
+
197
+ # Mix the causal max and the bidirectional mask to get the full
198
+ # allowable attention (i.e. full = not accounting for padding yet)
199
+ causal = torch.tril(
200
+ torch.ones((seq_len, seq_len),
201
+ dtype=torch.bool,
202
+ device=prefix_mask.device)).view(1, 1, seq_len, seq_len)
203
+ prefix = prefix_mask.view(-1, 1, 1, seq_len)
204
+ cannot_attend = ~torch.logical_or(causal, prefix.bool())
205
+
206
+ min_val = torch.finfo(attn_bias.dtype).min
207
+ attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
208
+
209
+ return attn_bias
210
+
211
+ def _apply_sequence_id(self, attn_bias: torch.Tensor,
212
+ sequence_id: torch.LongTensor):
213
+ seq_len = sequence_id.shape[-1]
214
+ if seq_len > self.config.max_seq_len:
215
+ raise ValueError(
216
+ f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}'
217
+ )
218
+
219
+ # select seq_len subset of attn mask
220
+ attn_bias = attn_bias[..., :seq_len, :seq_len]
221
+
222
+ # Restrict attention to tokens that share the same value
223
+ # in sequence_id
224
+ cannot_attend = torch.logical_not(
225
+ torch.eq(sequence_id.view(-1, seq_len, 1),
226
+ sequence_id.view(-1, 1, seq_len))).unsqueeze(1)
227
+ min_val = torch.finfo(attn_bias.dtype).min
228
+ attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
229
+
230
+ return attn_bias
231
+
232
+ def forward(
233
+ self,
234
+ input_ids: torch.LongTensor,
235
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
236
+ attention_mask: Optional[torch.ByteTensor] = None,
237
+ prefix_mask: Optional[torch.ByteTensor] = None,
238
+ sequence_id: Optional[torch.LongTensor] = None,
239
+ return_dict: Optional[bool] = None,
240
+ output_attentions: Optional[bool] = None,
241
+ output_hidden_states: Optional[bool] = None,
242
+ use_cache: Optional[bool] = None):
243
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
244
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
245
+
246
+ # These args are passed in by keyword in huggingface's generate function
247
+ # https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/generation/utils.py#L2201-L2206
248
+ # but have not yet been fully implemented in ReplitLM
249
+ if not return_dict:
250
+ raise NotImplementedError(
251
+ 'return_dict False is not implemented yet for ReplitLM')
252
+ if output_attentions:
253
+ raise NotImplementedError(
254
+ 'output_attentions is not implemented yet for ReplitLM')
255
+
256
+ if attention_mask is not None and attention_mask[:, 0].sum(
257
+ ) != attention_mask.shape[0] and self.training:
258
+ raise NotImplementedError(
259
+ 'ReplitLM does not support training with left padding.')
260
+
261
+ if self.prefix_lm and prefix_mask is None:
262
+ raise ValueError(
263
+ 'prefix_mask is a required argument when ReplitLM is configured with prefix_lm=True.'
264
+ )
265
+
266
+ if self.training:
267
+ if self.attn_uses_sequence_id and sequence_id is None:
268
+ raise ValueError(
269
+ 'sequence_id is a required argument when ReplitLM is configured with attn_uses_sequence_id=True ' +\
270
+ 'and the model is in train mode.'
271
+ )
272
+ elif (self.attn_uses_sequence_id is False) and (sequence_id
273
+ is not None):
274
+ warnings.warn(
275
+ 'ReplitLM received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' +\
276
+ 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.'
277
+ )
278
+
279
+ S = input_ids.size(1)
280
+
281
+ assert (
282
+ S <= self.config.max_seq_len
283
+ ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
284
+
285
+ tok_emb = self.transformer.wte(input_ids) # type: ignore
286
+
287
+ # print({
288
+ # "tok_emb": tok_emb.tolist(),
289
+ # })
290
+ if self.alibi:
291
+ x = tok_emb
292
+ else:
293
+ past_position = 0
294
+ if past_key_values is not None:
295
+ if len(past_key_values) != self.config.n_layers:
296
+ raise ValueError(
297
+ f'past_key_values must provide a past_key_value for each attention ' +\
298
+ f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).'
299
+ )
300
+ # get the key tensor whose spec should be (batch, seq, dim), and
301
+ # collect the `seq`, so that the position embedding is shifted
302
+ past_position = past_key_values[0][0].size(1)
303
+
304
+ if S + past_position > self.config.max_seq_len:
305
+ raise ValueError(
306
+ f'Cannot forward input with past sequence length {past_position} and current sequence length '
307
+ f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.'
308
+ )
309
+ pos = torch.arange(past_position,
310
+ S + past_position,
311
+ dtype=torch.long,
312
+ device=input_ids.device).unsqueeze(0)
313
+ if attention_mask is not None:
314
+ # adjust the position indices to account for padding tokens
315
+ pos = torch.clamp(pos - torch.cumsum(
316
+ (~attention_mask).to(torch.int32), dim=1)[:,
317
+ past_position:],
318
+ min=0)
319
+
320
+ pos_emb = self.transformer.wpe(pos) # type: ignore
321
+ x = tok_emb + pos_emb
322
+
323
+ if self.embedding_fraction == 1:
324
+ x = self.transformer.emb_drop(x) # type: ignore
325
+ else:
326
+ # this implementation is proposed on page 7 of the GLM-130B paper https://arxiv.org/abs/2210.02414
327
+ x_shrunk = (x * self.embedding_fraction) + (
328
+ x.detach() * (1 - self.embedding_fraction))
329
+ assert isinstance(self.transformer.emb_drop, nn.Module) # pyright
330
+ x = self.transformer.emb_drop(x_shrunk)
331
+
332
+ attn_bias, attention_mask = self._attn_bias(
333
+ device=x.device,
334
+ dtype=x.dtype,
335
+ attention_mask=attention_mask,
336
+ prefix_mask=prefix_mask,
337
+ sequence_id=sequence_id)
338
+
339
+ # print({
340
+ # "attn_bias": attn_bias.tolist() if attn_bias is not None else None,
341
+ # "attention_mask": attention_mask.tolist() if attention_mask is not None else None,
342
+ # })
343
+
344
+ # initialize the past key values cache if it should be used
345
+ if use_cache and past_key_values is None:
346
+ past_key_values = [() for _ in range(self.config.n_layers)
347
+ ] # type: ignore
348
+
349
+ all_hidden_states = () if output_hidden_states else None
350
+ for b_idx, block in enumerate(self.transformer.blocks): # type: ignore
351
+ if output_hidden_states:
352
+ assert all_hidden_states is not None # pyright
353
+ all_hidden_states = all_hidden_states + (x,)
354
+ past_key_value = past_key_values[
355
+ b_idx] if past_key_values is not None else None
356
+ # print({
357
+ # "x_before": x.tolist(),
358
+ # })
359
+ x, past_key_value = block(x,
360
+ past_key_value=past_key_value,
361
+ attn_bias=attn_bias,
362
+ attention_mask=attention_mask,
363
+ is_causal=self.is_causal)
364
+ # print({
365
+ # "x_after": x.tolist(),
366
+ # })
367
+ if past_key_values is not None:
368
+ past_key_values[b_idx] = past_key_value
369
+
370
+ x = self.transformer.ln_f(x) # type: ignore
371
+
372
+ #print({
373
+ # "x": x.tolist(),
374
+ #})
375
+
376
+ # output embedding weight tied to input embedding
377
+ assert isinstance(self.transformer.wte, nn.Module) # pyright
378
+ assert isinstance(self.transformer.wte.weight, torch.Tensor) # pyright
379
+ logits = F.linear(x, self.transformer.wte.weight, None)
380
+
381
+
382
+ # print({
383
+ # "logits": logits.tolist(),
384
+ # })
385
+
386
+ if self.logit_scale is not None:
387
+ if self.logit_scale == 0:
388
+ warnings.warn(
389
+ f'Multiplying logits by {self.logit_scale=}. This will produce uniform (uninformative) outputs.'
390
+ )
391
+ logits *= self.logit_scale
392
+
393
+ return CausalLMOutputWithPast(logits=logits,
394
+ past_key_values=past_key_values,
395
+ hidden_states=all_hidden_states)
396
+
397
+ # Param Initialization, needed for device='meta' fast initialization
398
+ def param_init_fn(self, module):
399
+ init_fn_name = self.config.param_init_fn
400
+ if self.config.verbose > 1:
401
+ warnings.warn(f'Using {init_fn_name} initialization.')
402
+ MODEL_INIT_REGISTRY[init_fn_name](module=module,
403
+ **self.config.to_dict())
404
+
405
+ # FSDP Wrap function
406
+ def fsdp_wrap_fn(self, module):
407
+ return isinstance(module, GPTBlock)
408
+
409
+ # Activation Checkpointing
410
+ def activation_checkpointing_fn(self, module):
411
+ return isinstance(module, GPTBlock)
412
+
413
+ def prepare_inputs_for_generation(self,
414
+ input_ids,
415
+ past_key_values=None,
416
+ inputs_embeds=None,
417
+ **kwargs):
418
+ if inputs_embeds is not None:
419
+ raise NotImplementedError(
420
+ 'inputs_embeds is not implemented for ReplitLM yet')
421
+
422
+ attention_mask = kwargs['attention_mask'].bool()
423
+ if attention_mask[:, -1].sum() != attention_mask.shape[0]:
424
+ raise NotImplementedError(
425
+ 'ReplitLM does not support generation with right padding.')
426
+
427
+ if self.attn_uses_sequence_id and self.training:
428
+ sequence_id = torch.zeros_like(input_ids[:1])
429
+ else:
430
+ sequence_id = None
431
+
432
+ if past_key_values is not None:
433
+ input_ids = input_ids[:, -1].unsqueeze(-1)
434
+
435
+ if self.prefix_lm:
436
+ # Leverage a convenience of sequential generation!
437
+ prefix_mask = torch.ones_like(attention_mask)
438
+ # This requires that we're using the cache
439
+ if kwargs.get('use_cache') == False:
440
+ raise NotImplementedError(
441
+ 'ReplitLM with prefix_lm=True does not support use_cache=False.'
442
+ )
443
+ else:
444
+ prefix_mask = None
445
+
446
+ return {
447
+ 'input_ids': input_ids,
448
+ 'attention_mask': attention_mask,
449
+ 'prefix_mask': prefix_mask,
450
+ 'sequence_id': sequence_id,
451
+ 'past_key_values': past_key_values,
452
+ 'use_cache': kwargs.get('use_cache', True),
453
+ }
454
+
455
+ @staticmethod
456
+ def _reorder_cache(past_key_values, beam_idx):
457
+ """Used by HuggingFace generate when using beam search with kv-caching.
458
+
459
+ See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
460
+ for an example in transformers.
461
+ """
462
+ reordered_past = []
463
+ for layer_past in past_key_values:
464
+ reordered_past += [
465
+ tuple(
466
+ past_state.index_select(0, beam_idx)
467
+ for past_state in layer_past)
468
+ ]
469
+ return reordered_past
replit_lm_tokenizer.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Forked from the file src/transformers/models/bert_generation/tokenization_bert_generation.py from the HuggingFace Transformers library.
17
+ Permalink: https://github.com/huggingface/transformers/blob/04ab5605fbb4ef207b10bf2772d88c53fc242e83/src/transformers/models/bert_generation/tokenization_bert_generation.py
18
+
19
+ Class is modified for compatibility with custom vocabulary and to achieve desired encode/decode behavior for Replit Code v1.3b model.
20
+ """
21
+
22
+ """ Tokenizer class for ReplitLM"""
23
+
24
+
25
+ import os
26
+ import sentencepiece as spm
27
+ from shutil import copyfile
28
+ from transformers import PreTrainedTokenizer
29
+ from typing import Any, Dict, List, Optional, Tuple
30
+ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
31
+
32
+
33
+ class ReplitLMTokenizer(PreTrainedTokenizer):
34
+ """
35
+ Construct a ReplitLMTokenizer tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
36
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods.
37
+
38
+ Args:
39
+ vocab_file (`str`):
40
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
41
+ contains the vocabulary necessary to instantiate a tokenizer.
42
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
43
+ The end of sequence token.
44
+ bos_token (`str`, *optional*, defaults to `None`):
45
+ The begin of sequence token.
46
+ unk_token (`str`, *optional*, defaults to `"<|unk|>"`):
47
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
48
+ token instead.
49
+ pad_token (`str`, *optional*, defaults to `"<|pad|>"`):
50
+ The token used for padding, for example when batching sequences of different lengths.
51
+ sp_model_kwargs (`dict`, *optional*):
52
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
53
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
54
+ to set:
55
+ - `enable_sampling`: Enable subword regularization.
56
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
57
+ - `nbest_size = {0,1}`: No sampling is performed.
58
+ - `nbest_size > 1`: samples from the nbest_size results.
59
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
60
+ using forward-filtering-and-backward-sampling algorithm.
61
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
62
+ BPE-dropout.
63
+ """
64
+
65
+ vocab_files_names = VOCAB_FILES_NAMES
66
+ prefix_tokens: List[int] = []
67
+ model_input_names = ["input_ids", "attention_mask"]
68
+
69
+ def __init__(
70
+ self,
71
+ vocab_file,
72
+ bos_token=None,
73
+ eos_token="<|endoftext|>",
74
+ unk_token="<|unk|>",
75
+ pad_token="<|pad|>",
76
+ sep_token=None,
77
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
78
+ **kwargs,
79
+ ) -> None:
80
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
81
+
82
+ # Add extra_ids to the special token list
83
+ super().__init__(
84
+ bos_token=bos_token,
85
+ eos_token=eos_token,
86
+ unk_token=unk_token,
87
+ pad_token=pad_token,
88
+ sep_token=sep_token,
89
+ sp_model_kwargs=self.sp_model_kwargs,
90
+ **kwargs,
91
+ )
92
+
93
+ self.vocab_file = vocab_file
94
+
95
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
96
+ self.sp_model.Load(vocab_file)
97
+
98
+ @property
99
+ def vocab_size(self):
100
+ return self.sp_model.get_piece_size()
101
+
102
+ def get_vocab(self):
103
+ vocab = {self.convert_ids_to_tokens(
104
+ i): i for i in range(self.vocab_size)}
105
+ vocab.update(self.added_tokens_encoder)
106
+ return vocab
107
+
108
+ def __getstate__(self):
109
+ state = self.__dict__.copy()
110
+ state["sp_model"] = None
111
+ return state
112
+
113
+ def __setstate__(self, d):
114
+ self.__dict__ = d
115
+
116
+ # for backward compatibility
117
+ if not hasattr(self, "sp_model_kwargs"):
118
+ self.sp_model_kwargs = {}
119
+
120
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
121
+ self.sp_model.load(self.vocab_file)
122
+
123
+ def _tokenize(self, text: str) -> List[str]:
124
+ """Take as input a string and return a list of strings (tokens) for words/sub-words"""
125
+ return self.sp_model.encode(text, out_type=str)
126
+
127
+ def _convert_token_to_id(self, token):
128
+ """Converts a token (str) in an id using the vocab."""
129
+ return self.sp_model.piece_to_id(token)
130
+
131
+ def _convert_id_to_token(self, index):
132
+ """Converts an index (integer) in a token (str) using the vocab."""
133
+ token = self.sp_model.id_to_piece(index)
134
+ return token
135
+
136
+ def convert_tokens_to_string(self, tokens):
137
+ """Converts a sequence of tokens (string) in a single string."""
138
+ return self.sp_model.decode(tokens)
139
+
140
+ def save_vocabulary(self,
141
+ save_directory: str,
142
+ filename_prefix: Optional[str] = None) -> Tuple[str]:
143
+
144
+ if not os.path.isdir(save_directory):
145
+ raise ValueError(
146
+ f"Vocabulary path ({save_directory}) should be a directory")
147
+
148
+ out_vocab_file = os.path.join(
149
+ save_directory, (filename_prefix + "-" if filename_prefix else "") +
150
+ VOCAB_FILES_NAMES["vocab_file"])
151
+
152
+ if os.path.abspath(
153
+ self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(
154
+ self.vocab_file):
155
+ copyfile(self.vocab_file, out_vocab_file)
156
+ elif not os.path.isfile(self.vocab_file):
157
+ with open(out_vocab_file, "wb") as fi:
158
+ content_spiece_model = self.sp_model.serialized_model_proto()
159
+ fi.write(content_spiece_model)
160
+
161
+ return (out_vocab_file, )
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "eos_token": "<|endoftext|>",
3
+ "pad_token": "<|pad|>",
4
+ "unk_token": "<|unk|>"
5
+ }
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34b064403defb5e56f9fdc4e8e0847d6395439bca206e74e419c53259fe47f02
3
+ size 707655
tokenizer_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoTokenizer": [
4
+ "replit_lm_tokenizer.ReplitLMTokenizer",
5
+ null
6
+ ]
7
+ },
8
+ "bos_token": null,
9
+ "clean_up_tokenization_spaces": false,
10
+ "eos_token": "<|endoftext|>",
11
+ "model_max_length": 2048,
12
+ "pad_token": "<|pad|>",
13
+ "padding_side": "right",
14
+ "sep_token": null,
15
+ "sp_model_kwargs": {},
16
+ "tokenizer_class": "ReplitLMTokenizer",
17
+ "unk_token": "<|unk|>"
18
+ }