winglian commited on
Commit
0f10080
·
unverified ·
1 Parent(s): ead34c5

be more robust about checking embedding modules for lora finetunes (#1074) [skip ci]

Browse files

* be more robust about checking embedding modules for lora finetunes

* update dynamic error message

src/axolotl/utils/config.py CHANGED
@@ -151,6 +151,10 @@ def normalize_config(cfg):
151
 
152
 
153
  def validate_config(cfg):
 
 
 
 
154
  if is_torch_bf16_gpu_available():
155
  if not cfg.bf16 and not cfg.bfloat16:
156
  LOG.info("bf16 support detected, but not enabled for this configuration.")
@@ -443,20 +447,6 @@ def validate_config(cfg):
443
  if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
444
  raise ValueError("neftune_noise_alpha must be > 0.0")
445
 
446
- if (
447
- cfg.adapter
448
- and cfg.tokens
449
- and (
450
- not cfg.lora_modules_to_save
451
- or not all(
452
- x in cfg.lora_modules_to_save for x in ["embed_tokens", "lm_head"]
453
- )
454
- )
455
- ):
456
- raise ValueError(
457
- "lora_modules_to_save not properly set yet adding new tokens. Please add `embed_tokens` and `lm_head` to `lora_modules_to_save`."
458
- )
459
-
460
  if cfg.max_memory is not None and cfg.gpu_memory_limit is not None:
461
  raise ValueError(
462
  "max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
 
151
 
152
 
153
  def validate_config(cfg):
154
+ """
155
+ This is a "pre-validation" step that handles the yaml configuration before we have any
156
+ information about the model architecture
157
+ """
158
  if is_torch_bf16_gpu_available():
159
  if not cfg.bf16 and not cfg.bfloat16:
160
  LOG.info("bf16 support detected, but not enabled for this configuration.")
 
447
  if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
448
  raise ValueError("neftune_noise_alpha must be > 0.0")
449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
  if cfg.max_memory is not None and cfg.gpu_memory_limit is not None:
451
  raise ValueError(
452
  "max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
src/axolotl/utils/lora_embeddings.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ helpers for lora embeddings
3
+ """
4
+
5
+
6
+ def get_linear_embedding_layers(model_type):
7
+ """
8
+ returns the linear embedding layers needed for loras, dependent on the model arch
9
+ """
10
+ if model_type == "phi-msft":
11
+ return ["embd", "lm_head.linear"]
12
+ return ["lm_head", "embed_tokens"]
src/axolotl/utils/models.py CHANGED
@@ -2,7 +2,7 @@
2
  import logging
3
  import math
4
  import os
5
- from typing import Any, Optional, Tuple # noqa: F401
6
 
7
  import addict
8
  import bitsandbytes as bnb
@@ -28,12 +28,16 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
28
  from axolotl.utils.bench import log_gpu_memory_usage
29
  from axolotl.utils.chat_templates import chat_templates
30
  from axolotl.utils.dict import DictDefault
 
31
 
32
  LOG = logging.getLogger("axolotl")
33
 
34
 
35
- def check_model_config(cfg: DictDefault, model_config: AutoConfig):
36
- quant_config_exists = hasattr(model_config, "quantization_config")
 
 
 
37
  quant_config_method_is_gptq = (
38
  quant_config_exists
39
  and "quant_method" in model_config.quantization_config
@@ -52,6 +56,20 @@ def check_model_config(cfg: DictDefault, model_config: AutoConfig):
52
  "Please use the `gptq` flag to train quantized model or point to a non-quantized model."
53
  )
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  def load_model_config(cfg):
57
  model_config_name = cfg.base_model_config or cfg.base_model
@@ -139,6 +157,7 @@ def load_tokenizer(cfg):
139
  setattr(tokenizer, attr_name, "<|endoftext|>")
140
 
141
  if cfg.special_tokens:
 
142
  for k, val in cfg.special_tokens.items():
143
  # check if new special token is not already in tokenizer and
144
  # is adapter training to make sure lora_modules_to_save is set
@@ -149,14 +168,15 @@ def load_tokenizer(cfg):
149
  and (
150
  not cfg.lora_modules_to_save
151
  or not all(
152
- x in cfg.lora_modules_to_save
153
- for x in ["embed_tokens", "lm_head"]
154
  )
155
  )
156
- and (model_config.model_type in ["llama", "mistral", "mixtral"])
157
  ):
 
 
 
158
  raise ValueError(
159
- "Please set lora_modules_to_save to ['embed_tokens', 'lm_head'] when using an adapter and changing the special tokens."
160
  )
161
 
162
  tokenizer.add_special_tokens(
 
2
  import logging
3
  import math
4
  import os
5
+ from typing import Any, Optional, Tuple, Union # noqa: F401
6
 
7
  import addict
8
  import bitsandbytes as bnb
 
28
  from axolotl.utils.bench import log_gpu_memory_usage
29
  from axolotl.utils.chat_templates import chat_templates
30
  from axolotl.utils.dict import DictDefault
31
+ from axolotl.utils.lora_embeddings import get_linear_embedding_layers
32
 
33
  LOG = logging.getLogger("axolotl")
34
 
35
 
36
+ def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
37
+ quant_config_exists = (
38
+ hasattr(model_config, "quantization_config")
39
+ and model_config.quantization_config
40
+ )
41
  quant_config_method_is_gptq = (
42
  quant_config_exists
43
  and "quant_method" in model_config.quantization_config
 
56
  "Please use the `gptq` flag to train quantized model or point to a non-quantized model."
57
  )
58
 
59
+ lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
60
+ if (
61
+ cfg.adapter
62
+ and cfg.tokens
63
+ and (
64
+ not cfg.lora_modules_to_save
65
+ or not all(x in cfg.lora_modules_to_save for x in lora_modules_to_save)
66
+ )
67
+ ):
68
+ lora_modules_to_save = ", ".join(map(lambda x: f"`{x}`", lora_modules_to_save))
69
+ raise ValueError(
70
+ f"`lora_modules_to_save` not properly set when adding new tokens. Please include {lora_modules_to_save} in `lora_modules_to_save`."
71
+ )
72
+
73
 
74
  def load_model_config(cfg):
75
  model_config_name = cfg.base_model_config or cfg.base_model
 
157
  setattr(tokenizer, attr_name, "<|endoftext|>")
158
 
159
  if cfg.special_tokens:
160
+ lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
161
  for k, val in cfg.special_tokens.items():
162
  # check if new special token is not already in tokenizer and
163
  # is adapter training to make sure lora_modules_to_save is set
 
168
  and (
169
  not cfg.lora_modules_to_save
170
  or not all(
171
+ x in cfg.lora_modules_to_save for x in lora_modules_to_save
 
172
  )
173
  )
 
174
  ):
175
+ lora_modules_to_save = ", ".join(
176
+ [f"`{x}`" for x in lora_modules_to_save]
177
+ )
178
  raise ValueError(
179
+ f"Please set lora_modules_to_save to {lora_modules_to_save} when using an adapter and changing the special tokens."
180
  )
181
 
182
  tokenizer.add_special_tokens(
tests/test_validation.py CHANGED
@@ -10,12 +10,13 @@ from transformers.utils import is_torch_bf16_gpu_available
10
 
11
  from axolotl.utils.config import validate_config
12
  from axolotl.utils.dict import DictDefault
 
13
  from axolotl.utils.wandb_ import setup_wandb_env_vars
14
 
15
 
16
- class ValidationTest(unittest.TestCase):
17
  """
18
- Test the validation module
19
  """
20
 
21
  _caplog: Optional[pytest.LogCaptureFixture] = None
@@ -24,6 +25,12 @@ class ValidationTest(unittest.TestCase):
24
  def inject_fixtures(self, caplog):
25
  self._caplog = caplog
26
 
 
 
 
 
 
 
27
  def test_load_4bit_deprecate(self):
28
  cfg = DictDefault(
29
  {
@@ -687,16 +694,23 @@ class ValidationTest(unittest.TestCase):
687
 
688
  validate_config(cfg)
689
 
690
- def test_add_tokens_adapter(self):
 
 
 
 
 
 
691
  cfg = DictDefault(
692
  {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
693
  )
 
694
 
695
  with pytest.raises(
696
  ValueError,
697
- match=r".*lora_modules_to_save not properly set yet adding new tokens*",
698
  ):
699
- validate_config(cfg)
700
 
701
  cfg = DictDefault(
702
  {
@@ -709,9 +723,9 @@ class ValidationTest(unittest.TestCase):
709
 
710
  with pytest.raises(
711
  ValueError,
712
- match=r".*lora_modules_to_save not properly set yet adding new tokens*",
713
  ):
714
- validate_config(cfg)
715
 
716
  cfg = DictDefault(
717
  {
@@ -722,10 +736,48 @@ class ValidationTest(unittest.TestCase):
722
  }
723
  )
724
 
725
- validate_config(cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
726
 
727
 
728
- class ValidationWandbTest(ValidationTest):
729
  """
730
  Validation test for wandb
731
  """
 
10
 
11
  from axolotl.utils.config import validate_config
12
  from axolotl.utils.dict import DictDefault
13
+ from axolotl.utils.models import check_model_config
14
  from axolotl.utils.wandb_ import setup_wandb_env_vars
15
 
16
 
17
+ class BaseValidation(unittest.TestCase):
18
  """
19
+ Base validation module to setup the log capture
20
  """
21
 
22
  _caplog: Optional[pytest.LogCaptureFixture] = None
 
25
  def inject_fixtures(self, caplog):
26
  self._caplog = caplog
27
 
28
+
29
+ class ValidationTest(BaseValidation):
30
+ """
31
+ Test the validation module
32
+ """
33
+
34
  def test_load_4bit_deprecate(self):
35
  cfg = DictDefault(
36
  {
 
694
 
695
  validate_config(cfg)
696
 
697
+
698
+ class ValidationCheckModelConfig(BaseValidation):
699
+ """
700
+ Test the validation for the config when the model config is available
701
+ """
702
+
703
+ def test_llama_add_tokens_adapter(self):
704
  cfg = DictDefault(
705
  {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
706
  )
707
+ model_config = DictDefault({"model_type": "llama"})
708
 
709
  with pytest.raises(
710
  ValueError,
711
+ match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
712
  ):
713
+ check_model_config(cfg, model_config)
714
 
715
  cfg = DictDefault(
716
  {
 
723
 
724
  with pytest.raises(
725
  ValueError,
726
+ match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
727
  ):
728
+ check_model_config(cfg, model_config)
729
 
730
  cfg = DictDefault(
731
  {
 
736
  }
737
  )
738
 
739
+ check_model_config(cfg, model_config)
740
+
741
+ def test_phi2_add_tokens_adapter(self):
742
+ cfg = DictDefault(
743
+ {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
744
+ )
745
+ model_config = DictDefault({"model_type": "phi-msft"})
746
+
747
+ with pytest.raises(
748
+ ValueError,
749
+ match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
750
+ ):
751
+ check_model_config(cfg, model_config)
752
+
753
+ cfg = DictDefault(
754
+ {
755
+ "adapter": "qlora",
756
+ "load_in_4bit": True,
757
+ "tokens": ["<|imstart|>"],
758
+ "lora_modules_to_save": ["embed_tokens", "lm_head"],
759
+ }
760
+ )
761
+
762
+ with pytest.raises(
763
+ ValueError,
764
+ match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
765
+ ):
766
+ check_model_config(cfg, model_config)
767
+
768
+ cfg = DictDefault(
769
+ {
770
+ "adapter": "qlora",
771
+ "load_in_4bit": True,
772
+ "tokens": ["<|imstart|>"],
773
+ "lora_modules_to_save": ["embd", "lm_head.linear"],
774
+ }
775
+ )
776
+
777
+ check_model_config(cfg, model_config)
778
 
779
 
780
+ class ValidationWandbTest(BaseValidation):
781
  """
782
  Validation test for wandb
783
  """