yairschiff
commited on
Enable mambav2 compat
Browse files- modeling_caduceus.py +29 -16
modeling_caduceus.py
CHANGED
@@ -2,21 +2,29 @@
|
|
2 |
|
3 |
"""
|
4 |
|
|
|
5 |
import math
|
6 |
from functools import partial
|
7 |
from typing import Optional, Tuple, Union
|
8 |
|
9 |
import torch
|
10 |
-
from mamba_ssm.modules.mamba_simple import Mamba
|
|
|
|
|
|
|
|
|
11 |
from torch import nn
|
12 |
from torch.nn import functional as F
|
13 |
from transformers import PreTrainedModel
|
14 |
from transformers.modeling_outputs import BaseModelOutputWithNoAttention, MaskedLMOutput, SequenceClassifierOutput
|
15 |
|
16 |
try:
|
17 |
-
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
|
18 |
except ImportError:
|
19 |
-
|
|
|
|
|
|
|
20 |
|
21 |
from .configuration_caduceus import CaduceusConfig
|
22 |
from .modeling_rcps import RCPSAddNormWrapper, RCPSEmbedding, RCPSLMHead, RCPSMambaBlock
|
@@ -54,13 +62,24 @@ def create_block(
|
|
54 |
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
55 |
)
|
56 |
block_cls = RCPSMambaBlock if rcps else Block
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
block.layer_idx = layer_idx
|
65 |
return block
|
66 |
|
@@ -497,12 +516,6 @@ class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
|
|
497 |
|
498 |
# Initialize weights and apply final processing
|
499 |
self.post_init()
|
500 |
-
self.init_scorer()
|
501 |
-
|
502 |
-
def init_scorer(self, initializer_range=0.02):
|
503 |
-
initializer_range = self.config.initializer_cfg.get("initializer_range", initializer_range) \
|
504 |
-
if self.config.initializer_cfg is not None else initializer_range
|
505 |
-
self.score.weight.data.normal_(std=initializer_range)
|
506 |
|
507 |
def get_input_embeddings(self):
|
508 |
return self.caduceus.backbone.embeddings.word_embeddings
|
|
|
2 |
|
3 |
"""
|
4 |
|
5 |
+
import inspect
|
6 |
import math
|
7 |
from functools import partial
|
8 |
from typing import Optional, Tuple, Union
|
9 |
|
10 |
import torch
|
11 |
+
from mamba_ssm.modules.mamba_simple import Mamba
|
12 |
+
try:
|
13 |
+
from mamba_ssm.modules.mamba_simple import Block # Legacy mambav1 file structure
|
14 |
+
except ImportError:
|
15 |
+
from mamba_ssm.modules.block import Block # mambav2 file structure
|
16 |
from torch import nn
|
17 |
from torch.nn import functional as F
|
18 |
from transformers import PreTrainedModel
|
19 |
from transformers.modeling_outputs import BaseModelOutputWithNoAttention, MaskedLMOutput, SequenceClassifierOutput
|
20 |
|
21 |
try:
|
22 |
+
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn # Legacy mambav1 file structure
|
23 |
except ImportError:
|
24 |
+
try:
|
25 |
+
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn # mambav2 file structure
|
26 |
+
except ImportError:
|
27 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
28 |
|
29 |
from .configuration_caduceus import CaduceusConfig
|
30 |
from .modeling_rcps import RCPSAddNormWrapper, RCPSEmbedding, RCPSLMHead, RCPSMambaBlock
|
|
|
62 |
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
63 |
)
|
64 |
block_cls = RCPSMambaBlock if rcps else Block
|
65 |
+
# mambav2 compatibility
|
66 |
+
if "mlp_cls" in inspect.signature(block_cls.__init__).parameters:
|
67 |
+
block = block_cls(
|
68 |
+
d_model,
|
69 |
+
mixer_cls,
|
70 |
+
mlp_cls=nn.Identity,
|
71 |
+
norm_cls=norm_cls,
|
72 |
+
fused_add_norm=fused_add_norm,
|
73 |
+
residual_in_fp32=residual_in_fp32,
|
74 |
+
)
|
75 |
+
else:
|
76 |
+
block = block_cls(
|
77 |
+
d_model,
|
78 |
+
mixer_cls,
|
79 |
+
norm_cls=norm_cls,
|
80 |
+
fused_add_norm=fused_add_norm,
|
81 |
+
residual_in_fp32=residual_in_fp32,
|
82 |
+
)
|
83 |
block.layer_idx = layer_idx
|
84 |
return block
|
85 |
|
|
|
516 |
|
517 |
# Initialize weights and apply final processing
|
518 |
self.post_init()
|
|
|
|
|
|
|
|
|
|
|
|
|
519 |
|
520 |
def get_input_embeddings(self):
|
521 |
return self.caduceus.backbone.embeddings.word_embeddings
|