tmm1 commited on
Commit
2eda9e0
·
1 Parent(s): 78b9efb
Files changed (1) hide show
  1. src/axolotl/utils/models.py +1 -1
src/axolotl/utils/models.py CHANGED
@@ -333,7 +333,7 @@ def load_model(
333
  model, use_gradient_checkpointing=cfg.gradient_checkpointing
334
  )
335
 
336
- # LlamaRMSNorm layers are in fp32 after kit call, so we need to
337
  # convert them back to fp16/bf16 for flash-attn compatibility.
338
  if cfg.flash_attention and cfg.is_llama_derived_model:
339
  for name, module in model.named_modules():
 
333
  model, use_gradient_checkpointing=cfg.gradient_checkpointing
334
  )
335
 
336
+ # LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
337
  # convert them back to fp16/bf16 for flash-attn compatibility.
338
  if cfg.flash_attention and cfg.is_llama_derived_model:
339
  for name, module in model.named_modules():