Jackmin108 commited on
Commit
c120089
·
1 Parent(s): 027219a

feat: rotary base as a property

Browse files

Signed-off-by: Meow <[email protected]>

configuration_xlm_roberta.py CHANGED
@@ -20,6 +20,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
20
  bos_token_id=0,
21
  eos_token_id=2,
22
  position_embedding_type="absolute",
 
23
  use_cache=True,
24
  classifier_dropout=None,
25
  lora_adaptations=None,
@@ -52,6 +53,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
52
  self.initializer_range = initializer_range
53
  self.layer_norm_eps = layer_norm_eps
54
  self.position_embedding_type = position_embedding_type
 
55
  self.use_cache = use_cache
56
  self.classifier_dropout = classifier_dropout
57
  self.load_trained_adapters = load_trained_adapters
 
20
  bos_token_id=0,
21
  eos_token_id=2,
22
  position_embedding_type="absolute",
23
+ rotary_emb_base=10000.0,
24
  use_cache=True,
25
  classifier_dropout=None,
26
  lora_adaptations=None,
 
53
  self.initializer_range = initializer_range
54
  self.layer_norm_eps = layer_norm_eps
55
  self.position_embedding_type = position_embedding_type
56
+ self.rotary_emb_base = rotary_emb_base
57
  self.use_cache = use_cache
58
  self.classifier_dropout = classifier_dropout
59
  self.load_trained_adapters = load_trained_adapters
modeling_lora.py CHANGED
@@ -265,6 +265,14 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
265
  self.main_params_trainable = config.lora_main_params_trainable
266
 
267
 
 
 
 
 
 
 
 
 
268
  @property
269
  def main_params_trainable(self):
270
  return self._main_params_trainable
 
265
  self.main_params_trainable = config.lora_main_params_trainable
266
 
267
 
268
+ @property
269
+ def rotary_emb_base(self):
270
+ return self.roberta.rotary_emb_base
271
+
272
+ @rotary_emb_base.setter
273
+ def rotary_emb_base(self, base):
274
+ self.roberta.rotary_emb_base = base
275
+
276
  @property
277
  def main_params_trainable(self):
278
  return self._main_params_trainable
modeling_xlm_roberta.py CHANGED
@@ -93,7 +93,7 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
93
  rotary_kwargs["rotary_emb_dim"] = getattr(
94
  config, "rotary_emb_dim", config.hidden_size / config.num_attention_heads
95
  )
96
- rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
97
  rotary_kwargs["rotary_emb_scale_base"] = getattr(
98
  config, "rotary_emb_scale_base", None
99
  )
@@ -453,6 +453,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
453
 
454
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
455
  self.tokenizer = AutoTokenizer.from_pretrained(self.name_or_path, trust_remote_code=True)
 
456
 
457
  @torch.inference_mode()
458
  def encode(
@@ -601,7 +602,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
601
  self.train(is_training)
602
  return all_embeddings
603
 
604
-
605
  def truncate_embeddings(self, embeddings, truncate_dim):
606
  if not self.config.matryoshka_dimensions:
607
  logger.warning(
@@ -624,12 +624,23 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
624
  input_mask_expanded.sum(1), min=1e-9
625
  )
626
 
627
-
628
  def cls_pooling(
629
  self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
630
  ):
631
  return token_embeddings[:,0]
632
 
 
 
 
 
 
 
 
 
 
 
 
 
633
 
634
  def forward(
635
  self,
 
93
  rotary_kwargs["rotary_emb_dim"] = getattr(
94
  config, "rotary_emb_dim", config.hidden_size / config.num_attention_heads
95
  )
96
+ rotary_kwargs["rotary_emb_base"] = config.rotary_emb_base
97
  rotary_kwargs["rotary_emb_scale_base"] = getattr(
98
  config, "rotary_emb_scale_base", None
99
  )
 
453
 
454
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
455
  self.tokenizer = AutoTokenizer.from_pretrained(self.name_or_path, trust_remote_code=True)
456
+ self._rotary_emb_base = config.rotary_emb_base
457
 
458
  @torch.inference_mode()
459
  def encode(
 
602
  self.train(is_training)
603
  return all_embeddings
604
 
 
605
  def truncate_embeddings(self, embeddings, truncate_dim):
606
  if not self.config.matryoshka_dimensions:
607
  logger.warning(
 
624
  input_mask_expanded.sum(1), min=1e-9
625
  )
626
 
 
627
  def cls_pooling(
628
  self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
629
  ):
630
  return token_embeddings[:,0]
631
 
632
+ @property
633
+ def rotary_emb_base(self):
634
+ return self._rotary_emb_base
635
+
636
+ @rotary_emb_base.setter
637
+ def rotary_emb_base(self, base):
638
+ if not isinstance(base, (int, float)):
639
+ raise TypeError("Base must be an integer or float")
640
+ logger.info(f'Changing RoPE base value to {base}')
641
+ for layer in self.encoder.layers:
642
+ layer.mixer.rotary_emb.base = base
643
+ self._rotary_emb_base = base
644
 
645
  def forward(
646
  self,
rotary.py CHANGED
@@ -443,7 +443,7 @@ class RotaryEmbedding(torch.nn.Module):
443
  """
444
  super().__init__()
445
  self.dim = dim
446
- self.base = float(base)
447
  self.pos_idx_in_fp32 = pos_idx_in_fp32
448
  # Generate and save the inverse frequency buffer (non trainable)
449
  inv_freq = self._compute_inv_freq(device)
@@ -463,6 +463,17 @@ class RotaryEmbedding(torch.nn.Module):
463
  self._cos_k_cached = None
464
  self._sin_k_cached = None
465
 
 
 
 
 
 
 
 
 
 
 
 
466
  def _compute_inv_freq(self, device=None):
467
  return 1.0 / (
468
  self.base
 
443
  """
444
  super().__init__()
445
  self.dim = dim
446
+ self._base = float(base)
447
  self.pos_idx_in_fp32 = pos_idx_in_fp32
448
  # Generate and save the inverse frequency buffer (non trainable)
449
  inv_freq = self._compute_inv_freq(device)
 
463
  self._cos_k_cached = None
464
  self._sin_k_cached = None
465
 
466
+ @property
467
+ def base(self):
468
+ return self._base
469
+
470
+ @base.setter
471
+ def base(self, new_base):
472
+ if new_base > 0:
473
+ self._base = float(new_base)
474
+ else:
475
+ raise ValueError("Rotary base value must be positive")
476
+
477
  def _compute_inv_freq(self, device=None):
478
  return 1.0 / (
479
  self.base