kallewoof
commited on
fix: switch to using the HuggingFace Transformers NEFT implementation (#941)
Browse files* fix: switch to using the HuggingFace Transformers NEFT implementation
* linter
* add support for noisy_embedding_alpha with a warning about it being renamed
* restore pre/posttrain_hooks
* move validation of NEFT noise alpha into validate_config()
* linter
- README.md +1 -1
- src/axolotl/core/trainer_builder.py +6 -0
- src/axolotl/monkeypatch/neft_embeddings.py +0 -65
- src/axolotl/train.py +2 -5
- src/axolotl/utils/config.py +14 -0
README.md
CHANGED
@@ -774,7 +774,7 @@ max_grad_norm:
|
|
774 |
# Augmentation techniques
|
775 |
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
|
776 |
# currently only supported on Llama and Mistral
|
777 |
-
|
778 |
|
779 |
# Whether to bettertransformers
|
780 |
flash_optimum:
|
|
|
774 |
# Augmentation techniques
|
775 |
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
|
776 |
# currently only supported on Llama and Mistral
|
777 |
+
neftune_noise_alpha:
|
778 |
|
779 |
# Whether to bettertransformers
|
780 |
flash_optimum:
|
src/axolotl/core/trainer_builder.py
CHANGED
@@ -712,6 +712,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
712 |
training_arguments_kwargs
|
713 |
)
|
714 |
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
|
|
|
|
|
|
|
|
|
|
|
|
715 |
training_args = (
|
716 |
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
717 |
**training_arguments_kwargs,
|
|
|
712 |
training_arguments_kwargs
|
713 |
)
|
714 |
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
715 |
+
|
716 |
+
if self.cfg.neftune_noise_alpha is not None:
|
717 |
+
training_arguments_kwargs[
|
718 |
+
"neftune_noise_alpha"
|
719 |
+
] = self.cfg.neftune_noise_alpha
|
720 |
+
|
721 |
training_args = (
|
722 |
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
723 |
**training_arguments_kwargs,
|
src/axolotl/monkeypatch/neft_embeddings.py
DELETED
@@ -1,65 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
patches implemented through the trainer hooks to enable NEFT/noisy embeddings per https://arxiv.org/abs/2310.05914
|
3 |
-
"""
|
4 |
-
import torch
|
5 |
-
from peft import PeftModel
|
6 |
-
from transformers import PreTrainedModel
|
7 |
-
|
8 |
-
|
9 |
-
def patch_neft(alpha, model):
|
10 |
-
embeddings = None
|
11 |
-
if isinstance(model, PreTrainedModel):
|
12 |
-
embeddings = model.get_input_embeddings()
|
13 |
-
if isinstance(model, PeftModel):
|
14 |
-
embeddings = model.base_model.get_input_embeddings()
|
15 |
-
if not embeddings:
|
16 |
-
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
|
17 |
-
embeddings.noisy_embedding_alpha = alpha
|
18 |
-
old_forward = embeddings.forward
|
19 |
-
|
20 |
-
# This hack seems to be needed to properly use a custom forward pass
|
21 |
-
# all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
|
22 |
-
bound_method = neft_forward.__get__( # pylint: disable=no-value-for-parameter
|
23 |
-
embeddings, embeddings.__class__
|
24 |
-
)
|
25 |
-
setattr(embeddings, "forward", bound_method)
|
26 |
-
|
27 |
-
embeddings._old_forward = old_forward # pylint: disable=protected-access
|
28 |
-
return model
|
29 |
-
|
30 |
-
|
31 |
-
def unpatch_neft(model):
|
32 |
-
embeddings = None
|
33 |
-
if isinstance(model, PreTrainedModel):
|
34 |
-
embeddings = model.get_input_embeddings()
|
35 |
-
if isinstance(model, PeftModel):
|
36 |
-
embeddings = model.base_model.get_input_embeddings()
|
37 |
-
if not embeddings:
|
38 |
-
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
|
39 |
-
if hasattr(embeddings, "_old_forward"):
|
40 |
-
embeddings.forward = embeddings._old_forward # pylint: disable=protected-access
|
41 |
-
del embeddings._old_forward # pylint: disable=protected-access
|
42 |
-
del embeddings.noisy_embedding_alpha
|
43 |
-
|
44 |
-
|
45 |
-
def neft_forward(self, inputs: torch.Tensor):
|
46 |
-
embeddings = self._old_forward(inputs) # pylint: disable=protected-access
|
47 |
-
|
48 |
-
if self.training:
|
49 |
-
dims = torch.tensor(embeddings.size(1) * embeddings.size(2))
|
50 |
-
mag_norm = self.noisy_embedding_alpha / torch.sqrt(dims)
|
51 |
-
embeddings = embeddings + torch.zeros_like(embeddings).uniform_(
|
52 |
-
-mag_norm, mag_norm
|
53 |
-
)
|
54 |
-
|
55 |
-
return embeddings
|
56 |
-
|
57 |
-
|
58 |
-
def pretrain_hook(cfg, trainer):
|
59 |
-
if cfg.noisy_embedding_alpha:
|
60 |
-
trainer.model = patch_neft(cfg.noisy_embedding_alpha, trainer.model)
|
61 |
-
|
62 |
-
|
63 |
-
def post_train_hook(cfg, trainer):
|
64 |
-
if cfg.noisy_embedding_alpha:
|
65 |
-
unpatch_neft(trainer.model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/axolotl/train.py
CHANGED
@@ -16,7 +16,6 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
|
|
16 |
|
17 |
from axolotl.common.cli import TrainerCliArgs
|
18 |
from axolotl.logging_config import configure_logging
|
19 |
-
from axolotl.monkeypatch import neft_embeddings
|
20 |
from axolotl.utils.dict import DictDefault
|
21 |
from axolotl.utils.freeze import freeze_parameters_except
|
22 |
from axolotl.utils.models import load_model, load_tokenizer
|
@@ -180,21 +179,19 @@ def train(
|
|
180 |
return model, tokenizer
|
181 |
|
182 |
|
183 |
-
def pretrain_hooks(
|
184 |
"""
|
185 |
Run hooks right before kicking off the training
|
186 |
:param cfg:
|
187 |
:param trainer:
|
188 |
:return:
|
189 |
"""
|
190 |
-
neft_embeddings.pretrain_hook(cfg, trainer)
|
191 |
|
192 |
|
193 |
-
def post_train_hooks(
|
194 |
"""
|
195 |
Run hooks right after training completes
|
196 |
:param cfg:
|
197 |
:param trainer:
|
198 |
:return:
|
199 |
"""
|
200 |
-
neft_embeddings.post_train_hook(cfg, trainer)
|
|
|
16 |
|
17 |
from axolotl.common.cli import TrainerCliArgs
|
18 |
from axolotl.logging_config import configure_logging
|
|
|
19 |
from axolotl.utils.dict import DictDefault
|
20 |
from axolotl.utils.freeze import freeze_parameters_except
|
21 |
from axolotl.utils.models import load_model, load_tokenizer
|
|
|
179 |
return model, tokenizer
|
180 |
|
181 |
|
182 |
+
def pretrain_hooks(_cfg, _trainer):
|
183 |
"""
|
184 |
Run hooks right before kicking off the training
|
185 |
:param cfg:
|
186 |
:param trainer:
|
187 |
:return:
|
188 |
"""
|
|
|
189 |
|
190 |
|
191 |
+
def post_train_hooks(_cfg, _trainer):
|
192 |
"""
|
193 |
Run hooks right after training completes
|
194 |
:param cfg:
|
195 |
:param trainer:
|
196 |
:return:
|
197 |
"""
|
|
src/axolotl/utils/config.py
CHANGED
@@ -434,6 +434,20 @@ def validate_config(cfg):
|
|
434 |
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
435 |
)
|
436 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
437 |
# TODO
|
438 |
# MPT 7b
|
439 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
|
|
434 |
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
435 |
)
|
436 |
|
437 |
+
if cfg.noisy_embedding_alpha is not None:
|
438 |
+
# Deprecated, use neftune_noise_alpha
|
439 |
+
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
440 |
+
if cfg.neftune_noise_alpha is None:
|
441 |
+
cfg.neftune_noise_alpha = cfg.noisy_embedding_alpha
|
442 |
+
else:
|
443 |
+
# User is providing both; bail and have them sort out their settings
|
444 |
+
raise ValueError(
|
445 |
+
"noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting"
|
446 |
+
)
|
447 |
+
|
448 |
+
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
|
449 |
+
raise ValueError("neftune_noise_alpha must be > 0.0")
|
450 |
+
|
451 |
# TODO
|
452 |
# MPT 7b
|
453 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|