Maxime
Maxime
commited on
add noisy embedding (#721)
Browse files* add noisy embedding
* fix format
* Update README.md
* Update README.md
* linter issues
* caseus fixes
---------
Co-authored-by: Maxime <[email protected]>
README.md
CHANGED
@@ -672,6 +672,11 @@ adam_epsilon:
|
|
672 |
# Gradient clipping max norm
|
673 |
max_grad_norm:
|
674 |
|
|
|
|
|
|
|
|
|
|
|
675 |
# Whether to bettertransformers
|
676 |
flash_optimum:
|
677 |
# Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
|
|
672 |
# Gradient clipping max norm
|
673 |
max_grad_norm:
|
674 |
|
675 |
+
# Augmentation techniques
|
676 |
+
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
|
677 |
+
# currently only supported on Llama and Mistral
|
678 |
+
noisy_embedding_alpha:
|
679 |
+
|
680 |
# Whether to bettertransformers
|
681 |
flash_optimum:
|
682 |
# Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
src/axolotl/monkeypatch/llama_embeddings_hijack.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import transformers.models.llama.modeling_llama
|
7 |
+
from transformers.utils import logging
|
8 |
+
|
9 |
+
logger = logging.get_logger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5):
|
13 |
+
# pylint: disable=duplicate-code
|
14 |
+
def noised_embed(orig_embed, noise_alpha, model):
|
15 |
+
def new_func(input_ids):
|
16 |
+
# during training, we add noise to the embedding
|
17 |
+
# during generation, we don't add noise to the embedding
|
18 |
+
if model.training:
|
19 |
+
embed_init = orig_embed(input_ids)
|
20 |
+
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
|
21 |
+
mag_norm = noise_alpha / torch.sqrt(dims)
|
22 |
+
return embed_init + torch.zeros_like(embed_init).uniform_(
|
23 |
+
-mag_norm, mag_norm
|
24 |
+
)
|
25 |
+
return orig_embed(input_ids)
|
26 |
+
|
27 |
+
return new_func
|
28 |
+
|
29 |
+
def post_init(orig_post_init):
|
30 |
+
def new_func(self):
|
31 |
+
orig_post_init(self)
|
32 |
+
self.embed_tokens.forward = noised_embed(
|
33 |
+
self.embed_tokens.forward, noise_alpha, self
|
34 |
+
)
|
35 |
+
|
36 |
+
return new_func
|
37 |
+
|
38 |
+
transformers.models.llama.modeling_llama.LlamaModel.post_init = post_init(
|
39 |
+
transformers.models.llama.modeling_llama.LlamaModel.post_init
|
40 |
+
)
|
src/axolotl/monkeypatch/mistral_embeddings_hijack.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import transformers.models.mistral.modeling_mistral
|
7 |
+
from transformers.utils import logging
|
8 |
+
|
9 |
+
logger = logging.get_logger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
def replace_mistral_embeddings_with_uniform_distribution(noise_alpha=5):
|
13 |
+
# pylint: disable=duplicate-code
|
14 |
+
def noised_embed(orig_embed, noise_alpha, model):
|
15 |
+
def new_func(input_ids):
|
16 |
+
# during training, we add noise to the embedding
|
17 |
+
# during generation, we don't add noise to the embedding
|
18 |
+
if model.training:
|
19 |
+
embed_init = orig_embed(input_ids)
|
20 |
+
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
|
21 |
+
mag_norm = noise_alpha / torch.sqrt(dims)
|
22 |
+
return embed_init + torch.zeros_like(embed_init).uniform_(
|
23 |
+
-mag_norm, mag_norm
|
24 |
+
)
|
25 |
+
return orig_embed(input_ids)
|
26 |
+
|
27 |
+
return new_func
|
28 |
+
|
29 |
+
def post_init(orig_post_init):
|
30 |
+
def new_func(self):
|
31 |
+
orig_post_init(self)
|
32 |
+
self.embed_tokens.forward = noised_embed(
|
33 |
+
self.embed_tokens.forward, noise_alpha, self
|
34 |
+
)
|
35 |
+
|
36 |
+
return new_func
|
37 |
+
|
38 |
+
transformers.models.mistral.modeling_mistral.MistralModel.post_init = post_init(
|
39 |
+
transformers.models.mistral.modeling_mistral.MistralModel.post_init
|
40 |
+
)
|
src/axolotl/utils/models.py
CHANGED
@@ -180,6 +180,26 @@ def load_model(
|
|
180 |
LOG.info("patching with flash attention")
|
181 |
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
184 |
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
185 |
replace_llama_rope_with_xpos_rope,
|
|
|
180 |
LOG.info("patching with flash attention")
|
181 |
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
182 |
|
183 |
+
if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
|
184 |
+
from axolotl.monkeypatch.llama_embeddings_hijack import (
|
185 |
+
replace_llama_embeddings_with_uniform_distribution,
|
186 |
+
)
|
187 |
+
|
188 |
+
LOG.info("patching with noisy embeddings")
|
189 |
+
replace_llama_embeddings_with_uniform_distribution(
|
190 |
+
noise_alpha=cfg.noisy_embedding_alpha
|
191 |
+
)
|
192 |
+
|
193 |
+
if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha:
|
194 |
+
from axolotl.monkeypatch.mistral_embeddings_hijack import (
|
195 |
+
replace_mistral_embeddings_with_uniform_distribution,
|
196 |
+
)
|
197 |
+
|
198 |
+
LOG.info("patching with noisy embeddings")
|
199 |
+
replace_mistral_embeddings_with_uniform_distribution(
|
200 |
+
noise_alpha=cfg.noisy_embedding_alpha
|
201 |
+
)
|
202 |
+
|
203 |
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
204 |
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
205 |
replace_llama_rope_with_xpos_rope,
|