Jan Philipp Harries
Jan Philipp Harries
commited on
Added advanced DDP args (#515)
Browse files* add ddp_config
* add advanced ddp config
* add ddp_config
* add advanced ddp config
---------
Co-authored-by: Jan Philipp Harries <[email protected]>
- README.md +5 -0
- src/axolotl/utils/trainer.py +9 -0
README.md
CHANGED
@@ -623,6 +623,11 @@ fsdp_config:
|
|
623 |
# Deepspeed config path
|
624 |
deepspeed:
|
625 |
|
|
|
|
|
|
|
|
|
|
|
626 |
# Path to torch distx for optim 'adamw_anyprecision'
|
627 |
torchdistx_path:
|
628 |
|
|
|
623 |
# Deepspeed config path
|
624 |
deepspeed:
|
625 |
|
626 |
+
# Advanced DDP Arguments
|
627 |
+
ddp_timeout:
|
628 |
+
ddp_bucket_cap_mb:
|
629 |
+
ddp_broadcast_buffers:
|
630 |
+
|
631 |
# Path to torch distx for optim 'adamw_anyprecision'
|
632 |
torchdistx_path:
|
633 |
|
src/axolotl/utils/trainer.py
CHANGED
@@ -579,6 +579,15 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
579 |
if cfg.bench_dataset:
|
580 |
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
|
581 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
582 |
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
583 |
max_steps=total_num_steps if cfg.max_steps else -1,
|
584 |
max_seq_length=cfg.sequence_len,
|
|
|
579 |
if cfg.bench_dataset:
|
580 |
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
|
581 |
|
582 |
+
# DDP Config
|
583 |
+
if cfg.ddp_timeout:
|
584 |
+
training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout
|
585 |
+
# see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
|
586 |
+
if cfg.ddp_bucket_cap_mb:
|
587 |
+
training_arguments_kwargs["ddp_bucket_cap_mb"] = cfg.ddp_bucket_cap_mb
|
588 |
+
if cfg.ddp_broadcast_buffers is not None:
|
589 |
+
training_arguments_kwargs["ddp_broadcast_buffers"] = cfg.ddp_broadcast_buffers
|
590 |
+
|
591 |
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
592 |
max_steps=total_num_steps if cfg.max_steps else -1,
|
593 |
max_seq_length=cfg.sequence_len,
|