Angainor Development commited on
Commit
813cfa4
·
unverified ·
1 Parent(s): 193c73b

WIP: Rely on cfg.inference

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +3 -4
src/axolotl/utils/models.py CHANGED
@@ -80,8 +80,7 @@ def load_model(
80
  model_type,
81
  tokenizer,
82
  cfg,
83
- adapter="lora",
84
- inference=False,
85
  ):
86
  # type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
87
  """
@@ -95,7 +94,7 @@ def load_model(
95
  )
96
 
97
  if is_llama_derived_model and cfg.flash_attention:
98
- if cfg.device not in ["mps", "cpu"] and inference is False:
99
  from axolotl.flash_attn import replace_llama_attn_with_flash_attn
100
 
101
  logging.info("patching with flash attention")
@@ -402,7 +401,7 @@ def load_lora(model, cfg):
402
  model = PeftModel.from_pretrained(
403
  model,
404
  cfg.lora_model_dir,
405
- is_trainable=True,
406
  device_map=cfg.device_map,
407
  # torch_dtype=torch.float16,
408
  )
 
80
  model_type,
81
  tokenizer,
82
  cfg,
83
+ adapter="lora"
 
84
  ):
85
  # type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
86
  """
 
94
  )
95
 
96
  if is_llama_derived_model and cfg.flash_attention:
97
+ if cfg.device not in ["mps", "cpu"] and cfg.inference is False:
98
  from axolotl.flash_attn import replace_llama_attn_with_flash_attn
99
 
100
  logging.info("patching with flash attention")
 
401
  model = PeftModel.from_pretrained(
402
  model,
403
  cfg.lora_model_dir,
404
+ is_trainable=not cfg.inference,
405
  device_map=cfg.device_map,
406
  # torch_dtype=torch.float16,
407
  )