Jackmin108
commited on
Commit
·
c120089
1
Parent(s):
027219a
feat: rotary base as a property
Browse filesSigned-off-by: Meow <[email protected]>
- configuration_xlm_roberta.py +2 -0
- modeling_lora.py +8 -0
- modeling_xlm_roberta.py +14 -3
- rotary.py +12 -1
configuration_xlm_roberta.py
CHANGED
@@ -20,6 +20,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
|
|
20 |
bos_token_id=0,
|
21 |
eos_token_id=2,
|
22 |
position_embedding_type="absolute",
|
|
|
23 |
use_cache=True,
|
24 |
classifier_dropout=None,
|
25 |
lora_adaptations=None,
|
@@ -52,6 +53,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
|
|
52 |
self.initializer_range = initializer_range
|
53 |
self.layer_norm_eps = layer_norm_eps
|
54 |
self.position_embedding_type = position_embedding_type
|
|
|
55 |
self.use_cache = use_cache
|
56 |
self.classifier_dropout = classifier_dropout
|
57 |
self.load_trained_adapters = load_trained_adapters
|
|
|
20 |
bos_token_id=0,
|
21 |
eos_token_id=2,
|
22 |
position_embedding_type="absolute",
|
23 |
+
rotary_emb_base=10000.0,
|
24 |
use_cache=True,
|
25 |
classifier_dropout=None,
|
26 |
lora_adaptations=None,
|
|
|
53 |
self.initializer_range = initializer_range
|
54 |
self.layer_norm_eps = layer_norm_eps
|
55 |
self.position_embedding_type = position_embedding_type
|
56 |
+
self.rotary_emb_base = rotary_emb_base
|
57 |
self.use_cache = use_cache
|
58 |
self.classifier_dropout = classifier_dropout
|
59 |
self.load_trained_adapters = load_trained_adapters
|
modeling_lora.py
CHANGED
@@ -265,6 +265,14 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
265 |
self.main_params_trainable = config.lora_main_params_trainable
|
266 |
|
267 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
@property
|
269 |
def main_params_trainable(self):
|
270 |
return self._main_params_trainable
|
|
|
265 |
self.main_params_trainable = config.lora_main_params_trainable
|
266 |
|
267 |
|
268 |
+
@property
|
269 |
+
def rotary_emb_base(self):
|
270 |
+
return self.roberta.rotary_emb_base
|
271 |
+
|
272 |
+
@rotary_emb_base.setter
|
273 |
+
def rotary_emb_base(self, base):
|
274 |
+
self.roberta.rotary_emb_base = base
|
275 |
+
|
276 |
@property
|
277 |
def main_params_trainable(self):
|
278 |
return self._main_params_trainable
|
modeling_xlm_roberta.py
CHANGED
@@ -93,7 +93,7 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
|
93 |
rotary_kwargs["rotary_emb_dim"] = getattr(
|
94 |
config, "rotary_emb_dim", config.hidden_size / config.num_attention_heads
|
95 |
)
|
96 |
-
rotary_kwargs["rotary_emb_base"] =
|
97 |
rotary_kwargs["rotary_emb_scale_base"] = getattr(
|
98 |
config, "rotary_emb_scale_base", None
|
99 |
)
|
@@ -453,6 +453,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
453 |
|
454 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
455 |
self.tokenizer = AutoTokenizer.from_pretrained(self.name_or_path, trust_remote_code=True)
|
|
|
456 |
|
457 |
@torch.inference_mode()
|
458 |
def encode(
|
@@ -601,7 +602,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
601 |
self.train(is_training)
|
602 |
return all_embeddings
|
603 |
|
604 |
-
|
605 |
def truncate_embeddings(self, embeddings, truncate_dim):
|
606 |
if not self.config.matryoshka_dimensions:
|
607 |
logger.warning(
|
@@ -624,12 +624,23 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
624 |
input_mask_expanded.sum(1), min=1e-9
|
625 |
)
|
626 |
|
627 |
-
|
628 |
def cls_pooling(
|
629 |
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
630 |
):
|
631 |
return token_embeddings[:,0]
|
632 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
633 |
|
634 |
def forward(
|
635 |
self,
|
|
|
93 |
rotary_kwargs["rotary_emb_dim"] = getattr(
|
94 |
config, "rotary_emb_dim", config.hidden_size / config.num_attention_heads
|
95 |
)
|
96 |
+
rotary_kwargs["rotary_emb_base"] = config.rotary_emb_base
|
97 |
rotary_kwargs["rotary_emb_scale_base"] = getattr(
|
98 |
config, "rotary_emb_scale_base", None
|
99 |
)
|
|
|
453 |
|
454 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
455 |
self.tokenizer = AutoTokenizer.from_pretrained(self.name_or_path, trust_remote_code=True)
|
456 |
+
self._rotary_emb_base = config.rotary_emb_base
|
457 |
|
458 |
@torch.inference_mode()
|
459 |
def encode(
|
|
|
602 |
self.train(is_training)
|
603 |
return all_embeddings
|
604 |
|
|
|
605 |
def truncate_embeddings(self, embeddings, truncate_dim):
|
606 |
if not self.config.matryoshka_dimensions:
|
607 |
logger.warning(
|
|
|
624 |
input_mask_expanded.sum(1), min=1e-9
|
625 |
)
|
626 |
|
|
|
627 |
def cls_pooling(
|
628 |
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
629 |
):
|
630 |
return token_embeddings[:,0]
|
631 |
|
632 |
+
@property
|
633 |
+
def rotary_emb_base(self):
|
634 |
+
return self._rotary_emb_base
|
635 |
+
|
636 |
+
@rotary_emb_base.setter
|
637 |
+
def rotary_emb_base(self, base):
|
638 |
+
if not isinstance(base, (int, float)):
|
639 |
+
raise TypeError("Base must be an integer or float")
|
640 |
+
logger.info(f'Changing RoPE base value to {base}')
|
641 |
+
for layer in self.encoder.layers:
|
642 |
+
layer.mixer.rotary_emb.base = base
|
643 |
+
self._rotary_emb_base = base
|
644 |
|
645 |
def forward(
|
646 |
self,
|
rotary.py
CHANGED
@@ -443,7 +443,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|
443 |
"""
|
444 |
super().__init__()
|
445 |
self.dim = dim
|
446 |
-
self.
|
447 |
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
448 |
# Generate and save the inverse frequency buffer (non trainable)
|
449 |
inv_freq = self._compute_inv_freq(device)
|
@@ -463,6 +463,17 @@ class RotaryEmbedding(torch.nn.Module):
|
|
463 |
self._cos_k_cached = None
|
464 |
self._sin_k_cached = None
|
465 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
466 |
def _compute_inv_freq(self, device=None):
|
467 |
return 1.0 / (
|
468 |
self.base
|
|
|
443 |
"""
|
444 |
super().__init__()
|
445 |
self.dim = dim
|
446 |
+
self._base = float(base)
|
447 |
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
448 |
# Generate and save the inverse frequency buffer (non trainable)
|
449 |
inv_freq = self._compute_inv_freq(device)
|
|
|
463 |
self._cos_k_cached = None
|
464 |
self._sin_k_cached = None
|
465 |
|
466 |
+
@property
|
467 |
+
def base(self):
|
468 |
+
return self._base
|
469 |
+
|
470 |
+
@base.setter
|
471 |
+
def base(self, new_base):
|
472 |
+
if new_base > 0:
|
473 |
+
self._base = float(new_base)
|
474 |
+
else:
|
475 |
+
raise ValueError("Rotary base value must be positive")
|
476 |
+
|
477 |
def _compute_inv_freq(self, device=None):
|
478 |
return 1.0 / (
|
479 |
self.base
|