scope flash-attn+qlora fix correctly, scope to llama, add comment
Browse files
src/axolotl/utils/models.py
CHANGED
@@ -333,13 +333,15 @@ def load_model(
|
|
333 |
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
334 |
)
|
335 |
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
if hasattr(module, "weight"):
|
342 |
module.to(torch_dtype)
|
|
|
|
|
|
|
343 |
|
344 |
model, lora_config = load_adapter(model, cfg, adapter)
|
345 |
|
|
|
333 |
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
334 |
)
|
335 |
|
336 |
+
# LlamaRMSNorm layers are in fp32 after kit call, so we need to
|
337 |
+
# convert them back to fp16/bf16 for flash-attn compatibility.
|
338 |
+
if cfg.flash_attention and cfg.is_llama_derived_model:
|
339 |
+
for name, module in model.named_modules():
|
340 |
+
if "norm" in name:
|
|
|
341 |
module.to(torch_dtype)
|
342 |
+
if "lm_head" in name or "embed_tokens" in name:
|
343 |
+
if hasattr(module, "weight"):
|
344 |
+
module.to(torch_dtype)
|
345 |
|
346 |
model, lora_config = load_adapter(model, cfg, adapter)
|
347 |
|