yairschiff commited on
Commit
e3f8061
·
verified ·
1 Parent(s): 0600f8f

Ensure weights are tied for BiMamba (if applicable) when loaded from_pretrained

Browse files
Files changed (1) hide show
  1. modeling_caduceus.py +31 -4
modeling_caduceus.py CHANGED
@@ -1,5 +1,4 @@
1
  """Caduceus model for Hugging Face.
2
-
3
  """
4
 
5
  import inspect
@@ -46,7 +45,6 @@ def create_block(
46
  dtype=None,
47
  ):
48
  """Create Caduceus block.
49
-
50
  Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py
51
  """
52
  if ssm_cfg is None:
@@ -121,7 +119,6 @@ class BiMambaWrapper(nn.Module):
121
 
122
  def forward(self, hidden_states, inference_params=None):
123
  """Bidirectional-enabled forward pass
124
-
125
  hidden_states: (B, L, D)
126
  Returns: same shape as hidden_states
127
  """
@@ -360,6 +357,24 @@ class Caduceus(CaduceusPreTrainedModel):
360
  factory_kwargs = {"device": device, "dtype": dtype}
361
  self.backbone = CaduceusMixerModel(config, **factory_kwargs, **kwargs)
362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  def forward(
364
  self,
365
  input_ids: torch.LongTensor = None,
@@ -431,8 +446,12 @@ class CaduceusForMaskedLM(CaduceusPreTrainedModel):
431
  raise NotImplementedError("Setting output embeddings for RCPS LM is not supported.")
432
  self.lm_head = new_embeddings
433
 
 
 
 
434
  def tie_weights(self):
435
  """Tie weights, accounting for RCPS."""
 
436
  if self.config.rcps:
437
  self.lm_head.set_weight(self.get_input_embeddings().weight)
438
  else:
@@ -445,7 +464,7 @@ class CaduceusForMaskedLM(CaduceusPreTrainedModel):
445
  def set_decoder(self, decoder):
446
  """Set decoder (backbone) for the model."""
447
  self.caduceus = decoder
448
-
449
  def forward(
450
  self,
451
  input_ids: torch.LongTensor = None,
@@ -536,6 +555,13 @@ class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
536
  if self.pooling_strategy == "first": # Use embedding of first token in the sequence
537
  return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[0, ...]
538
 
 
 
 
 
 
 
 
539
  def forward(
540
  self,
541
  input_ids: torch.LongTensor = None,
@@ -543,6 +569,7 @@ class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
543
  labels: Optional[torch.LongTensor] = None,
544
  output_hidden_states: Optional[bool] = None,
545
  return_dict: Optional[bool] = None,
 
546
  ) -> Union[Tuple, SequenceClassifierOutput]:
547
  r"""
548
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
1
  """Caduceus model for Hugging Face.
 
2
  """
3
 
4
  import inspect
 
45
  dtype=None,
46
  ):
47
  """Create Caduceus block.
 
48
  Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py
49
  """
50
  if ssm_cfg is None:
 
119
 
120
  def forward(self, hidden_states, inference_params=None):
121
  """Bidirectional-enabled forward pass
 
122
  hidden_states: (B, L, D)
123
  Returns: same shape as hidden_states
124
  """
 
357
  factory_kwargs = {"device": device, "dtype": dtype}
358
  self.backbone = CaduceusMixerModel(config, **factory_kwargs, **kwargs)
359
 
360
+ def maybe_weight_tie_mamba(self):
361
+ if getattr(self.config, 'bidirectional', False) and getattr(self.config, 'bidirectional_weight_tie', False):
362
+ if getattr(self.config, 'rcps', False):
363
+ for layer in self.backbone.layers:
364
+ layer.mixer.submodule.mamba_rev.in_proj.weight = layer.mixer.submodule.mamba_fwd.in_proj.weight
365
+ layer.mixer.submodule.mamba_rev.in_proj.bias = layer.mixer.submodule.mamba_fwd.in_proj.bias
366
+ layer.mixer.submodule.mamba_rev.out_proj.weight = layer.mixer.submodule.mamba_fwd.out_proj.weight
367
+ layer.mixer.submodule.mamba_rev.out_proj.bias = layer.mixer.submodule.mamba_fwd.out_proj.bias
368
+ else:
369
+ for layer in self.backbone.layers:
370
+ layer.mixer.mamba_rev.in_proj.weight = layer.mixer.mamba_fwd.in_proj.weight
371
+ layer.mixer.mamba_rev.in_proj.bias = layer.mixer.mamba_fwd.in_proj.bias
372
+ layer.mixer.mamba_rev.out_proj.weight = layer.mixer.mamba_fwd.out_proj.weight
373
+ layer.mixer.mamba_rev.out_proj.bias = layer.mixer.mamba_fwd.out_proj.bias
374
+
375
+ def tie_weights(self):
376
+ self.maybe_weight_tie_mamba()
377
+
378
  def forward(
379
  self,
380
  input_ids: torch.LongTensor = None,
 
446
  raise NotImplementedError("Setting output embeddings for RCPS LM is not supported.")
447
  self.lm_head = new_embeddings
448
 
449
+ def maybe_weight_tie_mamba(self):
450
+ self.caduceus.maybe_weight_tie_mamba()
451
+
452
  def tie_weights(self):
453
  """Tie weights, accounting for RCPS."""
454
+ self.maybe_weight_tie_mamba()
455
  if self.config.rcps:
456
  self.lm_head.set_weight(self.get_input_embeddings().weight)
457
  else:
 
464
  def set_decoder(self, decoder):
465
  """Set decoder (backbone) for the model."""
466
  self.caduceus = decoder
467
+
468
  def forward(
469
  self,
470
  input_ids: torch.LongTensor = None,
 
555
  if self.pooling_strategy == "first": # Use embedding of first token in the sequence
556
  return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[0, ...]
557
 
558
+ def maybe_weight_tie_mamba(self):
559
+ self.caduceus.maybe_weight_tie_mamba()
560
+
561
+ def tie_weights(self):
562
+ self.maybe_weight_tie_mamba()
563
+ super().tie_weights()
564
+
565
  def forward(
566
  self,
567
  input_ids: torch.LongTensor = None,
 
569
  labels: Optional[torch.LongTensor] = None,
570
  output_hidden_states: Optional[bool] = None,
571
  return_dict: Optional[bool] = None,
572
+ **kwargs,
573
  ) -> Union[Tuple, SequenceClassifierOutput]:
574
  r"""
575
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):