yairschiff
commited on
Ensure weights are tied for BiMamba (if applicable) when loaded from_pretrained
Browse files- modeling_caduceus.py +31 -4
modeling_caduceus.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
"""Caduceus model for Hugging Face.
|
2 |
-
|
3 |
"""
|
4 |
|
5 |
import inspect
|
@@ -46,7 +45,6 @@ def create_block(
|
|
46 |
dtype=None,
|
47 |
):
|
48 |
"""Create Caduceus block.
|
49 |
-
|
50 |
Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py
|
51 |
"""
|
52 |
if ssm_cfg is None:
|
@@ -121,7 +119,6 @@ class BiMambaWrapper(nn.Module):
|
|
121 |
|
122 |
def forward(self, hidden_states, inference_params=None):
|
123 |
"""Bidirectional-enabled forward pass
|
124 |
-
|
125 |
hidden_states: (B, L, D)
|
126 |
Returns: same shape as hidden_states
|
127 |
"""
|
@@ -360,6 +357,24 @@ class Caduceus(CaduceusPreTrainedModel):
|
|
360 |
factory_kwargs = {"device": device, "dtype": dtype}
|
361 |
self.backbone = CaduceusMixerModel(config, **factory_kwargs, **kwargs)
|
362 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
def forward(
|
364 |
self,
|
365 |
input_ids: torch.LongTensor = None,
|
@@ -431,8 +446,12 @@ class CaduceusForMaskedLM(CaduceusPreTrainedModel):
|
|
431 |
raise NotImplementedError("Setting output embeddings for RCPS LM is not supported.")
|
432 |
self.lm_head = new_embeddings
|
433 |
|
|
|
|
|
|
|
434 |
def tie_weights(self):
|
435 |
"""Tie weights, accounting for RCPS."""
|
|
|
436 |
if self.config.rcps:
|
437 |
self.lm_head.set_weight(self.get_input_embeddings().weight)
|
438 |
else:
|
@@ -445,7 +464,7 @@ class CaduceusForMaskedLM(CaduceusPreTrainedModel):
|
|
445 |
def set_decoder(self, decoder):
|
446 |
"""Set decoder (backbone) for the model."""
|
447 |
self.caduceus = decoder
|
448 |
-
|
449 |
def forward(
|
450 |
self,
|
451 |
input_ids: torch.LongTensor = None,
|
@@ -536,6 +555,13 @@ class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
|
|
536 |
if self.pooling_strategy == "first": # Use embedding of first token in the sequence
|
537 |
return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[0, ...]
|
538 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
539 |
def forward(
|
540 |
self,
|
541 |
input_ids: torch.LongTensor = None,
|
@@ -543,6 +569,7 @@ class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
|
|
543 |
labels: Optional[torch.LongTensor] = None,
|
544 |
output_hidden_states: Optional[bool] = None,
|
545 |
return_dict: Optional[bool] = None,
|
|
|
546 |
) -> Union[Tuple, SequenceClassifierOutput]:
|
547 |
r"""
|
548 |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
|
1 |
"""Caduceus model for Hugging Face.
|
|
|
2 |
"""
|
3 |
|
4 |
import inspect
|
|
|
45 |
dtype=None,
|
46 |
):
|
47 |
"""Create Caduceus block.
|
|
|
48 |
Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py
|
49 |
"""
|
50 |
if ssm_cfg is None:
|
|
|
119 |
|
120 |
def forward(self, hidden_states, inference_params=None):
|
121 |
"""Bidirectional-enabled forward pass
|
|
|
122 |
hidden_states: (B, L, D)
|
123 |
Returns: same shape as hidden_states
|
124 |
"""
|
|
|
357 |
factory_kwargs = {"device": device, "dtype": dtype}
|
358 |
self.backbone = CaduceusMixerModel(config, **factory_kwargs, **kwargs)
|
359 |
|
360 |
+
def maybe_weight_tie_mamba(self):
|
361 |
+
if getattr(self.config, 'bidirectional', False) and getattr(self.config, 'bidirectional_weight_tie', False):
|
362 |
+
if getattr(self.config, 'rcps', False):
|
363 |
+
for layer in self.backbone.layers:
|
364 |
+
layer.mixer.submodule.mamba_rev.in_proj.weight = layer.mixer.submodule.mamba_fwd.in_proj.weight
|
365 |
+
layer.mixer.submodule.mamba_rev.in_proj.bias = layer.mixer.submodule.mamba_fwd.in_proj.bias
|
366 |
+
layer.mixer.submodule.mamba_rev.out_proj.weight = layer.mixer.submodule.mamba_fwd.out_proj.weight
|
367 |
+
layer.mixer.submodule.mamba_rev.out_proj.bias = layer.mixer.submodule.mamba_fwd.out_proj.bias
|
368 |
+
else:
|
369 |
+
for layer in self.backbone.layers:
|
370 |
+
layer.mixer.mamba_rev.in_proj.weight = layer.mixer.mamba_fwd.in_proj.weight
|
371 |
+
layer.mixer.mamba_rev.in_proj.bias = layer.mixer.mamba_fwd.in_proj.bias
|
372 |
+
layer.mixer.mamba_rev.out_proj.weight = layer.mixer.mamba_fwd.out_proj.weight
|
373 |
+
layer.mixer.mamba_rev.out_proj.bias = layer.mixer.mamba_fwd.out_proj.bias
|
374 |
+
|
375 |
+
def tie_weights(self):
|
376 |
+
self.maybe_weight_tie_mamba()
|
377 |
+
|
378 |
def forward(
|
379 |
self,
|
380 |
input_ids: torch.LongTensor = None,
|
|
|
446 |
raise NotImplementedError("Setting output embeddings for RCPS LM is not supported.")
|
447 |
self.lm_head = new_embeddings
|
448 |
|
449 |
+
def maybe_weight_tie_mamba(self):
|
450 |
+
self.caduceus.maybe_weight_tie_mamba()
|
451 |
+
|
452 |
def tie_weights(self):
|
453 |
"""Tie weights, accounting for RCPS."""
|
454 |
+
self.maybe_weight_tie_mamba()
|
455 |
if self.config.rcps:
|
456 |
self.lm_head.set_weight(self.get_input_embeddings().weight)
|
457 |
else:
|
|
|
464 |
def set_decoder(self, decoder):
|
465 |
"""Set decoder (backbone) for the model."""
|
466 |
self.caduceus = decoder
|
467 |
+
|
468 |
def forward(
|
469 |
self,
|
470 |
input_ids: torch.LongTensor = None,
|
|
|
555 |
if self.pooling_strategy == "first": # Use embedding of first token in the sequence
|
556 |
return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[0, ...]
|
557 |
|
558 |
+
def maybe_weight_tie_mamba(self):
|
559 |
+
self.caduceus.maybe_weight_tie_mamba()
|
560 |
+
|
561 |
+
def tie_weights(self):
|
562 |
+
self.maybe_weight_tie_mamba()
|
563 |
+
super().tie_weights()
|
564 |
+
|
565 |
def forward(
|
566 |
self,
|
567 |
input_ids: torch.LongTensor = None,
|
|
|
569 |
labels: Optional[torch.LongTensor] = None,
|
570 |
output_hidden_states: Optional[bool] = None,
|
571 |
return_dict: Optional[bool] = None,
|
572 |
+
**kwargs,
|
573 |
) -> Union[Tuple, SequenceClassifierOutput]:
|
574 |
r"""
|
575 |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|