match up gradient checkpointing when using lora w config
Browse files
src/axolotl/utils/models.py
CHANGED
@@ -305,7 +305,9 @@ def load_model(
|
|
305 |
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
306 |
):
|
307 |
logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
308 |
-
model = prepare_model_for_kbit_training(
|
|
|
|
|
309 |
|
310 |
model, lora_config = load_adapter(model, cfg, adapter)
|
311 |
|
|
|
305 |
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
306 |
):
|
307 |
logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
308 |
+
model = prepare_model_for_kbit_training(
|
309 |
+
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
310 |
+
)
|
311 |
|
312 |
model, lora_config = load_adapter(model, cfg, adapter)
|
313 |
|