winglian commited on
Commit
1bc1186
·
unverified ·
1 Parent(s): b3a61e8

allow overriding of model_config parameters from the YML (#853)

Browse files

* allow overriding of model_config parameters from the YML

* remove old logging, update readme

* move the updating of model config to the load_model_config function

* add warning for deprecated rope_scaling in the root of the YML config

README.md CHANGED
@@ -489,6 +489,14 @@ is_llama_derived_model:
489
  # Please note that if you set this to true, `padding_side` will be set to "left" by default
490
  is_mistral_derived_model:
491
 
 
 
 
 
 
 
 
 
492
  # Whether you are training a 4-bit GPTQ quantized model
493
  gptq: true
494
  gptq_groupsize: 128 # group size
@@ -756,10 +764,6 @@ landmark_attention:
756
  # xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
757
  # LLaMA only
758
  xpos_rope:
759
- # RoPE Scaling https://github.com/huggingface/transformers/pull/24653
760
- rope_scaling:
761
- type: # linear | dynamic
762
- factor: # float
763
 
764
  # Resume from a specific checkpoint dir
765
  resume_from_checkpoint:
 
489
  # Please note that if you set this to true, `padding_side` will be set to "left" by default
490
  is_mistral_derived_model:
491
 
492
+ # optional overrides to the base model configuration
493
+ model_config:
494
+ # RoPE Scaling https://github.com/huggingface/transformers/pull/24653
495
+ rope_scaling:
496
+ type: # linear | dynamic
497
+ factor: # float
498
+
499
+
500
  # Whether you are training a 4-bit GPTQ quantized model
501
  gptq: true
502
  gptq_groupsize: 128 # group size
 
764
  # xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
765
  # LLaMA only
766
  xpos_rope:
 
 
 
 
767
 
768
  # Resume from a specific checkpoint dir
769
  resume_from_checkpoint:
src/axolotl/utils/config.py CHANGED
@@ -369,6 +369,9 @@ def validate_config(cfg):
369
  "If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
370
  )
371
 
 
 
 
372
  # TODO
373
  # MPT 7b
374
  # https://github.com/facebookresearch/bitsandbytes/issues/25
 
369
  "If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
370
  )
371
 
372
+ if cfg.rope_scaling:
373
+ LOG.warning("`rope_scaling` should now be be a key under `model_config`")
374
+
375
  # TODO
376
  # MPT 7b
377
  # https://github.com/facebookresearch/bitsandbytes/issues/25
src/axolotl/utils/models.py CHANGED
@@ -17,7 +17,6 @@ from transformers import ( # noqa: F401
17
  AutoTokenizer,
18
  BitsAndBytesConfig,
19
  GPTQConfig,
20
- LlamaConfig,
21
  PreTrainedModel,
22
  PreTrainedTokenizerBase,
23
  )
@@ -32,9 +31,14 @@ LOG = logging.getLogger("axolotl")
32
  def load_model_config(cfg):
33
  model_config_name = cfg.base_model_config or cfg.base_model
34
  trust_remote_code = cfg.trust_remote_code is True
35
- return AutoConfig.from_pretrained(
36
  model_config_name, trust_remote_code=trust_remote_code
37
  )
 
 
 
 
 
38
 
39
 
40
  def load_tokenizer(cfg):
@@ -51,7 +55,7 @@ def load_tokenizer(cfg):
51
  if cfg.tokenizer_type:
52
  tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
53
 
54
- tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
55
  tokenizer = tokenizer_cls.from_pretrained(
56
  tokenizer_config,
57
  trust_remote_code=cfg.trust_remote_code or False,
@@ -110,7 +114,6 @@ def load_model(
110
  Load a model for a given configuration and tokenizer.
111
  """
112
  base_model = cfg.base_model
113
- base_model_config = cfg.base_model_config
114
  model_type = cfg.model_type
115
  model_config = load_model_config(cfg)
116
 
@@ -238,16 +241,9 @@ def load_model(
238
  if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
239
  from transformers import LlamaForCausalLM
240
 
241
- config_kwargs = {}
242
- if cfg.rope_scaling:
243
- config_kwargs["rope_scaling"] = cfg.rope_scaling
244
- config = LlamaConfig.from_pretrained(
245
- base_model_config,
246
- **config_kwargs,
247
- )
248
  model = LlamaForCausalLM.from_pretrained(
249
  base_model,
250
- config=config,
251
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
252
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
253
  **model_kwargs,
@@ -305,66 +301,55 @@ def load_model(
305
  if cfg.gptq:
306
  model = AutoModelForCausalLM.from_pretrained(
307
  base_model,
 
308
  trust_remote_code=cfg.trust_remote_code or False,
309
  **model_kwargs,
310
  )
311
  else:
312
  model = getattr(transformers, model_type).from_pretrained(
313
  base_model,
 
314
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
315
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
316
  trust_remote_code=cfg.trust_remote_code or False,
317
  **model_kwargs,
318
  )
319
  else:
320
- config = AutoConfig.from_pretrained(
321
- base_model,
322
- trust_remote_code=cfg.trust_remote_code or False,
323
- )
324
  # Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
325
  # when training starts
326
  if (
327
- hasattr(config, "max_seq_len")
328
- and config.max_seq_len
329
- and cfg.sequence_len > config.max_seq_len
330
  ):
331
- config.max_seq_len = cfg.sequence_len
332
  LOG.warning(f"increasing context length to {cfg.sequence_len}")
333
  elif (
334
- hasattr(config, "max_sequence_length")
335
- and config.max_sequence_length
336
- and cfg.sequence_len > config.max_sequence_length
337
  ):
338
- config.max_sequence_length = cfg.sequence_len
339
  LOG.warning(f"increasing context length to {cfg.sequence_len}")
340
  if cfg.gptq:
341
  model = AutoModelForCausalLM.from_pretrained(
342
  base_model,
343
- config=config,
344
  trust_remote_code=cfg.trust_remote_code or False,
345
  **model_kwargs,
346
  )
347
  else:
348
  model = AutoModelForCausalLM.from_pretrained(
349
  base_model,
350
- config=config,
351
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
352
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
353
  trust_remote_code=cfg.trust_remote_code or False,
354
  **model_kwargs,
355
  )
356
  except Exception as err: # pylint: disable=broad-exception-caught
357
- LOG.error(
358
- "Exception raised attempting to load model, retrying with AutoModelForCausalLM"
359
- )
360
  LOG.exception(err)
361
- model = AutoModelForCausalLM.from_pretrained(
362
- base_model,
363
- load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
364
- load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
365
- trust_remote_code=cfg.trust_remote_code or False,
366
- **model_kwargs,
367
- )
368
 
369
  embeddings_len = (
370
  math.ceil(len(tokenizer) / 32) * 32
 
17
  AutoTokenizer,
18
  BitsAndBytesConfig,
19
  GPTQConfig,
 
20
  PreTrainedModel,
21
  PreTrainedTokenizerBase,
22
  )
 
31
  def load_model_config(cfg):
32
  model_config_name = cfg.base_model_config or cfg.base_model
33
  trust_remote_code = cfg.trust_remote_code is True
34
+ model_config = AutoConfig.from_pretrained(
35
  model_config_name, trust_remote_code=trust_remote_code
36
  )
37
+ if cfg.model_config:
38
+ for key, val in cfg.model_config.items():
39
+ setattr(model_config, key, val)
40
+
41
+ return model_config
42
 
43
 
44
  def load_tokenizer(cfg):
 
55
  if cfg.tokenizer_type:
56
  tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
57
 
58
+ tokenizer_config = cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
59
  tokenizer = tokenizer_cls.from_pretrained(
60
  tokenizer_config,
61
  trust_remote_code=cfg.trust_remote_code or False,
 
114
  Load a model for a given configuration and tokenizer.
115
  """
116
  base_model = cfg.base_model
 
117
  model_type = cfg.model_type
118
  model_config = load_model_config(cfg)
119
 
 
241
  if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
242
  from transformers import LlamaForCausalLM
243
 
 
 
 
 
 
 
 
244
  model = LlamaForCausalLM.from_pretrained(
245
  base_model,
246
+ config=model_config,
247
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
248
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
249
  **model_kwargs,
 
301
  if cfg.gptq:
302
  model = AutoModelForCausalLM.from_pretrained(
303
  base_model,
304
+ config=model_config,
305
  trust_remote_code=cfg.trust_remote_code or False,
306
  **model_kwargs,
307
  )
308
  else:
309
  model = getattr(transformers, model_type).from_pretrained(
310
  base_model,
311
+ config=model_config,
312
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
313
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
314
  trust_remote_code=cfg.trust_remote_code or False,
315
  **model_kwargs,
316
  )
317
  else:
 
 
 
 
318
  # Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
319
  # when training starts
320
  if (
321
+ hasattr(model_config, "max_seq_len")
322
+ and model_config.max_seq_len
323
+ and cfg.sequence_len > model_config.max_seq_len
324
  ):
325
+ model_config.max_seq_len = cfg.sequence_len
326
  LOG.warning(f"increasing context length to {cfg.sequence_len}")
327
  elif (
328
+ hasattr(model_config, "max_sequence_length")
329
+ and model_config.max_sequence_length
330
+ and cfg.sequence_len > model_config.max_sequence_length
331
  ):
332
+ model_config.max_sequence_length = cfg.sequence_len
333
  LOG.warning(f"increasing context length to {cfg.sequence_len}")
334
  if cfg.gptq:
335
  model = AutoModelForCausalLM.from_pretrained(
336
  base_model,
337
+ config=model_config,
338
  trust_remote_code=cfg.trust_remote_code or False,
339
  **model_kwargs,
340
  )
341
  else:
342
  model = AutoModelForCausalLM.from_pretrained(
343
  base_model,
344
+ config=model_config,
345
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
346
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
347
  trust_remote_code=cfg.trust_remote_code or False,
348
  **model_kwargs,
349
  )
350
  except Exception as err: # pylint: disable=broad-exception-caught
 
 
 
351
  LOG.exception(err)
352
+ raise err
 
 
 
 
 
 
353
 
354
  embeddings_len = (
355
  math.ceil(len(tokenizer) / 32) * 32