allow overriding of model_config parameters from the YML (#853)
Browse files* allow overriding of model_config parameters from the YML
* remove old logging, update readme
* move the updating of model config to the load_model_config function
* add warning for deprecated rope_scaling in the root of the YML config
- README.md +8 -4
- src/axolotl/utils/config.py +3 -0
- src/axolotl/utils/models.py +21 -36
README.md
CHANGED
@@ -489,6 +489,14 @@ is_llama_derived_model:
|
|
489 |
# Please note that if you set this to true, `padding_side` will be set to "left" by default
|
490 |
is_mistral_derived_model:
|
491 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
492 |
# Whether you are training a 4-bit GPTQ quantized model
|
493 |
gptq: true
|
494 |
gptq_groupsize: 128 # group size
|
@@ -756,10 +764,6 @@ landmark_attention:
|
|
756 |
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
757 |
# LLaMA only
|
758 |
xpos_rope:
|
759 |
-
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
760 |
-
rope_scaling:
|
761 |
-
type: # linear | dynamic
|
762 |
-
factor: # float
|
763 |
|
764 |
# Resume from a specific checkpoint dir
|
765 |
resume_from_checkpoint:
|
|
|
489 |
# Please note that if you set this to true, `padding_side` will be set to "left" by default
|
490 |
is_mistral_derived_model:
|
491 |
|
492 |
+
# optional overrides to the base model configuration
|
493 |
+
model_config:
|
494 |
+
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
495 |
+
rope_scaling:
|
496 |
+
type: # linear | dynamic
|
497 |
+
factor: # float
|
498 |
+
|
499 |
+
|
500 |
# Whether you are training a 4-bit GPTQ quantized model
|
501 |
gptq: true
|
502 |
gptq_groupsize: 128 # group size
|
|
|
764 |
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
765 |
# LLaMA only
|
766 |
xpos_rope:
|
|
|
|
|
|
|
|
|
767 |
|
768 |
# Resume from a specific checkpoint dir
|
769 |
resume_from_checkpoint:
|
src/axolotl/utils/config.py
CHANGED
@@ -369,6 +369,9 @@ def validate_config(cfg):
|
|
369 |
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
370 |
)
|
371 |
|
|
|
|
|
|
|
372 |
# TODO
|
373 |
# MPT 7b
|
374 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
|
|
369 |
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
370 |
)
|
371 |
|
372 |
+
if cfg.rope_scaling:
|
373 |
+
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
|
374 |
+
|
375 |
# TODO
|
376 |
# MPT 7b
|
377 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
src/axolotl/utils/models.py
CHANGED
@@ -17,7 +17,6 @@ from transformers import ( # noqa: F401
|
|
17 |
AutoTokenizer,
|
18 |
BitsAndBytesConfig,
|
19 |
GPTQConfig,
|
20 |
-
LlamaConfig,
|
21 |
PreTrainedModel,
|
22 |
PreTrainedTokenizerBase,
|
23 |
)
|
@@ -32,9 +31,14 @@ LOG = logging.getLogger("axolotl")
|
|
32 |
def load_model_config(cfg):
|
33 |
model_config_name = cfg.base_model_config or cfg.base_model
|
34 |
trust_remote_code = cfg.trust_remote_code is True
|
35 |
-
|
36 |
model_config_name, trust_remote_code=trust_remote_code
|
37 |
)
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
|
40 |
def load_tokenizer(cfg):
|
@@ -51,7 +55,7 @@ def load_tokenizer(cfg):
|
|
51 |
if cfg.tokenizer_type:
|
52 |
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
53 |
|
54 |
-
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
55 |
tokenizer = tokenizer_cls.from_pretrained(
|
56 |
tokenizer_config,
|
57 |
trust_remote_code=cfg.trust_remote_code or False,
|
@@ -110,7 +114,6 @@ def load_model(
|
|
110 |
Load a model for a given configuration and tokenizer.
|
111 |
"""
|
112 |
base_model = cfg.base_model
|
113 |
-
base_model_config = cfg.base_model_config
|
114 |
model_type = cfg.model_type
|
115 |
model_config = load_model_config(cfg)
|
116 |
|
@@ -238,16 +241,9 @@ def load_model(
|
|
238 |
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
239 |
from transformers import LlamaForCausalLM
|
240 |
|
241 |
-
config_kwargs = {}
|
242 |
-
if cfg.rope_scaling:
|
243 |
-
config_kwargs["rope_scaling"] = cfg.rope_scaling
|
244 |
-
config = LlamaConfig.from_pretrained(
|
245 |
-
base_model_config,
|
246 |
-
**config_kwargs,
|
247 |
-
)
|
248 |
model = LlamaForCausalLM.from_pretrained(
|
249 |
base_model,
|
250 |
-
config=
|
251 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
252 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
253 |
**model_kwargs,
|
@@ -305,66 +301,55 @@ def load_model(
|
|
305 |
if cfg.gptq:
|
306 |
model = AutoModelForCausalLM.from_pretrained(
|
307 |
base_model,
|
|
|
308 |
trust_remote_code=cfg.trust_remote_code or False,
|
309 |
**model_kwargs,
|
310 |
)
|
311 |
else:
|
312 |
model = getattr(transformers, model_type).from_pretrained(
|
313 |
base_model,
|
|
|
314 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
315 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
316 |
trust_remote_code=cfg.trust_remote_code or False,
|
317 |
**model_kwargs,
|
318 |
)
|
319 |
else:
|
320 |
-
config = AutoConfig.from_pretrained(
|
321 |
-
base_model,
|
322 |
-
trust_remote_code=cfg.trust_remote_code or False,
|
323 |
-
)
|
324 |
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
325 |
# when training starts
|
326 |
if (
|
327 |
-
hasattr(
|
328 |
-
and
|
329 |
-
and cfg.sequence_len >
|
330 |
):
|
331 |
-
|
332 |
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
333 |
elif (
|
334 |
-
hasattr(
|
335 |
-
and
|
336 |
-
and cfg.sequence_len >
|
337 |
):
|
338 |
-
|
339 |
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
340 |
if cfg.gptq:
|
341 |
model = AutoModelForCausalLM.from_pretrained(
|
342 |
base_model,
|
343 |
-
config=
|
344 |
trust_remote_code=cfg.trust_remote_code or False,
|
345 |
**model_kwargs,
|
346 |
)
|
347 |
else:
|
348 |
model = AutoModelForCausalLM.from_pretrained(
|
349 |
base_model,
|
350 |
-
config=
|
351 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
352 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
353 |
trust_remote_code=cfg.trust_remote_code or False,
|
354 |
**model_kwargs,
|
355 |
)
|
356 |
except Exception as err: # pylint: disable=broad-exception-caught
|
357 |
-
LOG.error(
|
358 |
-
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
359 |
-
)
|
360 |
LOG.exception(err)
|
361 |
-
|
362 |
-
base_model,
|
363 |
-
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
364 |
-
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
365 |
-
trust_remote_code=cfg.trust_remote_code or False,
|
366 |
-
**model_kwargs,
|
367 |
-
)
|
368 |
|
369 |
embeddings_len = (
|
370 |
math.ceil(len(tokenizer) / 32) * 32
|
|
|
17 |
AutoTokenizer,
|
18 |
BitsAndBytesConfig,
|
19 |
GPTQConfig,
|
|
|
20 |
PreTrainedModel,
|
21 |
PreTrainedTokenizerBase,
|
22 |
)
|
|
|
31 |
def load_model_config(cfg):
|
32 |
model_config_name = cfg.base_model_config or cfg.base_model
|
33 |
trust_remote_code = cfg.trust_remote_code is True
|
34 |
+
model_config = AutoConfig.from_pretrained(
|
35 |
model_config_name, trust_remote_code=trust_remote_code
|
36 |
)
|
37 |
+
if cfg.model_config:
|
38 |
+
for key, val in cfg.model_config.items():
|
39 |
+
setattr(model_config, key, val)
|
40 |
+
|
41 |
+
return model_config
|
42 |
|
43 |
|
44 |
def load_tokenizer(cfg):
|
|
|
55 |
if cfg.tokenizer_type:
|
56 |
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
57 |
|
58 |
+
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
|
59 |
tokenizer = tokenizer_cls.from_pretrained(
|
60 |
tokenizer_config,
|
61 |
trust_remote_code=cfg.trust_remote_code or False,
|
|
|
114 |
Load a model for a given configuration and tokenizer.
|
115 |
"""
|
116 |
base_model = cfg.base_model
|
|
|
117 |
model_type = cfg.model_type
|
118 |
model_config = load_model_config(cfg)
|
119 |
|
|
|
241 |
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
242 |
from transformers import LlamaForCausalLM
|
243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
model = LlamaForCausalLM.from_pretrained(
|
245 |
base_model,
|
246 |
+
config=model_config,
|
247 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
248 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
249 |
**model_kwargs,
|
|
|
301 |
if cfg.gptq:
|
302 |
model = AutoModelForCausalLM.from_pretrained(
|
303 |
base_model,
|
304 |
+
config=model_config,
|
305 |
trust_remote_code=cfg.trust_remote_code or False,
|
306 |
**model_kwargs,
|
307 |
)
|
308 |
else:
|
309 |
model = getattr(transformers, model_type).from_pretrained(
|
310 |
base_model,
|
311 |
+
config=model_config,
|
312 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
313 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
314 |
trust_remote_code=cfg.trust_remote_code or False,
|
315 |
**model_kwargs,
|
316 |
)
|
317 |
else:
|
|
|
|
|
|
|
|
|
318 |
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
319 |
# when training starts
|
320 |
if (
|
321 |
+
hasattr(model_config, "max_seq_len")
|
322 |
+
and model_config.max_seq_len
|
323 |
+
and cfg.sequence_len > model_config.max_seq_len
|
324 |
):
|
325 |
+
model_config.max_seq_len = cfg.sequence_len
|
326 |
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
327 |
elif (
|
328 |
+
hasattr(model_config, "max_sequence_length")
|
329 |
+
and model_config.max_sequence_length
|
330 |
+
and cfg.sequence_len > model_config.max_sequence_length
|
331 |
):
|
332 |
+
model_config.max_sequence_length = cfg.sequence_len
|
333 |
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
334 |
if cfg.gptq:
|
335 |
model = AutoModelForCausalLM.from_pretrained(
|
336 |
base_model,
|
337 |
+
config=model_config,
|
338 |
trust_remote_code=cfg.trust_remote_code or False,
|
339 |
**model_kwargs,
|
340 |
)
|
341 |
else:
|
342 |
model = AutoModelForCausalLM.from_pretrained(
|
343 |
base_model,
|
344 |
+
config=model_config,
|
345 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
346 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
347 |
trust_remote_code=cfg.trust_remote_code or False,
|
348 |
**model_kwargs,
|
349 |
)
|
350 |
except Exception as err: # pylint: disable=broad-exception-caught
|
|
|
|
|
|
|
351 |
LOG.exception(err)
|
352 |
+
raise err
|
|
|
|
|
|
|
|
|
|
|
|
|
353 |
|
354 |
embeddings_len = (
|
355 |
math.ceil(len(tokenizer) / 32) * 32
|