Jan Philipp Harries Jan Philipp Harries commited on
Commit
2f586d1
1 Parent(s): 9845c5e

Fix pretraining with iterable/streaming Dataset (#556)

Browse files

* return without packing prep/len

* fix remove columns

* fix encode arguments

* add error when max steps not set

* fix test

---------

Co-authored-by: Jan Philipp Harries <[email protected]>

src/axolotl/utils/config.py CHANGED
@@ -191,6 +191,10 @@ def validate_config(cfg):
191
  LOG.warning(
192
  "You probably want to disable group_by_length as it will force a streamed dataset to download completely."
193
  )
 
 
 
 
194
 
195
  if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
196
  not cfg.optimizer or "adamw" not in cfg.optimizer
 
191
  LOG.warning(
192
  "You probably want to disable group_by_length as it will force a streamed dataset to download completely."
193
  )
194
+ if cfg.pretraining_dataset and not cfg.max_steps:
195
+ raise ValueError(
196
+ "max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!"
197
+ )
198
 
199
  if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
200
  not cfg.optimizer or "adamw" not in cfg.optimizer
src/axolotl/utils/data.py CHANGED
@@ -3,7 +3,7 @@ import functools
3
  import hashlib
4
  import logging
5
  from pathlib import Path
6
- from typing import Tuple, Union
7
 
8
  import torch
9
  from datasets import (
@@ -74,6 +74,7 @@ def prepare_dataset(cfg, tokenizer):
74
  # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
75
  train_dataset = train_dataset.with_format("torch")
76
  eval_dataset = None
 
77
 
78
  with zero_first(is_main_process()):
79
  train_dataset, eval_dataset = process_datasets_for_packing(
@@ -527,9 +528,11 @@ def load_prepare_datasets(
527
  return train_dataset, eval_dataset
528
 
529
 
530
- def encode_pretraining(tokenizer, max_tokens, examples):
 
 
531
  res = tokenizer(
532
- examples["text"],
533
  truncation=True,
534
  max_length=max_tokens - 2,
535
  add_special_tokens=True,
@@ -637,6 +640,12 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
637
  encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
638
  dataset = load_dataset(path, streaming=True, split="train")
639
  dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
640
- # TODO dynamically figure out which columns/features to remove
641
- dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
 
 
 
 
 
 
642
  return dataset
 
3
  import hashlib
4
  import logging
5
  from pathlib import Path
6
+ from typing import Dict, List, Tuple, Union
7
 
8
  import torch
9
  from datasets import (
 
74
  # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
75
  train_dataset = train_dataset.with_format("torch")
76
  eval_dataset = None
77
+ return train_dataset, eval_dataset, cfg.max_steps
78
 
79
  with zero_first(is_main_process()):
80
  train_dataset, eval_dataset = process_datasets_for_packing(
 
528
  return train_dataset, eval_dataset
529
 
530
 
531
+ def encode_pretraining(
532
+ tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
533
+ ) -> Dict[str, List]:
534
  res = tokenizer(
535
+ examples,
536
  truncation=True,
537
  max_length=max_tokens - 2,
538
  add_special_tokens=True,
 
640
  encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
641
  dataset = load_dataset(path, streaming=True, split="train")
642
  dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
643
+ dataset = dataset.map(
644
+ encode,
645
+ batched=True,
646
+ input_columns="text",
647
+ remove_columns=[
648
+ "text",
649
+ ],
650
+ )
651
  return dataset
tests/test_data.py CHANGED
@@ -35,7 +35,7 @@ class TestEncodePretraining(unittest.TestCase):
35
  "hello, hello",
36
  ]
37
  }
38
- result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
39
 
40
  self.assertEqual(len(result["input_ids"]), 3)
41
 
 
35
  "hello, hello",
36
  ]
37
  }
38
+ result = encode_pretraining(self.tokenizer, self.max_tokens, examples["text"])
39
 
40
  self.assertEqual(len(result["input_ids"]), 3)
41