Make dataset_processes configurable (#651)
Browse filesI'm using the Axolotl script to train models on https://modal.com serverless GPUs. Unfortunately, their environment seems to have some kind of bug where if I try to run `datasets.filter` with too high a `num_proc`, it throws an error and dies.
This PR adds a new configuration option `dataset_processes`, which lets you explicitly set the number of processes used to map/filter the dataset. If not included, this defaults to the current behavior of setting that to `os.cpu_count()`.
- README.md +3 -0
- src/axolotl/utils/config.py +2 -0
- src/axolotl/utils/trainer.py +11 -5
README.md
CHANGED
@@ -487,6 +487,9 @@ datasets:
|
|
487 |
dataset_prepared_path: data/last_run_prepared
|
488 |
# push prepared dataset to hub
|
489 |
push_dataset_to_hub: # repo path
|
|
|
|
|
|
|
490 |
# push checkpoints to hub
|
491 |
hub_model_id: # repo path to push finetuned model
|
492 |
# how to push checkpoints to hub
|
|
|
487 |
dataset_prepared_path: data/last_run_prepared
|
488 |
# push prepared dataset to hub
|
489 |
push_dataset_to_hub: # repo path
|
490 |
+
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
491 |
+
# if not set.
|
492 |
+
dataset_processes: # defaults to os.cpu_count() if not set
|
493 |
# push checkpoints to hub
|
494 |
hub_model_id: # repo path to push finetuned model
|
495 |
# how to push checkpoints to hub
|
src/axolotl/utils/config.py
CHANGED
@@ -75,6 +75,8 @@ def normalize_config(cfg):
|
|
75 |
else:
|
76 |
cfg.torch_dtype = torch.float32
|
77 |
|
|
|
|
|
78 |
model_config = load_model_config(cfg)
|
79 |
cfg.model_config_type = model_config.model_type
|
80 |
|
|
|
75 |
else:
|
76 |
cfg.torch_dtype = torch.float32
|
77 |
|
78 |
+
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
|
79 |
+
|
80 |
model_config = load_model_config(cfg)
|
81 |
cfg.model_config_type = model_config.model_type
|
82 |
|
src/axolotl/utils/trainer.py
CHANGED
@@ -400,19 +400,25 @@ def disable_datasets_caching():
|
|
400 |
def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
401 |
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
402 |
with zero_first(is_main_process()):
|
403 |
-
train_dataset = train_dataset.filter(drop_long, num_proc=
|
404 |
if eval_dataset:
|
405 |
-
eval_dataset = eval_dataset.filter(
|
|
|
|
|
406 |
|
407 |
if cfg.group_by_length:
|
408 |
-
train_dataset = train_dataset.map(
|
|
|
|
|
409 |
|
410 |
if cfg.sample_packing:
|
411 |
-
train_dataset = train_dataset.map(
|
|
|
|
|
412 |
if cfg.eval_sample_packing is not False:
|
413 |
if eval_dataset:
|
414 |
eval_dataset = eval_dataset.map(
|
415 |
-
add_position_ids, num_proc=
|
416 |
)
|
417 |
|
418 |
# Phi doesn't want the attention_mask feature when training
|
|
|
400 |
def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
401 |
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
402 |
with zero_first(is_main_process()):
|
403 |
+
train_dataset = train_dataset.filter(drop_long, num_proc=cfg.dataset_processes)
|
404 |
if eval_dataset:
|
405 |
+
eval_dataset = eval_dataset.filter(
|
406 |
+
drop_long, num_proc=cfg.dataset_processes
|
407 |
+
)
|
408 |
|
409 |
if cfg.group_by_length:
|
410 |
+
train_dataset = train_dataset.map(
|
411 |
+
add_length, num_proc=cfg.dataset_processes
|
412 |
+
)
|
413 |
|
414 |
if cfg.sample_packing:
|
415 |
+
train_dataset = train_dataset.map(
|
416 |
+
add_position_ids, num_proc=cfg.dataset_processes
|
417 |
+
)
|
418 |
if cfg.eval_sample_packing is not False:
|
419 |
if eval_dataset:
|
420 |
eval_dataset = eval_dataset.map(
|
421 |
+
add_position_ids, num_proc=cfg.dataset_processes
|
422 |
)
|
423 |
|
424 |
# Phi doesn't want the attention_mask feature when training
|