winglian commited on
Commit
fe0b768
·
1 Parent(s): e944311

match up gradient checkpointing when using lora w config

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +3 -1
src/axolotl/utils/models.py CHANGED
@@ -305,7 +305,9 @@ def load_model(
305
  or (cfg.adapter == "qlora" and cfg.load_in_4bit)
306
  ):
307
  logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
308
- model = prepare_model_for_kbit_training(model)
 
 
309
 
310
  model, lora_config = load_adapter(model, cfg, adapter)
311
 
 
305
  or (cfg.adapter == "qlora" and cfg.load_in_4bit)
306
  ):
307
  logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
308
+ model = prepare_model_for_kbit_training(
309
+ model, use_gradient_checkpointing=cfg.gradient_checkpointing
310
+ )
311
 
312
  model, lora_config = load_adapter(model, cfg, adapter)
313