Nanobit commited on
Commit
73e9ea4
·
unverified ·
2 Parent(s): f8d3798 df9528f

Merge pull request #143 from NanoCode012/fix/deprecate-prepare-8bit-training

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +4 -3
src/axolotl/utils/models.py CHANGED
@@ -128,7 +128,8 @@ def load_model(
128
  )
129
 
130
  replace_peft_model_with_int4_lora_model()
131
- from peft import prepare_model_for_int8_training
 
132
  except Exception as err:
133
  logging.exception(err)
134
  raise err
@@ -269,8 +270,8 @@ def load_model(
269
  (cfg.adapter == "lora" and load_in_8bit)
270
  or (cfg.adapter == "qlora" and cfg.load_in_4bit)
271
  ):
272
- logging.info("converting PEFT model w/ prepare_model_for_int8_training")
273
- model = prepare_model_for_int8_training(model)
274
 
275
  model, lora_config = load_adapter(model, cfg, adapter)
276
 
 
128
  )
129
 
130
  replace_peft_model_with_int4_lora_model()
131
+ else:
132
+ from peft import prepare_model_for_kbit_training
133
  except Exception as err:
134
  logging.exception(err)
135
  raise err
 
270
  (cfg.adapter == "lora" and load_in_8bit)
271
  or (cfg.adapter == "qlora" and cfg.load_in_4bit)
272
  ):
273
+ logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
274
+ model = prepare_model_for_kbit_training(model)
275
 
276
  model, lora_config = load_adapter(model, cfg, adapter)
277