don't pass rope_scaling kwarg if it's None (#383)
Browse files
src/axolotl/utils/models.py
CHANGED
@@ -229,8 +229,12 @@ def load_model(
|
|
229 |
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
|
230 |
from transformers import LlamaForCausalLM
|
231 |
|
|
|
|
|
|
|
232 |
config = LlamaConfig.from_pretrained(
|
233 |
-
base_model_config,
|
|
|
234 |
)
|
235 |
model = LlamaForCausalLM.from_pretrained(
|
236 |
base_model,
|
|
|
229 |
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
|
230 |
from transformers import LlamaForCausalLM
|
231 |
|
232 |
+
config_kwargs = {}
|
233 |
+
if cfg.rope_scaling:
|
234 |
+
config_kwargs["rope_scaling"] = cfg.rope_scaling
|
235 |
config = LlamaConfig.from_pretrained(
|
236 |
+
base_model_config,
|
237 |
+
**config_kwargs,
|
238 |
)
|
239 |
model = LlamaForCausalLM.from_pretrained(
|
240 |
base_model,
|