tmm1 commited on
Commit
78b9efb
·
1 Parent(s): 312a9fa

scope flash-attn+qlora fix correctly, scope to llama, add comment

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +8 -6
src/axolotl/utils/models.py CHANGED
@@ -333,13 +333,15 @@ def load_model(
333
  model, use_gradient_checkpointing=cfg.gradient_checkpointing
334
  )
335
 
336
- if cfg.flash_attention:
337
- for name, module in model.named_modules():
338
- if "norm" in name:
339
- module.to(torch_dtype)
340
- if "lm_head" in name or "embed_tokens" in name:
341
- if hasattr(module, "weight"):
342
  module.to(torch_dtype)
 
 
 
343
 
344
  model, lora_config = load_adapter(model, cfg, adapter)
345
 
 
333
  model, use_gradient_checkpointing=cfg.gradient_checkpointing
334
  )
335
 
336
+ # LlamaRMSNorm layers are in fp32 after kit call, so we need to
337
+ # convert them back to fp16/bf16 for flash-attn compatibility.
338
+ if cfg.flash_attention and cfg.is_llama_derived_model:
339
+ for name, module in model.named_modules():
340
+ if "norm" in name:
 
341
  module.to(torch_dtype)
342
+ if "lm_head" in name or "embed_tokens" in name:
343
+ if hasattr(module, "weight"):
344
+ module.to(torch_dtype)
345
 
346
  model, lora_config = load_adapter(model, cfg, adapter)
347