Commit
·
f92de46
1
Parent(s):
e46abf8
Upload HyenaDNAForCausalLM
Browse files- modeling_hyena.py +22 -17
modeling_hyena.py
CHANGED
@@ -19,8 +19,8 @@ def fftconv(u, k, D):
|
|
19 |
seqlen = u.shape[-1]
|
20 |
fft_size = 2 * seqlen
|
21 |
|
22 |
-
k_f = torch.fft.rfft(k, n=fft_size) / fft_size
|
23 |
-
u_f = torch.fft.rfft(u.to(dtype=
|
24 |
|
25 |
if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
|
26 |
y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]
|
@@ -60,11 +60,9 @@ class HyenaPositionalEmbedding(nn.Module):
|
|
60 |
w = 2 * math.pi * t_rescaled / self.seq_len # 1, L, 1
|
61 |
|
62 |
f = torch.linspace(1e-4, bands - 1, bands)[None, None]
|
63 |
-
|
64 |
-
|
65 |
-
z
|
66 |
-
z = torch.cat([t, z.real, z.imag], dim=-1)
|
67 |
-
# TODO Set z's LR to lr_pos_emb
|
68 |
self.z = nn.Parameter(z, requires_grad=True)
|
69 |
self.register_buffer("t", t)
|
70 |
|
@@ -147,7 +145,7 @@ class HyenaFilter(nn.Module):
|
|
147 |
|
148 |
def filter(self, L, *args, **kwargs):
|
149 |
z, t = self.pos_emb(L)
|
150 |
-
h = self.implicit_filter(z)
|
151 |
h = self.modulation(t, h)
|
152 |
return h
|
153 |
|
@@ -349,8 +347,15 @@ class HyenaDNAPreTrainedModel(PreTrainedModel):
|
|
349 |
supports_gradient_checkpointing = True
|
350 |
_no_split_modules = ["HyenaBlock"]
|
351 |
_skip_keys_device_placement = "past_key_values"
|
352 |
-
|
353 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
355 |
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
356 |
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
@@ -368,8 +373,8 @@ class HyenaDNAPreTrainedModel(PreTrainedModel):
|
|
368 |
|
369 |
|
370 |
class HyenaDNAModel(HyenaDNAPreTrainedModel):
|
371 |
-
def __init__(self, config) -> None:
|
372 |
-
super().__init__(config)
|
373 |
|
374 |
self.backbone = HyenaLMBackbone(config)
|
375 |
self.config = config
|
@@ -395,8 +400,8 @@ class HyenaDNAModel(HyenaDNAPreTrainedModel):
|
|
395 |
|
396 |
class HyenaDNAForCausalLM(HyenaDNAPreTrainedModel):
|
397 |
|
398 |
-
def __init__(self, config):
|
399 |
-
super().__init__(config)
|
400 |
self.hyena = HyenaDNAModel(config)
|
401 |
vocab_size = config.vocab_size
|
402 |
if vocab_size % config.pad_vocab_size_multiple != 0:
|
@@ -476,9 +481,9 @@ class HyenaDNAForCausalLM(HyenaDNAPreTrainedModel):
|
|
476 |
|
477 |
|
478 |
class HyenaDNAForSequenceClassification(HyenaDNAPreTrainedModel):
|
479 |
-
def __init__(self, config):
|
480 |
-
super().__init__(config)
|
481 |
-
self.num_labels = config.num_labels
|
482 |
self.hyena = HyenaDNAModel(config)
|
483 |
self.score = nn.Linear(config.d_model, self.num_labels, bias=False)
|
484 |
|
|
|
19 |
seqlen = u.shape[-1]
|
20 |
fft_size = 2 * seqlen
|
21 |
|
22 |
+
k_f = torch.fft.rfft(k.to(torch.float32), n=fft_size) / fft_size
|
23 |
+
u_f = torch.fft.rfft(u.to(dtype=torch.float32), n=fft_size)
|
24 |
|
25 |
if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
|
26 |
y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]
|
|
|
60 |
w = 2 * math.pi * t_rescaled / self.seq_len # 1, L, 1
|
61 |
|
62 |
f = torch.linspace(1e-4, bands - 1, bands)[None, None]
|
63 |
+
|
64 |
+
z = torch.cat([t, torch.cos(-f * w), torch.sin(-f * w)], dim=-1)
|
65 |
+
# The original code sets z's LR to lr_pos_emb, which is 1e-5 by default
|
|
|
|
|
66 |
self.z = nn.Parameter(z, requires_grad=True)
|
67 |
self.register_buffer("t", t)
|
68 |
|
|
|
145 |
|
146 |
def filter(self, L, *args, **kwargs):
|
147 |
z, t = self.pos_emb(L)
|
148 |
+
h = self.implicit_filter(z.to(dtype=self.implicit_filter[0].weight.dtype))
|
149 |
h = self.modulation(t, h)
|
150 |
return h
|
151 |
|
|
|
347 |
supports_gradient_checkpointing = True
|
348 |
_no_split_modules = ["HyenaBlock"]
|
349 |
_skip_keys_device_placement = "past_key_values"
|
350 |
+
_keys_to_ignore_on_load_missing = [r"freq"] # Shared tensors that safetensors merges
|
351 |
+
|
352 |
+
def _init_weights(self, module, initializer_range=0.02):
|
353 |
+
if isinstance(module, nn.Linear):
|
354 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
355 |
+
if module.bias is not None:
|
356 |
+
nn.init.zeros_(module.bias)
|
357 |
+
elif isinstance(module, nn.Embedding):
|
358 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
359 |
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
360 |
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
361 |
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
|
|
373 |
|
374 |
|
375 |
class HyenaDNAModel(HyenaDNAPreTrainedModel):
|
376 |
+
def __init__(self, config, **kwargs) -> None:
|
377 |
+
super().__init__(config, **kwargs)
|
378 |
|
379 |
self.backbone = HyenaLMBackbone(config)
|
380 |
self.config = config
|
|
|
400 |
|
401 |
class HyenaDNAForCausalLM(HyenaDNAPreTrainedModel):
|
402 |
|
403 |
+
def __init__(self, config, **kwargs):
|
404 |
+
super().__init__(config, **kwargs)
|
405 |
self.hyena = HyenaDNAModel(config)
|
406 |
vocab_size = config.vocab_size
|
407 |
if vocab_size % config.pad_vocab_size_multiple != 0:
|
|
|
481 |
|
482 |
|
483 |
class HyenaDNAForSequenceClassification(HyenaDNAPreTrainedModel):
|
484 |
+
def __init__(self, config, **kwargs):
|
485 |
+
super().__init__(config, **kwargs)
|
486 |
+
self.num_labels = kwargs.get("num_labels", config.num_labels)
|
487 |
self.hyena = HyenaDNAModel(config)
|
488 |
self.score = nn.Linear(config.d_model, self.num_labels, bias=False)
|
489 |
|