Norquinal commited on
Commit
cd0221e
·
1 Parent(s): f7d4f8c

Upload model

Browse files
cache.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Together
2
+ # This software is distributed under the terms of the Apache License, Version 2.0
3
+ # Author: Michael Poli
4
+
5
+ from torch import Tensor
6
+ from dataclasses import dataclass, field
7
+ from typing import Optional
8
+
9
+
10
+ # https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py
11
+ @dataclass
12
+ class InferenceParams:
13
+ """Inference parameters that are passed to the main model in order
14
+ to efficienly calculate and store the context during inference."""
15
+
16
+ max_seqlen: int
17
+ max_batch_size: int
18
+ seqlen_offset: int = 0
19
+ batch_size_offset: int = 0
20
+ key_value_memory_dict: dict = field(default_factory=dict)
21
+ lengths_per_sample: Optional[Tensor] = None
22
+
23
+ def reset(self, max_seqlen, max_batch_size):
24
+ self.max_seqlen = max_seqlen
25
+ self.max_batch_size = max_batch_size
26
+ self.seqlen_offset = 0
27
+ if self.lengths_per_sample is not None:
28
+ self.lengths_per_sample.zero_()
29
+
30
+
31
+ @dataclass
32
+ class RecurrentInferenceParams:
33
+ """Inference parameters passed to blocks with recurrent mode."""
34
+
35
+ fir_filter_length: int = 3
36
+ state_dim: int = 16
37
+ seqlen_offset: int = 0
38
+ fir_state_dict: dict = field(default_factory=dict)
39
+ state_dict: dict = field(default_factory=dict)
40
+
41
+ def reset(self):
42
+ self.fir_filter_length = 3
43
+ self.state_dim = 16
44
+ self.seqlen_offset = 0
config.json ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": null,
3
+ "_name_or_path": "togethercomputer_StripedHyena-Hessian-7B",
4
+ "architectures": [
5
+ "StripedHyenaModelForCausalLM"
6
+ ],
7
+ "attn_layer_idxs": [
8
+ 1,
9
+ 3,
10
+ 5,
11
+ 7,
12
+ 9,
13
+ 11,
14
+ 13,
15
+ 15,
16
+ 17,
17
+ 19,
18
+ 21,
19
+ 23,
20
+ 25,
21
+ 27,
22
+ 29,
23
+ 31
24
+ ],
25
+ "auto_map": {
26
+ "AutoConfig": "configuration_hyena.StripedHyenaConfig",
27
+ "AutoModelForCausalLM": "modeling_hyena.StripedHyenaModelForCausalLM"
28
+ },
29
+ "column_split": false,
30
+ "column_split_hyena": true,
31
+ "eps": 1e-05,
32
+ "final_norm": true,
33
+ "hidden_size": 4096,
34
+ "hyena_filter_groups": 1,
35
+ "hyena_layer_idxs": [
36
+ 0,
37
+ 2,
38
+ 4,
39
+ 6,
40
+ 8,
41
+ 10,
42
+ 12,
43
+ 14,
44
+ 16,
45
+ 18,
46
+ 20,
47
+ 22,
48
+ 24,
49
+ 26,
50
+ 28,
51
+ 30,
52
+ 32
53
+ ],
54
+ "inference_mode": false,
55
+ "inner_mlp_size": 14336,
56
+ "log_intermediate_values": false,
57
+ "make_vocab_size_divisible_by": 8,
58
+ "max_seqlen": 32768,
59
+ "mha_out_proj_bias": false,
60
+ "model_parallel_size": 1,
61
+ "model_type": "stripedhyena",
62
+ "num_attention_heads": 32,
63
+ "num_filters": 4096,
64
+ "num_layers": 32,
65
+ "pipe_parallel_size": 1,
66
+ "prefill_style": "fft",
67
+ "proj_groups": 4,
68
+ "qkv_proj_bias": false,
69
+ "rotary_emb_base": 500000,
70
+ "short_filter_bias": true,
71
+ "short_filter_length": 3,
72
+ "smeared_gqa": false,
73
+ "split_k0": true,
74
+ "state_size": 2,
75
+ "tie_embeddings": false,
76
+ "torch_dtype": "bfloat16",
77
+ "transformers_version": null,
78
+ "use_cache": true,
79
+ "use_flash_attention_2": true,
80
+ "use_flash_depthwise": false,
81
+ "use_flash_rmsnorm": true,
82
+ "use_flashfft": false,
83
+ "vocab_size": 32000
84
+ }
configuration_hyena.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ import json
3
+
4
+
5
+ class StripedHyenaConfig(PretrainedConfig):
6
+ model_type = "stripedhyena"
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size=32000,
11
+ hidden_size=4096,
12
+ num_filters=4096,
13
+ inner_mlp_size=14336,
14
+ attn_layer_idxs=[],
15
+ hyena_layer_idxs=[],
16
+ num_layers=32,
17
+ tie_embeddings=False,
18
+ short_filter_length=3,
19
+ num_attention_heads=32,
20
+ proj_groups=4,
21
+ hyena_filter_groups=1,
22
+ split_k0=True,
23
+ column_split_hyena=True,
24
+ column_split=False,
25
+ model_parallel_size=1,
26
+ pipe_parallel_size=1,
27
+ short_filter_bias=True,
28
+ mha_out_proj_bias=False,
29
+ qkv_proj_bias=False,
30
+ final_norm=True,
31
+ use_cache=True,
32
+ use_flash_attention_2=True,
33
+ use_flash_rmsnorm=True,
34
+ use_flash_depthwise=False,
35
+ use_flashfft=False,
36
+ inference_mode=False,
37
+ prefill_style="fft",
38
+ max_seqlen=32768,
39
+ eps=1e-5,
40
+ state_size=2,
41
+ rotary_emb_base=500000,
42
+ smeared_gqa=False,
43
+ make_vocab_size_divisible_by=8,
44
+ log_intermediate_values=False,
45
+ **kwargs,
46
+ ):
47
+ self.vocab_size = vocab_size
48
+ self.hidden_size = hidden_size
49
+ self.num_filters = num_filters
50
+ self.inner_mlp_size = inner_mlp_size
51
+ self.attn_layer_idxs = attn_layer_idxs
52
+ self.hyena_layer_idxs = hyena_layer_idxs
53
+ self.num_layers = num_layers
54
+ self.tie_embeddings = tie_embeddings
55
+ self.short_filter_length = short_filter_length
56
+ self.num_attention_heads = num_attention_heads
57
+ self.proj_groups = proj_groups
58
+ self.hyena_filter_groups = hyena_filter_groups
59
+ self.split_k0 = split_k0
60
+ self.column_split_hyena = column_split_hyena
61
+ self.column_split = column_split
62
+ self.model_parallel_size = model_parallel_size
63
+ self.pipe_parallel_size = pipe_parallel_size
64
+ self.short_filter_bias = short_filter_bias
65
+ self.mha_out_proj_bias = mha_out_proj_bias
66
+ self.qkv_proj_bias = qkv_proj_bias
67
+ self.final_norm = final_norm
68
+ self.use_cache = use_cache
69
+ self.use_flash_attention_2 = use_flash_attention_2
70
+ self.use_flash_rmsnorm = use_flash_rmsnorm
71
+ self.use_flash_depthwise = use_flash_depthwise
72
+ self.use_flashfft = use_flashfft
73
+ self.inference_mode = inference_mode
74
+ self.prefill_style = prefill_style
75
+ self.max_seqlen = max_seqlen
76
+ self.eps = eps
77
+ self.state_size = state_size
78
+ self.rotary_emb_base = rotary_emb_base
79
+ self.smeared_gqa = smeared_gqa
80
+ self.make_vocab_size_divisible_by = make_vocab_size_divisible_by
81
+ self.log_intermediate_values = log_intermediate_values
82
+ super().__init__(**kwargs)
83
+
84
+ def to_dict(self):
85
+ return {attr: getattr(self, attr) for attr in self.__dict__}
86
+
87
+ @classmethod
88
+ def from_original_config(cls, config_path, **kwargs):
89
+ with open(config_path, "r") as f:
90
+ config = json.load(f)
91
+
92
+ return cls(**config, **kwargs)
engine.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Together
2
+ # This software is distributed under the terms of the Apache License, Version 2.0
3
+ # Author: Michael Poli
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ try:
10
+ import conv1d_cpp
11
+ except:
12
+ pass
13
+ from .utils import column_split
14
+
15
+
16
+ def canonicalize_modal_system(poles, residues):
17
+ """Canonicalize a modal system.
18
+
19
+ Args:
20
+ poles (Tensor): The poles of the system.
21
+ residues (Tensor): The residues of the system.
22
+
23
+ Returns:
24
+ Tuple[Tensor, Tensor]: The canonicalized poles and residues.
25
+ """
26
+ raise NotImplementedError
27
+
28
+
29
+ IIR_PREFILL_MODES = [
30
+ "recurrence",
31
+ "modal-fft",
32
+ "hybrid-modal-recurrence",
33
+ "modal-scan",
34
+ "canonical-fft",
35
+ "iir-fir-caching",
36
+ ]
37
+
38
+
39
+ class HyenaInferenceEngine:
40
+ def __init__(
41
+ self, fir_fn=None, fftconv_fn=None, iir_prefill_style="modal-fft", layer_idx=None
42
+ ) -> None:
43
+ self.fir_fn = fir_fn
44
+ self.fftconv_fn = fftconv_fn
45
+ assert (
46
+ iir_prefill_style in IIR_PREFILL_MODES
47
+ ), f"iir_prefill_style must be one of {IIR_PREFILL_MODES}"
48
+ self.iir_prefill_style = iir_prefill_style
49
+ self.layer_idx = layer_idx
50
+ self.low_mem_mode = False
51
+
52
+ def parallel_fir(
53
+ self,
54
+ fir_fn,
55
+ u,
56
+ weight,
57
+ bias,
58
+ L,
59
+ fir_length=3,
60
+ inference_params=None,
61
+ prefill_mode=None,
62
+ padding_mask=None,
63
+ ):
64
+ """Compute the output state of the long convolutional filter."""
65
+ # prepare input layout, dimensions and dispatch to fir kernel
66
+ if fir_fn != torch.nn.functional.conv1d:
67
+ z_pre = fir_fn(u)[:, :L] # B, L, D
68
+ z_pre = z_pre.permute(0, 2, 1)
69
+ else:
70
+ u = u.permute(0, 2, 1) # B, D, L
71
+ z_pre = fir_fn(
72
+ u,
73
+ weight,
74
+ bias,
75
+ stride=1,
76
+ padding=fir_length - 1,
77
+ groups=u.shape[1],
78
+ )[..., :L]
79
+
80
+ # handle padding post fir, the only place with biases
81
+ if type(padding_mask) == torch.Tensor:
82
+ z_pre = z_pre * padding_mask[:, None]
83
+
84
+ if inference_params is not None:
85
+ # handle seqlen last and dim last cases for `u`
86
+ if fir_fn != torch.nn.functional.conv1d:
87
+ fir_state = u[:, -fir_length + 1 :].permute(0, 2, 1)
88
+ else:
89
+ fir_state = u[..., -fir_length + 1 :]
90
+ else:
91
+ fir_state = None
92
+
93
+ return z_pre, fir_state
94
+
95
+ def parallel_iir(
96
+ self,
97
+ z_pre,
98
+ h,
99
+ D,
100
+ L,
101
+ poles,
102
+ t,
103
+ dims,
104
+ layer_idx,
105
+ inference_params=None,
106
+ prefill_style="fft",
107
+ fftconv_fn=None,
108
+ padding_mask=None,
109
+ use_flashfft=False,
110
+ column_split_hyena=False,
111
+ long_fir_threshold=None,
112
+ ):
113
+ """Compute the output state of the short convolutional filter."""
114
+ fft_size = 2 * L
115
+ hidden_size, num_attention_heads, hidden_size_per_attention_head, _, _ = dims
116
+ # Compatibility with training infra that column splits the projections
117
+ if column_split_hyena:
118
+ z = z_pre.reshape(
119
+ z_pre.shape[0],
120
+ num_attention_heads,
121
+ 3 * hidden_size_per_attention_head,
122
+ z_pre.shape[2],
123
+ )
124
+ x2, x1, v = (
125
+ z[:, :, :hidden_size_per_attention_head],
126
+ z[
127
+ :,
128
+ :,
129
+ hidden_size_per_attention_head : 2 * hidden_size_per_attention_head,
130
+ ],
131
+ z[:, :, 2 * hidden_size_per_attention_head :],
132
+ )
133
+ x2, x1, v = (
134
+ x2.reshape(x2.shape[0], -1, x2.shape[-1]),
135
+ x1.reshape(x1.shape[0], -1, x1.shape[-1]),
136
+ v.reshape(v.shape[0], -1, v.shape[-1]),
137
+ )
138
+ else:
139
+ x2, x1, v = z_pre.split([hidden_size, hidden_size, hidden_size], dim=1)
140
+
141
+ x1v = x1 * v
142
+
143
+ if use_flashfft and (L % 2) == 0: # only works with even L
144
+ y = fftconv_fn(
145
+ x1v.to(dtype=torch.bfloat16).contiguous(),
146
+ h.to(dtype=torch.float32),
147
+ )
148
+ X_s = None
149
+
150
+ elif long_fir_threshold is None:
151
+ H = torch.fft.rfft(h.to(dtype=torch.float32), n=fft_size) / fft_size
152
+ X_s = torch.fft.fft(x1v.to(dtype=torch.float32), n=fft_size)
153
+ X = X_s[..., : H.shape[-1]]
154
+ if len(z_pre.shape) > 3:
155
+ H = H.unsqueeze(1)
156
+ y = torch.fft.irfft(X * H, n=fft_size, norm="forward")[..., :L]
157
+ else:
158
+ assert h.shape[0] == 1, "batch size must be 1 for long_fir_threshold"
159
+ h = h[0][:, None] # rearrange to d, 1, l for depthwise conv1d
160
+ h = h[..., :long_fir_threshold]
161
+ y = F.conv1d(
162
+ x1v,
163
+ h.to(dtype=x1v.dtype),
164
+ stride=1,
165
+ groups=x1v.shape[1],
166
+ padding=h.shape[-1] - 1,
167
+ )[..., :L]
168
+
169
+ y = y.to(dtype=x1v.dtype)
170
+ y = (y + x1v * D.unsqueeze(-1)) * x2
171
+ if inference_params is not None:
172
+ if prefill_style == "fft":
173
+ self.prefill_via_modal_fft(
174
+ inference_params=inference_params,
175
+ x1v=x1v,
176
+ X_s=X_s,
177
+ L=L,
178
+ t=t,
179
+ poles=poles,
180
+ dims=dims,
181
+ layer_idx=layer_idx,
182
+ use_flashfft=use_flashfft,
183
+ )
184
+
185
+ elif prefill_style == "recurrence":
186
+ self.prefill_via_direct_recurrence(
187
+ inference_params=inference_params,
188
+ x1v=x1v,
189
+ L=L,
190
+ poles=poles,
191
+ )
192
+
193
+ else:
194
+ raise NotImplementedError
195
+ if self.low_mem_mode:
196
+ del z_pre, x2, x1, v, x1v, h
197
+ torch.cuda.empty_cache()
198
+
199
+ return y.permute(0, 2, 1)
200
+
201
+ def step_fir(self, u, fir_state, weight, bias=None):
202
+ """Step the FIR filter.
203
+
204
+ Note:
205
+ `fir_state` contains the last `short_filter_length - 1` elements of `u`: `u_(L-2), u_{L-1), ...`
206
+ We assume dimensions of `short_filter_weight` to be `[d, 1, short_filter_len]` (SISO / multi SISO layout).
207
+ """
208
+ h0, h = weight[..., 0, -1], weight[..., 0, :-1]
209
+ h0, h = h0[None], h[None]
210
+ y = h0 * u + torch.sum(fir_state * h, dim=-1) + bias
211
+
212
+ # update
213
+ fir_state = torch.roll(fir_state, -1, dims=2)
214
+ fir_state[..., -1] = u
215
+ return y, fir_state
216
+
217
+ def step_iir(self, x2, x1, v, D, residues, poles, iir_state, iir_groups=1):
218
+ x1v = x1 * v
219
+
220
+ residues, poles = (
221
+ torch.view_as_complex(residues.to(torch.float32)),
222
+ torch.view_as_complex(poles.to(torch.float32)),
223
+ )
224
+ # squeeze the dummy seqlen dimension
225
+ # D, state_dim, 1 -> 1, D, state_dim
226
+ residues, poles = residues[..., 0][None], poles[..., 0][None]
227
+ iir_state = poles * iir_state + x1v[..., None]
228
+
229
+ res_state = torch.sum(residues * iir_state, dim=-1).real
230
+
231
+ if iir_groups > 1:
232
+ raise NotImplementedError
233
+ y = x2 * (res_state + D * x1v)
234
+
235
+ return y, iir_state
236
+
237
+ def prefill_via_fir_caching(self, u, inference_params, L, *args, **kwargs):
238
+ """Turns the IIR filter into a FIR and uses a cache for decoding."""
239
+ raise NotImplementedError(":)")
240
+
241
+ def prefill_via_direct_recurrence(self, inference_params, x1v, L, poles, *args, **kwargs):
242
+ """
243
+ Compute the IIR state via explicit SSM recurrence (modal form)
244
+ """
245
+ x1v_ = x1v[..., None, None] # b, d, l, sdim, reim
246
+ x1v_ = x1v_.repeat(1, 1, 1, 1, 2) # b, d, l, sdim, reim
247
+
248
+ state = x1v_[:, :, 0]
249
+ poles = poles[:, :, 0].to(dtype=torch.float32)
250
+
251
+ for i in range(L):
252
+ state = poles * state + x1v_[:, :, i]
253
+ inference_params.state_dict[self.layer_idx] = torch.view_as_complex(
254
+ state.to(dtype=torch.float32)
255
+ )
256
+
257
+ def prefill_via_hybrid_recurrence(
258
+ self, inference_params, u, log_poles, x1v_f_a, L, *args, **kwargs
259
+ ):
260
+ """
261
+ Compute the IIR state via hybrid recurrence-convolution over blocks
262
+ """
263
+ raise NotImplementedError(":)")
264
+
265
+ def prefill_via_scan(self, u, inference_params=None, *args, **kwargs):
266
+ raise NotImplementedError
267
+
268
+ def prefill_via_canonical_fft(self, u, inference_params=None, *args, **kwargs):
269
+ """
270
+ Compute the IIR state via a single FFT with the denominator of the SSM in companion form.
271
+
272
+ This is the most memory efficient "parallelized" prefilling method for Hyena.
273
+
274
+ From: https://arxiv.org/abs/2310.18780
275
+ """
276
+ raise NotImplementedError(":)")
277
+
278
+ def prefill_via_modal_fft(
279
+ self,
280
+ inference_params,
281
+ x1v,
282
+ L,
283
+ poles,
284
+ t,
285
+ dims,
286
+ layer_idx,
287
+ X_s=None,
288
+ use_flashfft=False,
289
+ state_dtype=torch.complex64,
290
+ *args,
291
+ **kwargs,
292
+ ):
293
+ """
294
+ Compute the IIR state via a single FFT, using the poles of the SSM in modal form.
295
+ """
296
+ # When the model has a long convolution derived from a SSM in modal form and prefill_style is "fft",
297
+ # we split the filter into poles and residues and reuse FFT computation on the input.
298
+ # This optimization is currently not supported when using flashfftconv.
299
+ hidden_size, _, _, state_size, hyena_filter_groups = dims
300
+
301
+ if use_flashfft:
302
+ # using real states
303
+ poles = poles.squeeze().reshape(poles.shape[0], -1)[..., None]
304
+
305
+ state_s = poles**t
306
+ if hyena_filter_groups > 1:
307
+ raise NotImplementedError
308
+
309
+ x1v = x1v[:, :, None].repeat(1, 1, 2 * state_size, 1)
310
+ x1v = x1v.reshape(x1v.shape[0], -1, x1v.shape[-1])
311
+ state_s = state_s[None]
312
+
313
+ state = self.fftconv_fn(
314
+ x1v.contiguous(),
315
+ state_s.to(dtype=torch.float32),
316
+ )
317
+ state = state[..., L - 1].reshape(x1v.shape[0], hidden_size, state_size, 2)
318
+ state = torch.view_as_complex(state.contiguous())
319
+ inference_params.state_dict[self.layer_idx] = state.to(dtype=state_dtype)
320
+ else:
321
+ assert X_s is not None
322
+ bs = x1v.shape[0]
323
+ fft_size = 2 * L
324
+ poles = torch.view_as_complex(poles.to(torch.float32))
325
+ state_s = poles**t
326
+ state_S = torch.fft.fft(state_s, n=fft_size).repeat(
327
+ bs, 1, 1, 1
328
+ ) # B, D, state_dim, 2 * L
329
+ if hyena_filter_groups > 1:
330
+ state_S = state_S.repeat_interleave(hidden_size // hyena_filter_groups, 1)
331
+ state = torch.fft.ifft(X_s[..., None, :] * state_S, n=fft_size)
332
+ inference_params.state_dict[layer_idx] = state[..., L - 1].to(dtype=state_dtype)
333
+
334
+ def _compute_state(self, log_poles, u, t, L, *args, **kwargs):
335
+ """
336
+ Compute the IIR state given an input `u` and log_poles of the modal system.
337
+ """
338
+ bs = u.shape[0]
339
+ fft_size = 2 * L
340
+ U = torch.fft.rfft(u.to(torch.float32), n=fft_size)
341
+ fft_size = 2 * L
342
+ x = (log_poles * t).exp()
343
+ # [batch, hidden_size, state_dim, 2 * seqlen]
344
+ X = torch.fft.fft(x, n=fft_size).repeat(bs, 1, 1, 1)
345
+ state = torch.fft.ifft(U[..., None, :] * X, n=fft_size)[..., :L]
346
+ return state
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.36.0.dev0"
4
+ }
layers.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Together
2
+ # This software is distributed under the terms of the Apache License, Version 2.0
3
+ # Author: Michael Poli
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ import torch.nn.functional as F
8
+ import torch.nn as nn
9
+
10
+
11
+ class RMSNorm(torch.nn.Module):
12
+ def __init__(self, config):
13
+ super(RMSNorm, self).__init__()
14
+ self.eps, self.hidden_size = config.eps, config.hidden_size
15
+ self.scale = torch.nn.Parameter(torch.ones(self.hidden_size))
16
+ self.register_parameter("scale", self.scale)
17
+ self.use_flash_rmsnorm = config.get("use_flash_rmsnorm", False)
18
+
19
+ if self.use_flash_rmsnorm:
20
+ try:
21
+ from flash_attn.ops.rms_norm import rms_norm as rmsnorm_func
22
+
23
+ self.rmsnorm_func = rmsnorm_func
24
+ except:
25
+ raise ImportError(
26
+ "For `use_flash_rmsnorm`: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/layer_norm`"
27
+ )
28
+
29
+ def forward(self, x):
30
+ if self.use_flash_rmsnorm:
31
+ return self.rmsnorm_func(x, self.scale, self.eps)
32
+ else:
33
+ y = x / (x.norm(2, dim=-1, keepdim=True) * self.hidden_size ** (-1.0 / 2) + self.eps)
34
+ return self.scale * y
35
+
36
+
37
+ class ParallelGatedMLP(nn.Module):
38
+ def __init__(
39
+ self,
40
+ config,
41
+ ):
42
+ super().__init__()
43
+
44
+ multiple_of = config.get("inner_size_multiple_of", 64)
45
+ self.act = F.silu
46
+
47
+ self.multiple_of = multiple_of * config.model_parallel_size
48
+
49
+ inner_size = int(2 * config.hidden_size * 4 / 3)
50
+ inner_size = self.multiple_of * ((inner_size + self.multiple_of - 1) // self.multiple_of)
51
+ # if specified in the config, inner_size will be used instead of the calculated value
52
+ if config.get("inner_mlp_size", None) is not None:
53
+ inner_size = config.inner_mlp_size
54
+
55
+ self.l1 = nn.Linear(
56
+ in_features=config.hidden_size,
57
+ out_features=inner_size,
58
+ bias=False,
59
+ )
60
+ self.l2 = nn.Linear(
61
+ in_features=config.hidden_size,
62
+ out_features=inner_size,
63
+ bias=False,
64
+ )
65
+ self.l3 = nn.Linear(
66
+ in_features=inner_size,
67
+ out_features=config.hidden_size,
68
+ bias=False,
69
+ )
70
+
71
+ def forward(self, z):
72
+ z1, z2 = self.l1(z), self.l2(z)
73
+ if type(z1) == tuple:
74
+ z1 = z1[0]
75
+ if type(z2) == tuple:
76
+ z2 = z2[0]
77
+ y = self.l3(self.act(z1) * z2)
78
+ return y[0] if type(y) == tuple else y
79
+
80
+
81
+ class Embedding(nn.Module):
82
+ _train_dtype = "bf16"
83
+
84
+ def __init__(self, config):
85
+ super().__init__()
86
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
87
+
88
+ def embed(self, input_ids, position_ids=None, tokentype_ids=None):
89
+ embeddings = self.word_embeddings(input_ids)
90
+ return embeddings
91
+
92
+ def unembed(self, u):
93
+ weight = self.word_embeddings.weight
94
+ return torch.matmul(u, weight)
95
+
96
+
97
+ class VocabParallelEmbedding(nn.Embedding):
98
+ "Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/embedding.py"
99
+
100
+ def __init__(self, config):
101
+ vocab_size, process_group, padding_idx = (
102
+ config.vocab_size,
103
+ config.get("process_group", None),
104
+ config.get("padding_idx", None),
105
+ )
106
+ self.process_group = process_group
107
+ if process_group is not None:
108
+ world_size = torch.distributed.get_world_size(process_group)
109
+ if vocab_size % world_size != 0:
110
+ raise ValueError(
111
+ f"vocab_size ({vocab_size}) must be divisible by " f"world_size ({world_size})"
112
+ )
113
+ if world_size > 1 and padding_idx is not None:
114
+ raise RuntimeError("ParallelEmbedding does not support padding_idx")
115
+ else:
116
+ world_size = 1
117
+ super().__init__(
118
+ vocab_size // world_size,
119
+ embedding_dim=config.hidden_size,
120
+ padding_idx=padding_idx,
121
+ )
122
+
123
+ def embed(self, x: Tensor) -> Tensor:
124
+ if self.process_group is None:
125
+ return self.forward(x)
126
+ else:
127
+ rank = torch.distributed.get_rank(self.process_group)
128
+ vocab_size = self.num_embeddings
129
+ vocab_start_index, vocab_end_index = (
130
+ rank * vocab_size,
131
+ (rank + 1) * vocab_size,
132
+ )
133
+ # Create a mask of valid vocab ids (1 means it needs to be masked).
134
+ input_ids_mask = (x < vocab_start_index) | (x >= vocab_end_index)
135
+ x = x - vocab_start_index
136
+ x[input_ids_mask] = 0
137
+ embeddings = self.forward(x)
138
+ embeddings[input_ids_mask] = 0.0
139
+ # Reduce to the global process group
140
+ torch.distributed.all_reduce(embeddings, group=self.process_group)
141
+ return embeddings
142
+
143
+ def unembed(self, u: Tensor) -> Tensor:
144
+ if self.process_group is None:
145
+ return u @ self.weight.T
146
+ else:
147
+ raise NotImplementedError
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fac4d012a198242ed842c59dc8d76fd511425d3c85e7ccef8501603236dda6f5
3
+ size 4904678112
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2f5b33a06d0d97143163199e115ddcdda2035f149d2a00a839688e86584bdf3
3
+ size 4984649888
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00f581b673be6a324d2490ad8406e3c753fe62fa3b6d811312e9fe77fb8eeaf2
3
+ size 4967601728
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ade0044faaa2a361c2c07dd647d43f8ff2969e85c0eda1e459b5918ba7f96e9
3
+ size 436208224
model.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Together
2
+ # This software is distributed under the terms of the Apache License, Version 2.0
3
+ # Author: Michael Poli
4
+ # Note: MP and PP utilities are removed for ease of use and editing.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from .utils import print_rank_0, column_split
11
+ from .cache import InferenceParams, RecurrentInferenceParams
12
+ from .engine import HyenaInferenceEngine
13
+ from .layers import (
14
+ RMSNorm,
15
+ ParallelGatedMLP,
16
+ VocabParallelEmbedding,
17
+ )
18
+
19
+ try:
20
+ from flash_attn.modules.mha import MHA
21
+ except ImportError:
22
+ "flash_attn not installed"
23
+
24
+
25
+ class AttentionBlock(nn.Module):
26
+ def __init__(self, config, layer_idx) -> None:
27
+ super().__init__()
28
+ self.config = config
29
+ self.pre_norm, self.post_norm = RMSNorm(config), RMSNorm(config)
30
+ self.layer_idx = layer_idx
31
+ self.proj_groups = config.get("proj_groups", 1)
32
+ dtype = config.get("attn_block_dtype", torch.bfloat16)
33
+ mlp_dtype = config.get("mlp_dtype", torch.bfloat16)
34
+ self.num_attention_heads = config.num_attention_heads
35
+ self.hidden_size_per_attention_head = config.hidden_size // config.num_attention_heads
36
+
37
+ self.counter = 0
38
+ self.inner_mha_cls = MHA(
39
+ embed_dim=config.hidden_size,
40
+ num_heads=config.num_attention_heads,
41
+ num_heads_kv=config.num_attention_heads // self.proj_groups,
42
+ rotary_emb_dim=config.hidden_size // config.num_attention_heads,
43
+ qkv_proj_bias=config.get("qkv_proj_bias", True),
44
+ rotary_emb_base=config.get("rotary_emb_base", 10000),
45
+ causal=True,
46
+ layer_idx=layer_idx,
47
+ out_proj_bias=config.get("mha_out_proj_bias", True),
48
+ use_flash_attn=self.config.use_flash_attn,
49
+ ).to(dtype=dtype)
50
+
51
+ if self.config.get("smeared_gqa", False):
52
+ self.inner_mha_cls.num_heads_kv = self.inner_mha_cls.num_heads
53
+ self.inner_mha_cls.rotary_emb.register_buffer(
54
+ "inv_freq", self.inner_mha_cls.rotary_emb.inv_freq
55
+ )
56
+
57
+ self.mlp = ParallelGatedMLP(config).to(dtype=mlp_dtype)
58
+
59
+ def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
60
+ if (
61
+ type(padding_mask) == torch.Tensor
62
+ ): # workaround for masking bug in FA. This works because Wqkv does not have bias
63
+ # and attention scores will be also automatically zeroed.
64
+ u = u * padding_mask[..., None]
65
+
66
+ u = (
67
+ self.inner_mha_cls(
68
+ self.pre_norm(u),
69
+ inference_params=inference_params,
70
+ )
71
+ + u
72
+ )
73
+ if type(padding_mask) == torch.Tensor: # guard against bias
74
+ u = u * padding_mask[..., None]
75
+ u = self.mlp(self.post_norm(u)) + u
76
+ return u, None
77
+
78
+
79
+ class ParallelHyenaFilter(nn.Module):
80
+ def __init__(self, config, layer_idx) -> None:
81
+ super().__init__()
82
+ self.config = config
83
+ self.layer_idx = layer_idx
84
+ self.hyena_filter_groups = config.get("hyena_filter_groups", self.config.hidden_size)
85
+
86
+ self.use_flashfft = config.get("use_flashfft", False)
87
+ self.state_size = config.state_size
88
+ self.hidden_size = config.hidden_size
89
+ self.num_filters = config.num_filters
90
+ self.inference_mode = config.get("inference_mode", True)
91
+ self.counter = 0
92
+ self.column_split_hyena = config.get("column_split_hyena", True)
93
+
94
+ assert self.hidden_size % self.num_filters == 0 and self.num_filters <= self.hidden_size
95
+
96
+ self.D = nn.Parameter(torch.zeros(self.hidden_size))
97
+
98
+ # attention heads are not used except to split post short_filter
99
+ # projections in the same way as the checkpoint
100
+ self.num_attention_heads = config.num_attention_heads
101
+ self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
102
+
103
+ # after preprocessing here we can save the new checkpoint
104
+ self.short_filter_length = config.short_filter_length
105
+ self.short_filter_weight = nn.Parameter(
106
+ torch.randn(3 * config.hidden_size, 1, config.short_filter_length)
107
+ )
108
+ self.short_filter_bias = (
109
+ nn.Parameter(torch.randn(3 * config.hidden_size)) if config.short_filter_bias else None
110
+ )
111
+
112
+ self.engine = HyenaInferenceEngine(layer_idx=layer_idx)
113
+ self.use_flash_depthwise = config.get("use_flash_depthwise", False)
114
+ self.data_dtype = None
115
+
116
+ if self.use_flash_depthwise:
117
+ self.fir_fn = FlashDepthwiseConv1d(
118
+ channels=3 * self.hidden_size,
119
+ kernel_size=self.short_filter_length,
120
+ padding=self.short_filter_length - 1,
121
+ weights=self.short_filter_weight,
122
+ bias=self.short_filter_bias,
123
+ device=None,
124
+ dtype=self.config.get("depthwise_dtype", torch.bfloat16),
125
+ )
126
+ else:
127
+ self.fir_fn = F.conv1d
128
+
129
+ self.fftconv_fn = None
130
+ self.long_fir_threshold = config.get("long_fir_threshold", None)
131
+ if self.long_fir_threshold is not None:
132
+ assert (
133
+ self.use_flashfft is False
134
+ ), "long_fir_threshold not compatible with fused flashfft"
135
+
136
+ self.num_systems = self.hidden_size // self.hyena_filter_groups
137
+ self.poles = nn.Parameter(torch.randn(self.num_systems, self.state_size, 1, 2))
138
+ self.residues = nn.Parameter(torch.randn(self.num_systems, self.state_size, 1, 2))
139
+ self.h = None
140
+
141
+ def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
142
+ if (
143
+ inference_params is not None
144
+ and self.layer_idx in inference_params.fir_state_dict.keys()
145
+ ):
146
+ return self.sequential_forward(u, inference_params)
147
+
148
+ else:
149
+ return self.parallel_forward(u, inference_params, padding_mask)
150
+
151
+ def parallel_forward(self, u, inference_params=None, padding_mask=None):
152
+ L = u.shape[1]
153
+ z_pre, fir_state = self.engine.parallel_fir(
154
+ self.fir_fn,
155
+ u,
156
+ self.short_filter_weight,
157
+ self.short_filter_bias,
158
+ L,
159
+ fir_length=self.short_filter_length,
160
+ inference_params=inference_params,
161
+ padding_mask=padding_mask,
162
+ )
163
+ if inference_params:
164
+ inference_params.fir_state_dict[self.layer_idx] = fir_state
165
+
166
+ if self.h is None:
167
+ h, filter_dtype, poles, residues = self.compute_filter(L, u.device)
168
+ else:
169
+ h = self.h
170
+ filter_dtype = self.h.dtype
171
+
172
+ if self.hyena_filter_groups > 1:
173
+ h = h.repeat_interleave(self.hidden_size // self.hyena_filter_groups, 1)
174
+
175
+ # if inference_params is not None, we plan to perform generation:
176
+ # prefilling for the IIR portion of the filter is handled by the engine.
177
+ dims = (
178
+ self.hidden_size,
179
+ self.num_attention_heads,
180
+ self.hidden_size_per_attention_head,
181
+ self.state_size,
182
+ self.hyena_filter_groups,
183
+ )
184
+ y = self.engine.parallel_iir(
185
+ z_pre,
186
+ h,
187
+ self.D,
188
+ L,
189
+ t=self.t,
190
+ poles=self.poles,
191
+ dims=dims,
192
+ inference_params=inference_params,
193
+ layer_idx=self.layer_idx,
194
+ prefill_style=self.config.get("prefill_style", "fft"),
195
+ use_flashfft=self.use_flashfft,
196
+ fftconv_fn=self.fftconv_fn,
197
+ column_split_hyena=self.column_split_hyena,
198
+ long_fir_threshold=self.long_fir_threshold,
199
+ padding_mask=padding_mask,
200
+ )
201
+
202
+ return y, inference_params
203
+
204
+ def sequential_forward(self, u, inference_params):
205
+ if self.data_dtype is None:
206
+ self.data_dtype = u.dtype
207
+ if len(u.shape) > 2:
208
+ u = u[:, -1]
209
+
210
+ fir_state, iir_state = (
211
+ inference_params.fir_state_dict[self.layer_idx],
212
+ inference_params.state_dict[self.layer_idx],
213
+ )
214
+
215
+ z_pre, fir_state = self.engine.step_fir(
216
+ u, fir_state, weight=self.short_filter_weight, bias=self.short_filter_bias
217
+ )
218
+ x2, x1, v = (
219
+ column_split(z_pre, self.num_attention_heads, self.hidden_size_per_attention_head)
220
+ if self.column_split_hyena
221
+ else z_pre.split([self.hidden_size, self.hidden_size, self.hidden_size], dim=1)
222
+ )
223
+
224
+ y, iir_state = self.engine.step_iir(
225
+ x2,
226
+ x1,
227
+ v,
228
+ self.D,
229
+ self.residues,
230
+ self.poles,
231
+ iir_state,
232
+ iir_groups=self.hyena_filter_groups,
233
+ )
234
+
235
+ inference_params.fir_state_dict[self.layer_idx] = fir_state
236
+ inference_params.state_dict[self.layer_idx] = iir_state
237
+ y = y.to(dtype=self.data_dtype)
238
+ return y[:, None], inference_params
239
+
240
+ def update_time(self, L, device):
241
+ """
242
+ Set [0, 1, ..., L-1] where L is the length of the current batch of inputs.
243
+ If L is greater than the length of the previous batch, then the time vector is
244
+ reinitialized. Otherwise, the time vector is truncated from cache.
245
+ """
246
+ if not hasattr(self, "t"):
247
+ self.t = torch.arange(L, device=device)[None, None]
248
+ elif self.t.shape[-1] < L:
249
+ self.t = torch.arange(L, device=device)[None, None]
250
+ else:
251
+ self.t = self.t[..., :L]
252
+
253
+ def compute_filter(self, L, device):
254
+ self.update_time(L, device)
255
+ filter_dtype = torch.float32
256
+ residues, log_poles = (
257
+ torch.view_as_complex(self.residues.to(filter_dtype)),
258
+ torch.view_as_complex(self.poles.to(filter_dtype)).log(),
259
+ )
260
+ h = (residues * (log_poles * self.t).exp()).real.sum(1)[None]
261
+ return h, filter_dtype, log_poles, residues
262
+
263
+
264
+ class ParallelGatedConvBlock(nn.Module):
265
+ def __init__(self, config, layer_idx) -> None:
266
+ super().__init__()
267
+ self.config = config
268
+ self.layer_idx = layer_idx
269
+ dtype = config.get("hyena_block_dtype", torch.float32)
270
+ mlp_dtype = config.get("mlp_dtype", torch.bfloat16)
271
+ self.pre_norm, self.post_norm = RMSNorm(config).to(dtype=dtype), RMSNorm(config).to(
272
+ dtype=dtype
273
+ )
274
+ self.filter = ParallelHyenaFilter(config, layer_idx).to(dtype=dtype)
275
+ self.projections = nn.Linear(config.hidden_size, 3 * config.hidden_size)
276
+ self.out_filter_dense = nn.Linear(config.hidden_size, config.hidden_size).to(dtype)
277
+ self.mlp = ParallelGatedMLP(config).to(dtype=mlp_dtype)
278
+
279
+ def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
280
+ z = self.projections(self.pre_norm(u))
281
+ if type(padding_mask) == torch.Tensor: # guard against bias
282
+ z = z * padding_mask[..., None]
283
+
284
+ z, inference_params = self.filter(
285
+ z, inference_params=inference_params, padding_mask=padding_mask
286
+ )
287
+
288
+ u = self.out_filter_dense(z) + u
289
+ if type(padding_mask) == torch.Tensor: # guard against bias
290
+ u = u * padding_mask[..., None]
291
+ u = self.mlp(self.post_norm(u)) + u
292
+ return u, inference_params
293
+
294
+
295
+ def get_block(config, layer_idx, flash_fft=None):
296
+ if layer_idx in config.attn_layer_idxs:
297
+ return AttentionBlock(config, layer_idx)
298
+ elif layer_idx in config.hyena_layer_idxs:
299
+ block = ParallelGatedConvBlock(config, layer_idx)
300
+ if config.get("use_flashfft", "False"):
301
+ block.filter.fftconv_fn = flash_fft
302
+ return block
303
+ else:
304
+ raise NotImplementedError
305
+
306
+
307
+ class StripedHyena(nn.Module):
308
+ def __init__(self, config):
309
+ super().__init__()
310
+ self.config = config
311
+ self.embedding_layer = VocabParallelEmbedding(config)
312
+ self.norm = RMSNorm(config) if config.get("final_norm", True) else None
313
+ self.unembed = self.emb if config.tie_embeddings else VocabParallelEmbedding(config)
314
+ self.gradient_checkpointing = False
315
+
316
+ if config.get("use_flashfft", "False"):
317
+ raise NotImplementedError("Please use standalone SH code for other custom kernels")
318
+ else:
319
+ self.flash_fft = None
320
+
321
+ self.blocks = nn.ModuleList(
322
+ get_block(config, layer_idx, flash_fft=self.flash_fft)
323
+ for layer_idx in range(config.num_layers)
324
+ )
325
+
326
+ def forward(self, x, inference_params_dict=None, padding_mask=None):
327
+ L = x.shape[1]
328
+ x = self.embedding_layer.embed(x)
329
+ if inference_params_dict is not None:
330
+ x, inference_params_dict_out = self.stateful_forward(
331
+ x,
332
+ inference_params_dict=inference_params_dict,
333
+ )
334
+ else:
335
+ x, inference_params_dict_out = self.stateless_forward(x, padding_mask=padding_mask)
336
+ x = self.norm(x)
337
+ x = self.unembed.unembed(x)
338
+ return x, inference_params_dict_out
339
+
340
+ def stateful_forward(self, x, inference_params_dict=None):
341
+ for block_idx, block in enumerate(self.blocks):
342
+ block_name = "mha" if block_idx in self.config.attn_layer_idxs else "hyena"
343
+ inference_params = inference_params_dict[block_name]
344
+ x, _ = block(x, inference_params=inference_params)
345
+
346
+ return x, inference_params_dict
347
+
348
+ def stateless_forward(self, x, padding_mask=None):
349
+ if type(padding_mask) == torch.Tensor:
350
+ x = x * padding_mask[..., None]
351
+
352
+ for block_idx, block in enumerate(self.blocks):
353
+ if self.gradient_checkpointing and self.training:
354
+ def create_custom_forward(module):
355
+ def custom_forward(*inputs):
356
+ # None for past_key_value
357
+ return module(*inputs, inference_params=None, padding_mask=padding_mask)
358
+
359
+ return custom_forward
360
+
361
+ x, _ = checkpoint(create_custom_forward(block), x, use_reentrant=False)
362
+ else:
363
+ x, _ = block(x, inference_params=None, padding_mask=padding_mask)
364
+ return x, None
365
+
366
+ def initialize_inference_params(self):
367
+ print_rank_0("Initializing inference params...")
368
+ inference_params_dict = {
369
+ "mha": InferenceParams(
370
+ max_seqlen=self.config.get("max_seqlen", 8192),
371
+ max_batch_size=self.config.get("max_batch_size", 1),
372
+ seqlen_offset=0,
373
+ ),
374
+ "hyena": RecurrentInferenceParams(
375
+ fir_filter_length=self.config.short_filter_length,
376
+ state_dim=self.config.state_size,
377
+ seqlen_offset=0,
378
+ ),
379
+ }
380
+ return inference_params_dict
381
+
382
+ def precompute_filters(self, L, device):
383
+ for block_idx, block in enumerate(self.blocks):
384
+ if type(block) == ParallelGatedConvBlock:
385
+ if type(block.filter) == ParallelHyenaFilter:
386
+ L = block.filter.long_fir_threshold or L
387
+ print_rank_0(f"Precomputing filters, L={L}...")
388
+
389
+ filter_dtype = torch.float16 if L >= 2048 else torch.float32
390
+
391
+ block.filter._set_time(L, device)
392
+ residues, poles = (
393
+ torch.view_as_complex(block.filter.residues.to(torch.float16)),
394
+ torch.view_as_complex(block.filter.poles.to(torch.float16)),
395
+ )
396
+
397
+ block.filter.h = (residues * poles**block.filter.t).real.sum(1)[None]
398
+ block.filter.h = block.filter.h.to(dtype=filter_dtype)
399
+
400
+ def load_poles_residues(self, path):
401
+ "Load different poles and residues for each layer."
402
+ for block_idx, block in enumerate(self.blocks):
403
+ if type(block) == ParallelGatedConvBlock:
404
+ if type(block.filter) == ParallelHyenaFilter:
405
+ print(f"Loading poles and residues for block {block_idx}")
406
+ poles = torch.load(path + f"/approx_poles_{block_idx+1}.pt", map_location="cpu")
407
+ poles = torch.view_as_real(poles)
408
+ residues = torch.load(
409
+ path + f"/approx_residues_{block_idx+1}.pt", map_location="cpu"
410
+ )
411
+ residues = torch.view_as_real(residues)
412
+ poles = poles.permute(1, 0, 2).unsqueeze(-2)
413
+ residues = residues.permute(1, 0, 2).unsqueeze(-2)
414
+
415
+ block.filter.poles = nn.Parameter(poles)
416
+ block.filter.residues = nn.Parameter(residues)
417
+
418
+ def to_bfloat16_except_poles_residues(self):
419
+ """Convert all parameters to bfloat16 except for the poles and residues.
420
+
421
+ Particularly important for longer prompts.
422
+ """
423
+ for k, p in self.named_parameters():
424
+ if "poles" not in k and "residues" not in k:
425
+ p.data = p.data.to(torch.bfloat16)
model.safetensors.index.json ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 15293097984
4
+ },
5
+ "weight_map": {
6
+ "backbone.blocks.0.filter.D": "model-00001-of-00004.safetensors",
7
+ "backbone.blocks.0.filter.poles": "model-00001-of-00004.safetensors",
8
+ "backbone.blocks.0.filter.residues": "model-00001-of-00004.safetensors",
9
+ "backbone.blocks.0.filter.short_filter_bias": "model-00001-of-00004.safetensors",
10
+ "backbone.blocks.0.filter.short_filter_weight": "model-00001-of-00004.safetensors",
11
+ "backbone.blocks.0.mlp.l1.weight": "model-00001-of-00004.safetensors",
12
+ "backbone.blocks.0.mlp.l2.weight": "model-00001-of-00004.safetensors",
13
+ "backbone.blocks.0.mlp.l3.weight": "model-00001-of-00004.safetensors",
14
+ "backbone.blocks.0.out_filter_dense.bias": "model-00001-of-00004.safetensors",
15
+ "backbone.blocks.0.out_filter_dense.weight": "model-00001-of-00004.safetensors",
16
+ "backbone.blocks.0.post_norm.scale": "model-00001-of-00004.safetensors",
17
+ "backbone.blocks.0.pre_norm.scale": "model-00001-of-00004.safetensors",
18
+ "backbone.blocks.0.projections.bias": "model-00001-of-00004.safetensors",
19
+ "backbone.blocks.0.projections.weight": "model-00001-of-00004.safetensors",
20
+ "backbone.blocks.1.inner_mha_cls.Wqkv.weight": "model-00001-of-00004.safetensors",
21
+ "backbone.blocks.1.inner_mha_cls.out_proj.weight": "model-00001-of-00004.safetensors",
22
+ "backbone.blocks.1.inner_mha_cls.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
23
+ "backbone.blocks.1.mlp.l1.weight": "model-00001-of-00004.safetensors",
24
+ "backbone.blocks.1.mlp.l2.weight": "model-00001-of-00004.safetensors",
25
+ "backbone.blocks.1.mlp.l3.weight": "model-00001-of-00004.safetensors",
26
+ "backbone.blocks.1.post_norm.scale": "model-00001-of-00004.safetensors",
27
+ "backbone.blocks.1.pre_norm.scale": "model-00001-of-00004.safetensors",
28
+ "backbone.blocks.10.filter.D": "model-00002-of-00004.safetensors",
29
+ "backbone.blocks.10.filter.poles": "model-00002-of-00004.safetensors",
30
+ "backbone.blocks.10.filter.residues": "model-00002-of-00004.safetensors",
31
+ "backbone.blocks.10.filter.short_filter_bias": "model-00002-of-00004.safetensors",
32
+ "backbone.blocks.10.filter.short_filter_weight": "model-00002-of-00004.safetensors",
33
+ "backbone.blocks.10.mlp.l1.weight": "model-00002-of-00004.safetensors",
34
+ "backbone.blocks.10.mlp.l2.weight": "model-00002-of-00004.safetensors",
35
+ "backbone.blocks.10.mlp.l3.weight": "model-00002-of-00004.safetensors",
36
+ "backbone.blocks.10.out_filter_dense.bias": "model-00002-of-00004.safetensors",
37
+ "backbone.blocks.10.out_filter_dense.weight": "model-00002-of-00004.safetensors",
38
+ "backbone.blocks.10.post_norm.scale": "model-00002-of-00004.safetensors",
39
+ "backbone.blocks.10.pre_norm.scale": "model-00002-of-00004.safetensors",
40
+ "backbone.blocks.10.projections.bias": "model-00002-of-00004.safetensors",
41
+ "backbone.blocks.10.projections.weight": "model-00002-of-00004.safetensors",
42
+ "backbone.blocks.11.inner_mha_cls.Wqkv.weight": "model-00002-of-00004.safetensors",
43
+ "backbone.blocks.11.inner_mha_cls.out_proj.weight": "model-00002-of-00004.safetensors",
44
+ "backbone.blocks.11.inner_mha_cls.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
45
+ "backbone.blocks.11.mlp.l1.weight": "model-00002-of-00004.safetensors",
46
+ "backbone.blocks.11.mlp.l2.weight": "model-00002-of-00004.safetensors",
47
+ "backbone.blocks.11.mlp.l3.weight": "model-00002-of-00004.safetensors",
48
+ "backbone.blocks.11.post_norm.scale": "model-00002-of-00004.safetensors",
49
+ "backbone.blocks.11.pre_norm.scale": "model-00002-of-00004.safetensors",
50
+ "backbone.blocks.12.filter.D": "model-00002-of-00004.safetensors",
51
+ "backbone.blocks.12.filter.poles": "model-00002-of-00004.safetensors",
52
+ "backbone.blocks.12.filter.residues": "model-00002-of-00004.safetensors",
53
+ "backbone.blocks.12.filter.short_filter_bias": "model-00002-of-00004.safetensors",
54
+ "backbone.blocks.12.filter.short_filter_weight": "model-00002-of-00004.safetensors",
55
+ "backbone.blocks.12.mlp.l1.weight": "model-00002-of-00004.safetensors",
56
+ "backbone.blocks.12.mlp.l2.weight": "model-00002-of-00004.safetensors",
57
+ "backbone.blocks.12.mlp.l3.weight": "model-00002-of-00004.safetensors",
58
+ "backbone.blocks.12.out_filter_dense.bias": "model-00002-of-00004.safetensors",
59
+ "backbone.blocks.12.out_filter_dense.weight": "model-00002-of-00004.safetensors",
60
+ "backbone.blocks.12.post_norm.scale": "model-00002-of-00004.safetensors",
61
+ "backbone.blocks.12.pre_norm.scale": "model-00002-of-00004.safetensors",
62
+ "backbone.blocks.12.projections.bias": "model-00002-of-00004.safetensors",
63
+ "backbone.blocks.12.projections.weight": "model-00002-of-00004.safetensors",
64
+ "backbone.blocks.13.inner_mha_cls.Wqkv.weight": "model-00002-of-00004.safetensors",
65
+ "backbone.blocks.13.inner_mha_cls.out_proj.weight": "model-00002-of-00004.safetensors",
66
+ "backbone.blocks.13.inner_mha_cls.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
67
+ "backbone.blocks.13.mlp.l1.weight": "model-00002-of-00004.safetensors",
68
+ "backbone.blocks.13.mlp.l2.weight": "model-00002-of-00004.safetensors",
69
+ "backbone.blocks.13.mlp.l3.weight": "model-00002-of-00004.safetensors",
70
+ "backbone.blocks.13.post_norm.scale": "model-00002-of-00004.safetensors",
71
+ "backbone.blocks.13.pre_norm.scale": "model-00002-of-00004.safetensors",
72
+ "backbone.blocks.14.filter.D": "model-00002-of-00004.safetensors",
73
+ "backbone.blocks.14.filter.poles": "model-00002-of-00004.safetensors",
74
+ "backbone.blocks.14.filter.residues": "model-00002-of-00004.safetensors",
75
+ "backbone.blocks.14.filter.short_filter_bias": "model-00002-of-00004.safetensors",
76
+ "backbone.blocks.14.filter.short_filter_weight": "model-00002-of-00004.safetensors",
77
+ "backbone.blocks.14.mlp.l1.weight": "model-00002-of-00004.safetensors",
78
+ "backbone.blocks.14.mlp.l2.weight": "model-00002-of-00004.safetensors",
79
+ "backbone.blocks.14.mlp.l3.weight": "model-00002-of-00004.safetensors",
80
+ "backbone.blocks.14.out_filter_dense.bias": "model-00002-of-00004.safetensors",
81
+ "backbone.blocks.14.out_filter_dense.weight": "model-00002-of-00004.safetensors",
82
+ "backbone.blocks.14.post_norm.scale": "model-00002-of-00004.safetensors",
83
+ "backbone.blocks.14.pre_norm.scale": "model-00002-of-00004.safetensors",
84
+ "backbone.blocks.14.projections.bias": "model-00002-of-00004.safetensors",
85
+ "backbone.blocks.14.projections.weight": "model-00002-of-00004.safetensors",
86
+ "backbone.blocks.15.inner_mha_cls.Wqkv.weight": "model-00002-of-00004.safetensors",
87
+ "backbone.blocks.15.inner_mha_cls.out_proj.weight": "model-00002-of-00004.safetensors",
88
+ "backbone.blocks.15.inner_mha_cls.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
89
+ "backbone.blocks.15.mlp.l1.weight": "model-00002-of-00004.safetensors",
90
+ "backbone.blocks.15.mlp.l2.weight": "model-00002-of-00004.safetensors",
91
+ "backbone.blocks.15.mlp.l3.weight": "model-00002-of-00004.safetensors",
92
+ "backbone.blocks.15.post_norm.scale": "model-00002-of-00004.safetensors",
93
+ "backbone.blocks.15.pre_norm.scale": "model-00002-of-00004.safetensors",
94
+ "backbone.blocks.16.filter.D": "model-00002-of-00004.safetensors",
95
+ "backbone.blocks.16.filter.poles": "model-00002-of-00004.safetensors",
96
+ "backbone.blocks.16.filter.residues": "model-00002-of-00004.safetensors",
97
+ "backbone.blocks.16.filter.short_filter_bias": "model-00002-of-00004.safetensors",
98
+ "backbone.blocks.16.filter.short_filter_weight": "model-00002-of-00004.safetensors",
99
+ "backbone.blocks.16.mlp.l1.weight": "model-00002-of-00004.safetensors",
100
+ "backbone.blocks.16.mlp.l2.weight": "model-00002-of-00004.safetensors",
101
+ "backbone.blocks.16.mlp.l3.weight": "model-00002-of-00004.safetensors",
102
+ "backbone.blocks.16.out_filter_dense.bias": "model-00002-of-00004.safetensors",
103
+ "backbone.blocks.16.out_filter_dense.weight": "model-00002-of-00004.safetensors",
104
+ "backbone.blocks.16.post_norm.scale": "model-00002-of-00004.safetensors",
105
+ "backbone.blocks.16.pre_norm.scale": "model-00002-of-00004.safetensors",
106
+ "backbone.blocks.16.projections.bias": "model-00002-of-00004.safetensors",
107
+ "backbone.blocks.16.projections.weight": "model-00002-of-00004.safetensors",
108
+ "backbone.blocks.17.inner_mha_cls.Wqkv.weight": "model-00002-of-00004.safetensors",
109
+ "backbone.blocks.17.inner_mha_cls.out_proj.weight": "model-00002-of-00004.safetensors",
110
+ "backbone.blocks.17.inner_mha_cls.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
111
+ "backbone.blocks.17.mlp.l1.weight": "model-00002-of-00004.safetensors",
112
+ "backbone.blocks.17.mlp.l2.weight": "model-00002-of-00004.safetensors",
113
+ "backbone.blocks.17.mlp.l3.weight": "model-00002-of-00004.safetensors",
114
+ "backbone.blocks.17.post_norm.scale": "model-00002-of-00004.safetensors",
115
+ "backbone.blocks.17.pre_norm.scale": "model-00002-of-00004.safetensors",
116
+ "backbone.blocks.18.filter.D": "model-00002-of-00004.safetensors",
117
+ "backbone.blocks.18.filter.poles": "model-00002-of-00004.safetensors",
118
+ "backbone.blocks.18.filter.residues": "model-00002-of-00004.safetensors",
119
+ "backbone.blocks.18.filter.short_filter_bias": "model-00002-of-00004.safetensors",
120
+ "backbone.blocks.18.filter.short_filter_weight": "model-00002-of-00004.safetensors",
121
+ "backbone.blocks.18.mlp.l1.weight": "model-00002-of-00004.safetensors",
122
+ "backbone.blocks.18.mlp.l2.weight": "model-00002-of-00004.safetensors",
123
+ "backbone.blocks.18.mlp.l3.weight": "model-00002-of-00004.safetensors",
124
+ "backbone.blocks.18.out_filter_dense.bias": "model-00002-of-00004.safetensors",
125
+ "backbone.blocks.18.out_filter_dense.weight": "model-00002-of-00004.safetensors",
126
+ "backbone.blocks.18.post_norm.scale": "model-00002-of-00004.safetensors",
127
+ "backbone.blocks.18.pre_norm.scale": "model-00002-of-00004.safetensors",
128
+ "backbone.blocks.18.projections.bias": "model-00002-of-00004.safetensors",
129
+ "backbone.blocks.18.projections.weight": "model-00002-of-00004.safetensors",
130
+ "backbone.blocks.19.inner_mha_cls.Wqkv.weight": "model-00002-of-00004.safetensors",
131
+ "backbone.blocks.19.inner_mha_cls.out_proj.weight": "model-00002-of-00004.safetensors",
132
+ "backbone.blocks.19.inner_mha_cls.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
133
+ "backbone.blocks.19.mlp.l1.weight": "model-00002-of-00004.safetensors",
134
+ "backbone.blocks.19.mlp.l2.weight": "model-00002-of-00004.safetensors",
135
+ "backbone.blocks.19.mlp.l3.weight": "model-00002-of-00004.safetensors",
136
+ "backbone.blocks.19.post_norm.scale": "model-00002-of-00004.safetensors",
137
+ "backbone.blocks.19.pre_norm.scale": "model-00002-of-00004.safetensors",
138
+ "backbone.blocks.2.filter.D": "model-00001-of-00004.safetensors",
139
+ "backbone.blocks.2.filter.poles": "model-00001-of-00004.safetensors",
140
+ "backbone.blocks.2.filter.residues": "model-00001-of-00004.safetensors",
141
+ "backbone.blocks.2.filter.short_filter_bias": "model-00001-of-00004.safetensors",
142
+ "backbone.blocks.2.filter.short_filter_weight": "model-00001-of-00004.safetensors",
143
+ "backbone.blocks.2.mlp.l1.weight": "model-00001-of-00004.safetensors",
144
+ "backbone.blocks.2.mlp.l2.weight": "model-00001-of-00004.safetensors",
145
+ "backbone.blocks.2.mlp.l3.weight": "model-00001-of-00004.safetensors",
146
+ "backbone.blocks.2.out_filter_dense.bias": "model-00001-of-00004.safetensors",
147
+ "backbone.blocks.2.out_filter_dense.weight": "model-00001-of-00004.safetensors",
148
+ "backbone.blocks.2.post_norm.scale": "model-00001-of-00004.safetensors",
149
+ "backbone.blocks.2.pre_norm.scale": "model-00001-of-00004.safetensors",
150
+ "backbone.blocks.2.projections.bias": "model-00001-of-00004.safetensors",
151
+ "backbone.blocks.2.projections.weight": "model-00001-of-00004.safetensors",
152
+ "backbone.blocks.20.filter.D": "model-00002-of-00004.safetensors",
153
+ "backbone.blocks.20.filter.poles": "model-00002-of-00004.safetensors",
154
+ "backbone.blocks.20.filter.residues": "model-00002-of-00004.safetensors",
155
+ "backbone.blocks.20.filter.short_filter_bias": "model-00002-of-00004.safetensors",
156
+ "backbone.blocks.20.filter.short_filter_weight": "model-00002-of-00004.safetensors",
157
+ "backbone.blocks.20.mlp.l1.weight": "model-00003-of-00004.safetensors",
158
+ "backbone.blocks.20.mlp.l2.weight": "model-00003-of-00004.safetensors",
159
+ "backbone.blocks.20.mlp.l3.weight": "model-00003-of-00004.safetensors",
160
+ "backbone.blocks.20.out_filter_dense.bias": "model-00002-of-00004.safetensors",
161
+ "backbone.blocks.20.out_filter_dense.weight": "model-00002-of-00004.safetensors",
162
+ "backbone.blocks.20.post_norm.scale": "model-00002-of-00004.safetensors",
163
+ "backbone.blocks.20.pre_norm.scale": "model-00002-of-00004.safetensors",
164
+ "backbone.blocks.20.projections.bias": "model-00002-of-00004.safetensors",
165
+ "backbone.blocks.20.projections.weight": "model-00002-of-00004.safetensors",
166
+ "backbone.blocks.21.inner_mha_cls.Wqkv.weight": "model-00003-of-00004.safetensors",
167
+ "backbone.blocks.21.inner_mha_cls.out_proj.weight": "model-00003-of-00004.safetensors",
168
+ "backbone.blocks.21.inner_mha_cls.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
169
+ "backbone.blocks.21.mlp.l1.weight": "model-00003-of-00004.safetensors",
170
+ "backbone.blocks.21.mlp.l2.weight": "model-00003-of-00004.safetensors",
171
+ "backbone.blocks.21.mlp.l3.weight": "model-00003-of-00004.safetensors",
172
+ "backbone.blocks.21.post_norm.scale": "model-00003-of-00004.safetensors",
173
+ "backbone.blocks.21.pre_norm.scale": "model-00003-of-00004.safetensors",
174
+ "backbone.blocks.22.filter.D": "model-00003-of-00004.safetensors",
175
+ "backbone.blocks.22.filter.poles": "model-00003-of-00004.safetensors",
176
+ "backbone.blocks.22.filter.residues": "model-00003-of-00004.safetensors",
177
+ "backbone.blocks.22.filter.short_filter_bias": "model-00003-of-00004.safetensors",
178
+ "backbone.blocks.22.filter.short_filter_weight": "model-00003-of-00004.safetensors",
179
+ "backbone.blocks.22.mlp.l1.weight": "model-00003-of-00004.safetensors",
180
+ "backbone.blocks.22.mlp.l2.weight": "model-00003-of-00004.safetensors",
181
+ "backbone.blocks.22.mlp.l3.weight": "model-00003-of-00004.safetensors",
182
+ "backbone.blocks.22.out_filter_dense.bias": "model-00003-of-00004.safetensors",
183
+ "backbone.blocks.22.out_filter_dense.weight": "model-00003-of-00004.safetensors",
184
+ "backbone.blocks.22.post_norm.scale": "model-00003-of-00004.safetensors",
185
+ "backbone.blocks.22.pre_norm.scale": "model-00003-of-00004.safetensors",
186
+ "backbone.blocks.22.projections.bias": "model-00003-of-00004.safetensors",
187
+ "backbone.blocks.22.projections.weight": "model-00003-of-00004.safetensors",
188
+ "backbone.blocks.23.inner_mha_cls.Wqkv.weight": "model-00003-of-00004.safetensors",
189
+ "backbone.blocks.23.inner_mha_cls.out_proj.weight": "model-00003-of-00004.safetensors",
190
+ "backbone.blocks.23.inner_mha_cls.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
191
+ "backbone.blocks.23.mlp.l1.weight": "model-00003-of-00004.safetensors",
192
+ "backbone.blocks.23.mlp.l2.weight": "model-00003-of-00004.safetensors",
193
+ "backbone.blocks.23.mlp.l3.weight": "model-00003-of-00004.safetensors",
194
+ "backbone.blocks.23.post_norm.scale": "model-00003-of-00004.safetensors",
195
+ "backbone.blocks.23.pre_norm.scale": "model-00003-of-00004.safetensors",
196
+ "backbone.blocks.24.filter.D": "model-00003-of-00004.safetensors",
197
+ "backbone.blocks.24.filter.poles": "model-00003-of-00004.safetensors",
198
+ "backbone.blocks.24.filter.residues": "model-00003-of-00004.safetensors",
199
+ "backbone.blocks.24.filter.short_filter_bias": "model-00003-of-00004.safetensors",
200
+ "backbone.blocks.24.filter.short_filter_weight": "model-00003-of-00004.safetensors",
201
+ "backbone.blocks.24.mlp.l1.weight": "model-00003-of-00004.safetensors",
202
+ "backbone.blocks.24.mlp.l2.weight": "model-00003-of-00004.safetensors",
203
+ "backbone.blocks.24.mlp.l3.weight": "model-00003-of-00004.safetensors",
204
+ "backbone.blocks.24.out_filter_dense.bias": "model-00003-of-00004.safetensors",
205
+ "backbone.blocks.24.out_filter_dense.weight": "model-00003-of-00004.safetensors",
206
+ "backbone.blocks.24.post_norm.scale": "model-00003-of-00004.safetensors",
207
+ "backbone.blocks.24.pre_norm.scale": "model-00003-of-00004.safetensors",
208
+ "backbone.blocks.24.projections.bias": "model-00003-of-00004.safetensors",
209
+ "backbone.blocks.24.projections.weight": "model-00003-of-00004.safetensors",
210
+ "backbone.blocks.25.inner_mha_cls.Wqkv.weight": "model-00003-of-00004.safetensors",
211
+ "backbone.blocks.25.inner_mha_cls.out_proj.weight": "model-00003-of-00004.safetensors",
212
+ "backbone.blocks.25.inner_mha_cls.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
213
+ "backbone.blocks.25.mlp.l1.weight": "model-00003-of-00004.safetensors",
214
+ "backbone.blocks.25.mlp.l2.weight": "model-00003-of-00004.safetensors",
215
+ "backbone.blocks.25.mlp.l3.weight": "model-00003-of-00004.safetensors",
216
+ "backbone.blocks.25.post_norm.scale": "model-00003-of-00004.safetensors",
217
+ "backbone.blocks.25.pre_norm.scale": "model-00003-of-00004.safetensors",
218
+ "backbone.blocks.26.filter.D": "model-00003-of-00004.safetensors",
219
+ "backbone.blocks.26.filter.poles": "model-00003-of-00004.safetensors",
220
+ "backbone.blocks.26.filter.residues": "model-00003-of-00004.safetensors",
221
+ "backbone.blocks.26.filter.short_filter_bias": "model-00003-of-00004.safetensors",
222
+ "backbone.blocks.26.filter.short_filter_weight": "model-00003-of-00004.safetensors",
223
+ "backbone.blocks.26.mlp.l1.weight": "model-00003-of-00004.safetensors",
224
+ "backbone.blocks.26.mlp.l2.weight": "model-00003-of-00004.safetensors",
225
+ "backbone.blocks.26.mlp.l3.weight": "model-00003-of-00004.safetensors",
226
+ "backbone.blocks.26.out_filter_dense.bias": "model-00003-of-00004.safetensors",
227
+ "backbone.blocks.26.out_filter_dense.weight": "model-00003-of-00004.safetensors",
228
+ "backbone.blocks.26.post_norm.scale": "model-00003-of-00004.safetensors",
229
+ "backbone.blocks.26.pre_norm.scale": "model-00003-of-00004.safetensors",
230
+ "backbone.blocks.26.projections.bias": "model-00003-of-00004.safetensors",
231
+ "backbone.blocks.26.projections.weight": "model-00003-of-00004.safetensors",
232
+ "backbone.blocks.27.inner_mha_cls.Wqkv.weight": "model-00003-of-00004.safetensors",
233
+ "backbone.blocks.27.inner_mha_cls.out_proj.weight": "model-00003-of-00004.safetensors",
234
+ "backbone.blocks.27.inner_mha_cls.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
235
+ "backbone.blocks.27.mlp.l1.weight": "model-00003-of-00004.safetensors",
236
+ "backbone.blocks.27.mlp.l2.weight": "model-00003-of-00004.safetensors",
237
+ "backbone.blocks.27.mlp.l3.weight": "model-00003-of-00004.safetensors",
238
+ "backbone.blocks.27.post_norm.scale": "model-00003-of-00004.safetensors",
239
+ "backbone.blocks.27.pre_norm.scale": "model-00003-of-00004.safetensors",
240
+ "backbone.blocks.28.filter.D": "model-00003-of-00004.safetensors",
241
+ "backbone.blocks.28.filter.poles": "model-00003-of-00004.safetensors",
242
+ "backbone.blocks.28.filter.residues": "model-00003-of-00004.safetensors",
243
+ "backbone.blocks.28.filter.short_filter_bias": "model-00003-of-00004.safetensors",
244
+ "backbone.blocks.28.filter.short_filter_weight": "model-00003-of-00004.safetensors",
245
+ "backbone.blocks.28.mlp.l1.weight": "model-00003-of-00004.safetensors",
246
+ "backbone.blocks.28.mlp.l2.weight": "model-00003-of-00004.safetensors",
247
+ "backbone.blocks.28.mlp.l3.weight": "model-00003-of-00004.safetensors",
248
+ "backbone.blocks.28.out_filter_dense.bias": "model-00003-of-00004.safetensors",
249
+ "backbone.blocks.28.out_filter_dense.weight": "model-00003-of-00004.safetensors",
250
+ "backbone.blocks.28.post_norm.scale": "model-00003-of-00004.safetensors",
251
+ "backbone.blocks.28.pre_norm.scale": "model-00003-of-00004.safetensors",
252
+ "backbone.blocks.28.projections.bias": "model-00003-of-00004.safetensors",
253
+ "backbone.blocks.28.projections.weight": "model-00003-of-00004.safetensors",
254
+ "backbone.blocks.29.inner_mha_cls.Wqkv.weight": "model-00003-of-00004.safetensors",
255
+ "backbone.blocks.29.inner_mha_cls.out_proj.weight": "model-00003-of-00004.safetensors",
256
+ "backbone.blocks.29.inner_mha_cls.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
257
+ "backbone.blocks.29.mlp.l1.weight": "model-00003-of-00004.safetensors",
258
+ "backbone.blocks.29.mlp.l2.weight": "model-00003-of-00004.safetensors",
259
+ "backbone.blocks.29.mlp.l3.weight": "model-00003-of-00004.safetensors",
260
+ "backbone.blocks.29.post_norm.scale": "model-00003-of-00004.safetensors",
261
+ "backbone.blocks.29.pre_norm.scale": "model-00003-of-00004.safetensors",
262
+ "backbone.blocks.3.inner_mha_cls.Wqkv.weight": "model-00001-of-00004.safetensors",
263
+ "backbone.blocks.3.inner_mha_cls.out_proj.weight": "model-00001-of-00004.safetensors",
264
+ "backbone.blocks.3.inner_mha_cls.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
265
+ "backbone.blocks.3.mlp.l1.weight": "model-00001-of-00004.safetensors",
266
+ "backbone.blocks.3.mlp.l2.weight": "model-00001-of-00004.safetensors",
267
+ "backbone.blocks.3.mlp.l3.weight": "model-00001-of-00004.safetensors",
268
+ "backbone.blocks.3.post_norm.scale": "model-00001-of-00004.safetensors",
269
+ "backbone.blocks.3.pre_norm.scale": "model-00001-of-00004.safetensors",
270
+ "backbone.blocks.30.filter.D": "model-00003-of-00004.safetensors",
271
+ "backbone.blocks.30.filter.poles": "model-00003-of-00004.safetensors",
272
+ "backbone.blocks.30.filter.residues": "model-00003-of-00004.safetensors",
273
+ "backbone.blocks.30.filter.short_filter_bias": "model-00003-of-00004.safetensors",
274
+ "backbone.blocks.30.filter.short_filter_weight": "model-00003-of-00004.safetensors",
275
+ "backbone.blocks.30.mlp.l1.weight": "model-00003-of-00004.safetensors",
276
+ "backbone.blocks.30.mlp.l2.weight": "model-00003-of-00004.safetensors",
277
+ "backbone.blocks.30.mlp.l3.weight": "model-00003-of-00004.safetensors",
278
+ "backbone.blocks.30.out_filter_dense.bias": "model-00003-of-00004.safetensors",
279
+ "backbone.blocks.30.out_filter_dense.weight": "model-00003-of-00004.safetensors",
280
+ "backbone.blocks.30.post_norm.scale": "model-00003-of-00004.safetensors",
281
+ "backbone.blocks.30.pre_norm.scale": "model-00003-of-00004.safetensors",
282
+ "backbone.blocks.30.projections.bias": "model-00003-of-00004.safetensors",
283
+ "backbone.blocks.30.projections.weight": "model-00003-of-00004.safetensors",
284
+ "backbone.blocks.31.inner_mha_cls.Wqkv.weight": "model-00004-of-00004.safetensors",
285
+ "backbone.blocks.31.inner_mha_cls.out_proj.weight": "model-00004-of-00004.safetensors",
286
+ "backbone.blocks.31.inner_mha_cls.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
287
+ "backbone.blocks.31.mlp.l1.weight": "model-00004-of-00004.safetensors",
288
+ "backbone.blocks.31.mlp.l2.weight": "model-00004-of-00004.safetensors",
289
+ "backbone.blocks.31.mlp.l3.weight": "model-00004-of-00004.safetensors",
290
+ "backbone.blocks.31.post_norm.scale": "model-00003-of-00004.safetensors",
291
+ "backbone.blocks.31.pre_norm.scale": "model-00003-of-00004.safetensors",
292
+ "backbone.blocks.4.filter.D": "model-00001-of-00004.safetensors",
293
+ "backbone.blocks.4.filter.poles": "model-00001-of-00004.safetensors",
294
+ "backbone.blocks.4.filter.residues": "model-00001-of-00004.safetensors",
295
+ "backbone.blocks.4.filter.short_filter_bias": "model-00001-of-00004.safetensors",
296
+ "backbone.blocks.4.filter.short_filter_weight": "model-00001-of-00004.safetensors",
297
+ "backbone.blocks.4.mlp.l1.weight": "model-00001-of-00004.safetensors",
298
+ "backbone.blocks.4.mlp.l2.weight": "model-00001-of-00004.safetensors",
299
+ "backbone.blocks.4.mlp.l3.weight": "model-00001-of-00004.safetensors",
300
+ "backbone.blocks.4.out_filter_dense.bias": "model-00001-of-00004.safetensors",
301
+ "backbone.blocks.4.out_filter_dense.weight": "model-00001-of-00004.safetensors",
302
+ "backbone.blocks.4.post_norm.scale": "model-00001-of-00004.safetensors",
303
+ "backbone.blocks.4.pre_norm.scale": "model-00001-of-00004.safetensors",
304
+ "backbone.blocks.4.projections.bias": "model-00001-of-00004.safetensors",
305
+ "backbone.blocks.4.projections.weight": "model-00001-of-00004.safetensors",
306
+ "backbone.blocks.5.inner_mha_cls.Wqkv.weight": "model-00001-of-00004.safetensors",
307
+ "backbone.blocks.5.inner_mha_cls.out_proj.weight": "model-00001-of-00004.safetensors",
308
+ "backbone.blocks.5.inner_mha_cls.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
309
+ "backbone.blocks.5.mlp.l1.weight": "model-00001-of-00004.safetensors",
310
+ "backbone.blocks.5.mlp.l2.weight": "model-00001-of-00004.safetensors",
311
+ "backbone.blocks.5.mlp.l3.weight": "model-00001-of-00004.safetensors",
312
+ "backbone.blocks.5.post_norm.scale": "model-00001-of-00004.safetensors",
313
+ "backbone.blocks.5.pre_norm.scale": "model-00001-of-00004.safetensors",
314
+ "backbone.blocks.6.filter.D": "model-00001-of-00004.safetensors",
315
+ "backbone.blocks.6.filter.poles": "model-00001-of-00004.safetensors",
316
+ "backbone.blocks.6.filter.residues": "model-00001-of-00004.safetensors",
317
+ "backbone.blocks.6.filter.short_filter_bias": "model-00001-of-00004.safetensors",
318
+ "backbone.blocks.6.filter.short_filter_weight": "model-00001-of-00004.safetensors",
319
+ "backbone.blocks.6.mlp.l1.weight": "model-00001-of-00004.safetensors",
320
+ "backbone.blocks.6.mlp.l2.weight": "model-00001-of-00004.safetensors",
321
+ "backbone.blocks.6.mlp.l3.weight": "model-00001-of-00004.safetensors",
322
+ "backbone.blocks.6.out_filter_dense.bias": "model-00001-of-00004.safetensors",
323
+ "backbone.blocks.6.out_filter_dense.weight": "model-00001-of-00004.safetensors",
324
+ "backbone.blocks.6.post_norm.scale": "model-00001-of-00004.safetensors",
325
+ "backbone.blocks.6.pre_norm.scale": "model-00001-of-00004.safetensors",
326
+ "backbone.blocks.6.projections.bias": "model-00001-of-00004.safetensors",
327
+ "backbone.blocks.6.projections.weight": "model-00001-of-00004.safetensors",
328
+ "backbone.blocks.7.inner_mha_cls.Wqkv.weight": "model-00001-of-00004.safetensors",
329
+ "backbone.blocks.7.inner_mha_cls.out_proj.weight": "model-00001-of-00004.safetensors",
330
+ "backbone.blocks.7.inner_mha_cls.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
331
+ "backbone.blocks.7.mlp.l1.weight": "model-00001-of-00004.safetensors",
332
+ "backbone.blocks.7.mlp.l2.weight": "model-00001-of-00004.safetensors",
333
+ "backbone.blocks.7.mlp.l3.weight": "model-00001-of-00004.safetensors",
334
+ "backbone.blocks.7.post_norm.scale": "model-00001-of-00004.safetensors",
335
+ "backbone.blocks.7.pre_norm.scale": "model-00001-of-00004.safetensors",
336
+ "backbone.blocks.8.filter.D": "model-00001-of-00004.safetensors",
337
+ "backbone.blocks.8.filter.poles": "model-00001-of-00004.safetensors",
338
+ "backbone.blocks.8.filter.residues": "model-00001-of-00004.safetensors",
339
+ "backbone.blocks.8.filter.short_filter_bias": "model-00001-of-00004.safetensors",
340
+ "backbone.blocks.8.filter.short_filter_weight": "model-00001-of-00004.safetensors",
341
+ "backbone.blocks.8.mlp.l1.weight": "model-00001-of-00004.safetensors",
342
+ "backbone.blocks.8.mlp.l2.weight": "model-00001-of-00004.safetensors",
343
+ "backbone.blocks.8.mlp.l3.weight": "model-00001-of-00004.safetensors",
344
+ "backbone.blocks.8.out_filter_dense.bias": "model-00001-of-00004.safetensors",
345
+ "backbone.blocks.8.out_filter_dense.weight": "model-00001-of-00004.safetensors",
346
+ "backbone.blocks.8.post_norm.scale": "model-00001-of-00004.safetensors",
347
+ "backbone.blocks.8.pre_norm.scale": "model-00001-of-00004.safetensors",
348
+ "backbone.blocks.8.projections.bias": "model-00001-of-00004.safetensors",
349
+ "backbone.blocks.8.projections.weight": "model-00001-of-00004.safetensors",
350
+ "backbone.blocks.9.inner_mha_cls.Wqkv.weight": "model-00001-of-00004.safetensors",
351
+ "backbone.blocks.9.inner_mha_cls.out_proj.weight": "model-00001-of-00004.safetensors",
352
+ "backbone.blocks.9.inner_mha_cls.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
353
+ "backbone.blocks.9.mlp.l1.weight": "model-00001-of-00004.safetensors",
354
+ "backbone.blocks.9.mlp.l2.weight": "model-00002-of-00004.safetensors",
355
+ "backbone.blocks.9.mlp.l3.weight": "model-00002-of-00004.safetensors",
356
+ "backbone.blocks.9.post_norm.scale": "model-00001-of-00004.safetensors",
357
+ "backbone.blocks.9.pre_norm.scale": "model-00001-of-00004.safetensors",
358
+ "backbone.embedding_layer.weight": "model-00001-of-00004.safetensors",
359
+ "backbone.norm.scale": "model-00001-of-00004.safetensors",
360
+ "backbone.unembed.weight": "model-00001-of-00004.safetensors"
361
+ }
362
+ }
modeling_hyena.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """StripedHyena custom code port for the Hugging Face Hub"""
3
+
4
+ import torch
5
+ from torch.nn import functional as F
6
+ from .configuration_hyena import StripedHyenaConfig
7
+ from transformers import PreTrainedModel
8
+ from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
9
+ from transformers.utils import logging
10
+ from typing import Optional, Tuple, Union
11
+ from .model import StripedHyena
12
+ from .utils import dotdict
13
+ from .cache import InferenceParams
14
+ from .engine import HyenaInferenceEngine
15
+ from .layers import RMSNorm
16
+ from .utils import dotdict, column_split
17
+
18
+ logger = logging.get_logger(__name__)
19
+
20
+
21
+ class StripedHyenaPreTrainedModel(PreTrainedModel):
22
+ config_class = StripedHyenaConfig
23
+ base_model_prefix = "sh"
24
+ supports_gradient_checkpointing = False
25
+ _no_split_modules = ["AttentionBlock", "ParallelGatedConvBlock"]
26
+ _skip_keys_device_placement = "past_key_values"
27
+ _keys_to_ignore_on_load_missing = [r"freq"]
28
+ _keys_to_ignore_on_load_unexpected = [r"fftconv", r"twiddle_factors"]
29
+ _supports_flash_attn_2 = True
30
+
31
+
32
+ class StripedHyenaModelForCausalLM(StripedHyenaPreTrainedModel):
33
+ supports_gradient_checkpointing = True
34
+
35
+ def __init__(self, config, **kwargs):
36
+ super().__init__(config, **kwargs)
37
+ model_config = dotdict(config.to_dict())
38
+ self.backbone = StripedHyena(model_config)
39
+ self.backbone.gradient_checkpointing = False
40
+ self.config = config
41
+ vocab_size = config.vocab_size
42
+ if vocab_size % config.make_vocab_size_divisible_by != 0:
43
+ vocab_size += config.make_vocab_size_divisible_by - (
44
+ vocab_size % config.make_vocab_size_divisible_by
45
+ )
46
+ self.vocab_size = vocab_size
47
+ self.post_init()
48
+ self.force_dtype()
49
+
50
+ def force_dtype(self):
51
+ self.backbone.to_bfloat16_except_poles_residues()
52
+
53
+ def _set_gradient_checkpointing(self, enable, gradient_checkpointing_func):
54
+ self.backbone.gradient_checkpointing = enable
55
+
56
+ def get_input_embeddings(self):
57
+ return self.backbone.embedding_layer
58
+
59
+ def forward(
60
+ self,
61
+ input_ids: torch.LongTensor = None,
62
+ attention_mask: Optional[torch.LongTensor] = None,
63
+ labels: Optional[torch.LongTensor] = None,
64
+ use_cache: Optional[bool] = None,
65
+ output_attentions: Optional[bool] = None,
66
+ output_hidden_states: Optional[bool] = None,
67
+ past_key_values=None,
68
+ inputs_embeds: Optional[torch.FloatTensor] = None,
69
+ return_dict: Optional[bool] = None,
70
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
71
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
72
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
73
+
74
+ if use_cache:
75
+ if self.backbone.gradient_checkpointing and self.backbone.training:
76
+ logger.warning_once(
77
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
78
+ )
79
+ use_cache = False
80
+ elif labels is not None:
81
+ logger.warning_once(
82
+ "`use_cache=True` is incompatible with loss calculation. Setting `use_cache=False`..."
83
+ )
84
+ use_cache = False
85
+
86
+ inputs = input_ids
87
+ if use_cache:
88
+ if past_key_values is None:
89
+ past_key_values = self.backbone.initialize_inference_params()
90
+
91
+ batch_size = input_ids.shape[0]
92
+ past_key_values["mha"].max_batch_size = batch_size
93
+ past_key_values["hyena"].max_batch_size = batch_size
94
+ else:
95
+ seqlen_offset = past_key_values["mha"].seqlen_offset
96
+ if seqlen_offset == 0:
97
+ # second loop through generate will have prompt_len + 1 as seqlen
98
+ seqlen_offset = input_ids.shape[-1] - 1
99
+ past_key_values["hyena"].seqlen_offset = seqlen_offset
100
+ past_key_values["mha"].seqlen_offset = seqlen_offset
101
+ else:
102
+ past_key_values["mha"].seqlen_offset += 1
103
+ past_key_values["hyena"].seqlen_offset += 1
104
+
105
+ inputs = input_ids[
106
+ :,
107
+ -1:,
108
+ ]
109
+
110
+ logits, past_key_values = self.backbone(
111
+ inputs,
112
+ padding_mask=attention_mask,
113
+ inference_params_dict=past_key_values if use_cache else None,
114
+ )
115
+
116
+ loss = None
117
+ if labels is not None:
118
+ shift_logits = logits[..., :-1, :].contiguous()
119
+ shift_labels = labels[..., 1:].contiguous()
120
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
121
+ shift_labels = shift_labels.view(-1)
122
+ shift_labels = shift_labels.to(shift_logits.device)
123
+ loss = F.cross_entropy(shift_logits, shift_labels)
124
+
125
+ if return_dict:
126
+ return CausalLMOutputWithPast(
127
+ logits=logits,
128
+ hidden_states=None,
129
+ past_key_values=past_key_values if use_cache else None,
130
+ loss=loss,
131
+ )
132
+ else:
133
+ return logits
134
+
135
+ @classmethod
136
+ def can_generate(cls) -> bool:
137
+ return True
138
+
139
+ def prepare_inputs_for_generation(
140
+ self, input_ids, attention_mask=None, past_key_values=None, **kwargs
141
+ ):
142
+ return {
143
+ "input_ids": input_ids,
144
+ "attention_mask": attention_mask,
145
+ "past_key_values": past_key_values,
146
+ }
utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def column_split(x, num_heads, head_size):
5
+ """Split a tensor with `num_heads` alongside the head dimension, instead of
6
+ across heads. Fixed to three projections
7
+ """
8
+
9
+ x_reshaped = x.reshape(
10
+ x.shape[0],
11
+ num_heads,
12
+ 3 * head_size,
13
+ )
14
+
15
+ x2, x1, v = (
16
+ x_reshaped[:, :, :head_size],
17
+ x_reshaped[
18
+ :,
19
+ :,
20
+ head_size : 2 * head_size,
21
+ ],
22
+ x_reshaped[:, :, 2 * head_size :],
23
+ )
24
+ x2, x1, v = (
25
+ x2.reshape(x2.shape[0], -1),
26
+ x1.reshape(x1.shape[0], -1),
27
+ v.reshape(v.shape[0], -1),
28
+ )
29
+ return x2, x1, v
30
+
31
+
32
+ def get_init_from_string(init_str):
33
+ if type(init_str) == str:
34
+ if init_str == "torch.nn.init.zeros_":
35
+ return torch.nn.init.zeros_
36
+ elif init_str == "torch.nn.init.xavier_uniform_":
37
+ return torch.nn.init.xavier_uniform_
38
+ elif init_str == "torch.nn.init.xavier_normal_":
39
+ return torch.nn.init.xavier_normal_
40
+ else:
41
+ raise ValueError(f"Unrecognized init {init_str}")
42
+
43
+
44
+ def print_rank_0(message, debug=False, end="\n"):
45
+ """Print from rank 0 only."""
46
+ if torch.distributed.is_initialized():
47
+ if torch.distributed.get_rank() == 0:
48
+ print(message, flush=True, end=end)
49
+ else:
50
+ print(message, flush=True, end=end)
51
+
52
+
53
+ class dotdict(dict):
54
+ """dot.notation access to dictionary attributes"""
55
+
56
+ __getattr__ = dict.get
57
+ __setattr__ = dict.__setitem__
58
+ __delattr__ = dict.__delitem__
59
+
60
+
61
+ def ensure_divisibility(numerator, denominator):
62
+ """Ensure that numerator is divisible by the denominator."""
63
+ assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
64
+
65
+
66
+ def divide(numerator, denominator):
67
+ """Ensure that numerator is divisible by the denominator and return
68
+ the division value."""
69
+ ensure_divisibility(numerator, denominator)
70
+ return numerator // denominator
71
+
72
+
73
+ class VocabUtility:
74
+ """Split the vocabulary into `world_size` chunks amd return the
75
+ first and last index of the vocabulary belonging to the `rank`
76
+ partition: Note that indices in [first, last]"""
77
+
78
+ @staticmethod
79
+ def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
80
+ index_f = rank * per_partition_vocab_size
81
+ index_l = index_f + per_partition_vocab_size
82
+ return index_f, index_l
83
+
84
+ @staticmethod
85
+ def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
86
+ per_partition_vocab_size = divide(global_vocab_size, world_size)
87
+ return VocabUtility.vocab_range_from_per_partition_vocab_size(
88
+ per_partition_vocab_size, rank, world_size
89
+ )