xzuyn
commited on
Add `layers_to_transform` for `lora_config` (#1118)
Browse files- README.md +2 -1
- src/axolotl/utils/config.py +5 -0
- src/axolotl/utils/models.py +1 -0
- tests/test_validation.py +15 -0
README.md
CHANGED
@@ -677,7 +677,8 @@ lora_target_modules:
|
|
677 |
# - gate_proj
|
678 |
# - down_proj
|
679 |
# - up_proj
|
680 |
-
lora_target_linear: # If true, will target all linear
|
|
|
681 |
|
682 |
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
|
683 |
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
|
|
|
677 |
# - gate_proj
|
678 |
# - down_proj
|
679 |
# - up_proj
|
680 |
+
lora_target_linear: # If true, will target all linear modules
|
681 |
+
peft_layers_to_transform: # The layer indices to transform, otherwise, apply to all layers
|
682 |
|
683 |
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
|
684 |
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
|
src/axolotl/utils/config.py
CHANGED
@@ -257,6 +257,11 @@ def validate_config(cfg):
|
|
257 |
if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
|
258 |
raise ValueError("Fused modules are not supported with LoRA")
|
259 |
|
|
|
|
|
|
|
|
|
|
|
260 |
if cfg.relora_steps:
|
261 |
if cfg.adapter not in ("lora", "qlora"):
|
262 |
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
|
|
257 |
if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
|
258 |
raise ValueError("Fused modules are not supported with LoRA")
|
259 |
|
260 |
+
if cfg.adapter and cfg.peft_layers_to_transform and cfg.unfrozen_parameters:
|
261 |
+
raise ValueError(
|
262 |
+
"`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior."
|
263 |
+
)
|
264 |
+
|
265 |
if cfg.relora_steps:
|
266 |
if cfg.adapter not in ("lora", "qlora"):
|
267 |
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
src/axolotl/utils/models.py
CHANGED
@@ -733,6 +733,7 @@ def load_lora(model, cfg, inference=False):
|
|
733 |
r=cfg.lora_r,
|
734 |
lora_alpha=cfg.lora_alpha,
|
735 |
target_modules=lora_target_modules,
|
|
|
736 |
lora_dropout=cfg.lora_dropout,
|
737 |
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
738 |
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
|
|
733 |
r=cfg.lora_r,
|
734 |
lora_alpha=cfg.lora_alpha,
|
735 |
target_modules=lora_target_modules,
|
736 |
+
layers_to_transform=cfg.peft_layers_to_transform,
|
737 |
lora_dropout=cfg.lora_dropout,
|
738 |
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
739 |
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
tests/test_validation.py
CHANGED
@@ -694,6 +694,21 @@ class ValidationTest(BaseValidation):
|
|
694 |
|
695 |
validate_config(cfg)
|
696 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
697 |
|
698 |
class ValidationCheckModelConfig(BaseValidation):
|
699 |
"""
|
|
|
694 |
|
695 |
validate_config(cfg)
|
696 |
|
697 |
+
def test_unfrozen_parameters_w_peft_layers_to_transform(self):
|
698 |
+
cfg = DictDefault(
|
699 |
+
{
|
700 |
+
"adapter": "lora",
|
701 |
+
"unfrozen_parameters": ["model.layers.2[0-9]+.block_sparse_moe.gate.*"],
|
702 |
+
"peft_layers_to_transform": [0, 1],
|
703 |
+
}
|
704 |
+
)
|
705 |
+
|
706 |
+
with pytest.raises(
|
707 |
+
ValueError,
|
708 |
+
match=r".*can have unexpected behavior*",
|
709 |
+
):
|
710 |
+
validate_config(cfg)
|
711 |
+
|
712 |
|
713 |
class ValidationCheckModelConfig(BaseValidation):
|
714 |
"""
|