Birch-san tmm1 commited on
Commit
8e197f6
·
unverified ·
1 Parent(s): 267b7b2

pad_to_worst_case_seq_len boolean, for testing memory limits (#498)

Browse files

* pad_to_worst_case_seq_len boolean, for testing memory limits

* remove collator_pad_to_longest option since it does nothing

see docs: https://huggingface.co/docs/transformers/main_classes/data_collator#transformers.DataCollatorWithPadding.padding

True and "longest" mean the same thing

* rename to `pad_to_sequence_len, and ensure 64 alignment

---------

Co-authored-by: Aman Karmani <[email protected]>

README.md CHANGED
@@ -459,6 +459,9 @@ dataset_shard_idx:
459
  # the maximum length of an input to train with, this should typically be less than 2048
460
  # as most models have a token/context limit of 2048
461
  sequence_len: 2048
 
 
 
462
  # max sequence length to concatenate training samples together up to
463
  # inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
464
  # FutureWarning: This will soon be DEPRECATED
@@ -610,9 +613,6 @@ deepspeed:
610
  # Path to torch distx for optim 'adamw_anyprecision'
611
  torchdistx_path:
612
 
613
- # Set padding for data collator to 'longest'
614
- collator_pad_to_longest:
615
-
616
  # Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
617
  pretraining_dataset:
618
 
 
459
  # the maximum length of an input to train with, this should typically be less than 2048
460
  # as most models have a token/context limit of 2048
461
  sequence_len: 2048
462
+ # pad inputs so each step uses constant sized buffers
463
+ # this will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
464
+ pad_to_sequence_len:
465
  # max sequence length to concatenate training samples together up to
466
  # inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
467
  # FutureWarning: This will soon be DEPRECATED
 
613
  # Path to torch distx for optim 'adamw_anyprecision'
614
  torchdistx_path:
615
 
 
 
 
616
  # Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
617
  pretraining_dataset:
618
 
examples/pythia-12b/config.yml CHANGED
@@ -47,4 +47,3 @@ local_rank:
47
  gradient_checkpointing: true
48
  fsdp:
49
  fsdp_config:
50
- collator_pad_to_longest: true
 
47
  gradient_checkpointing: true
48
  fsdp:
49
  fsdp_config:
 
src/axolotl/utils/trainer.py CHANGED
@@ -585,10 +585,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
585
  callbacks.append(SaveBetterTransformerModelCallback)
586
 
587
  data_collator_kwargs = {
588
- "padding": True,
589
  }
590
- if cfg.collator_pad_to_longest:
591
- data_collator_kwargs["padding"] = "longest"
592
  else:
593
  # A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
594
  # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
 
585
  callbacks.append(SaveBetterTransformerModelCallback)
586
 
587
  data_collator_kwargs = {
588
+ "padding": True, # True/"longest" is the default
589
  }
590
+ if cfg.pad_to_sequence_len:
591
+ data_collator_kwargs["pad_to_multiple_of"] = 64 * round(cfg.sequence_len / 64)
592
  else:
593
  # A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
594
  # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html