corbt commited on
Commit
9ec2077
·
unverified ·
1 Parent(s): 590d603

Make dataset_processes configurable (#651)

Browse files

I'm using the Axolotl script to train models on https://modal.com serverless GPUs. Unfortunately, their environment seems to have some kind of bug where if I try to run `datasets.filter` with too high a `num_proc`, it throws an error and dies.

This PR adds a new configuration option `dataset_processes`, which lets you explicitly set the number of processes used to map/filter the dataset. If not included, this defaults to the current behavior of setting that to `os.cpu_count()`.

README.md CHANGED
@@ -487,6 +487,9 @@ datasets:
487
  dataset_prepared_path: data/last_run_prepared
488
  # push prepared dataset to hub
489
  push_dataset_to_hub: # repo path
 
 
 
490
  # push checkpoints to hub
491
  hub_model_id: # repo path to push finetuned model
492
  # how to push checkpoints to hub
 
487
  dataset_prepared_path: data/last_run_prepared
488
  # push prepared dataset to hub
489
  push_dataset_to_hub: # repo path
490
+ # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
491
+ # if not set.
492
+ dataset_processes: # defaults to os.cpu_count() if not set
493
  # push checkpoints to hub
494
  hub_model_id: # repo path to push finetuned model
495
  # how to push checkpoints to hub
src/axolotl/utils/config.py CHANGED
@@ -75,6 +75,8 @@ def normalize_config(cfg):
75
  else:
76
  cfg.torch_dtype = torch.float32
77
 
 
 
78
  model_config = load_model_config(cfg)
79
  cfg.model_config_type = model_config.model_type
80
 
 
75
  else:
76
  cfg.torch_dtype = torch.float32
77
 
78
+ cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
79
+
80
  model_config = load_model_config(cfg)
81
  cfg.model_config_type = model_config.model_type
82
 
src/axolotl/utils/trainer.py CHANGED
@@ -400,19 +400,25 @@ def disable_datasets_caching():
400
  def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
401
  drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
402
  with zero_first(is_main_process()):
403
- train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count())
404
  if eval_dataset:
405
- eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
 
 
406
 
407
  if cfg.group_by_length:
408
- train_dataset = train_dataset.map(add_length, num_proc=os.cpu_count())
 
 
409
 
410
  if cfg.sample_packing:
411
- train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
 
 
412
  if cfg.eval_sample_packing is not False:
413
  if eval_dataset:
414
  eval_dataset = eval_dataset.map(
415
- add_position_ids, num_proc=os.cpu_count()
416
  )
417
 
418
  # Phi doesn't want the attention_mask feature when training
 
400
  def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
401
  drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
402
  with zero_first(is_main_process()):
403
+ train_dataset = train_dataset.filter(drop_long, num_proc=cfg.dataset_processes)
404
  if eval_dataset:
405
+ eval_dataset = eval_dataset.filter(
406
+ drop_long, num_proc=cfg.dataset_processes
407
+ )
408
 
409
  if cfg.group_by_length:
410
+ train_dataset = train_dataset.map(
411
+ add_length, num_proc=cfg.dataset_processes
412
+ )
413
 
414
  if cfg.sample_packing:
415
+ train_dataset = train_dataset.map(
416
+ add_position_ids, num_proc=cfg.dataset_processes
417
+ )
418
  if cfg.eval_sample_packing is not False:
419
  if eval_dataset:
420
  eval_dataset = eval_dataset.map(
421
+ add_position_ids, num_proc=cfg.dataset_processes
422
  )
423
 
424
  # Phi doesn't want the attention_mask feature when training