fix typo
Browse files
src/axolotl/utils/models.py
CHANGED
@@ -333,7 +333,7 @@ def load_model(
|
|
333 |
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
334 |
)
|
335 |
|
336 |
-
# LlamaRMSNorm layers are in fp32 after
|
337 |
# convert them back to fp16/bf16 for flash-attn compatibility.
|
338 |
if cfg.flash_attention and cfg.is_llama_derived_model:
|
339 |
for name, module in model.named_modules():
|
|
|
333 |
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
334 |
)
|
335 |
|
336 |
+
# LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
|
337 |
# convert them back to fp16/bf16 for flash-attn compatibility.
|
338 |
if cfg.flash_attention and cfg.is_llama_derived_model:
|
339 |
for name, module in model.named_modules():
|