winglian commited on
Commit
b3f5e00
·
unverified ·
1 Parent(s): 5247c50

use save_strategy from config if available (#434)

Browse files

* use save_strategy from config if available

* update docs for save_strategy

Files changed (2) hide show
  1. README.md +1 -0
  2. src/axolotl/utils/trainer.py +7 -1
README.md CHANGED
@@ -472,6 +472,7 @@ warmup_steps: 100
472
  learning_rate: 0.00003
473
  lr_quadratic_warmup:
474
  logging_steps:
 
475
  save_steps: # leave empty to save at each epoch
476
  eval_steps:
477
  save_total_limit: # checkpoints saved at a time
 
472
  learning_rate: 0.00003
473
  lr_quadratic_warmup:
474
  logging_steps:
475
+ save_strategy: # set to `no` to skip checkpoint saves
476
  save_steps: # leave empty to save at each epoch
477
  eval_steps:
478
  save_total_limit: # checkpoints saved at a time
src/axolotl/utils/trainer.py CHANGED
@@ -457,6 +457,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
457
  # we have an eval set, but no steps defined, use epoch
458
  training_arguments_kwargs["evaluation_strategy"] = "epoch"
459
 
 
 
 
 
 
 
 
460
  training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
461
  max_steps=total_num_steps if cfg.max_steps else -1,
462
  max_seq_length=cfg.sequence_len,
@@ -468,7 +475,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
468
  eval_accumulation_steps=cfg.gradient_accumulation_steps,
469
  num_train_epochs=cfg.num_epochs,
470
  learning_rate=cfg.learning_rate,
471
- save_strategy="steps" if cfg.save_steps else "epoch",
472
  save_steps=cfg.save_steps,
473
  output_dir=cfg.output_dir,
474
  save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
 
457
  # we have an eval set, but no steps defined, use epoch
458
  training_arguments_kwargs["evaluation_strategy"] = "epoch"
459
 
460
+ if cfg.save_strategy:
461
+ training_arguments_kwargs["save_strategy"] = cfg.save_strategy
462
+ else:
463
+ training_arguments_kwargs["save_strategy"] = (
464
+ "steps" if cfg.save_steps else "epoch",
465
+ )
466
+
467
  training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
468
  max_steps=total_num_steps if cfg.max_steps else -1,
469
  max_seq_length=cfg.sequence_len,
 
475
  eval_accumulation_steps=cfg.gradient_accumulation_steps,
476
  num_train_epochs=cfg.num_epochs,
477
  learning_rate=cfg.learning_rate,
 
478
  save_steps=cfg.save_steps,
479
  output_dir=cfg.output_dir,
480
  save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,