Jan Philipp Harries
Jan Philipp Harries
commited on
Commit
•
2f586d1
1
Parent(s):
9845c5e
Fix pretraining with iterable/streaming Dataset (#556)
Browse files* return without packing prep/len
* fix remove columns
* fix encode arguments
* add error when max steps not set
* fix test
---------
Co-authored-by: Jan Philipp Harries <[email protected]>
- src/axolotl/utils/config.py +4 -0
- src/axolotl/utils/data.py +14 -5
- tests/test_data.py +1 -1
src/axolotl/utils/config.py
CHANGED
@@ -191,6 +191,10 @@ def validate_config(cfg):
|
|
191 |
LOG.warning(
|
192 |
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
193 |
)
|
|
|
|
|
|
|
|
|
194 |
|
195 |
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
|
196 |
not cfg.optimizer or "adamw" not in cfg.optimizer
|
|
|
191 |
LOG.warning(
|
192 |
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
193 |
)
|
194 |
+
if cfg.pretraining_dataset and not cfg.max_steps:
|
195 |
+
raise ValueError(
|
196 |
+
"max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!"
|
197 |
+
)
|
198 |
|
199 |
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
|
200 |
not cfg.optimizer or "adamw" not in cfg.optimizer
|
src/axolotl/utils/data.py
CHANGED
@@ -3,7 +3,7 @@ import functools
|
|
3 |
import hashlib
|
4 |
import logging
|
5 |
from pathlib import Path
|
6 |
-
from typing import Tuple, Union
|
7 |
|
8 |
import torch
|
9 |
from datasets import (
|
@@ -74,6 +74,7 @@ def prepare_dataset(cfg, tokenizer):
|
|
74 |
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
75 |
train_dataset = train_dataset.with_format("torch")
|
76 |
eval_dataset = None
|
|
|
77 |
|
78 |
with zero_first(is_main_process()):
|
79 |
train_dataset, eval_dataset = process_datasets_for_packing(
|
@@ -527,9 +528,11 @@ def load_prepare_datasets(
|
|
527 |
return train_dataset, eval_dataset
|
528 |
|
529 |
|
530 |
-
def encode_pretraining(
|
|
|
|
|
531 |
res = tokenizer(
|
532 |
-
examples
|
533 |
truncation=True,
|
534 |
max_length=max_tokens - 2,
|
535 |
add_special_tokens=True,
|
@@ -637,6 +640,12 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
|
|
637 |
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
638 |
dataset = load_dataset(path, streaming=True, split="train")
|
639 |
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
|
640 |
-
|
641 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
642 |
return dataset
|
|
|
3 |
import hashlib
|
4 |
import logging
|
5 |
from pathlib import Path
|
6 |
+
from typing import Dict, List, Tuple, Union
|
7 |
|
8 |
import torch
|
9 |
from datasets import (
|
|
|
74 |
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
75 |
train_dataset = train_dataset.with_format("torch")
|
76 |
eval_dataset = None
|
77 |
+
return train_dataset, eval_dataset, cfg.max_steps
|
78 |
|
79 |
with zero_first(is_main_process()):
|
80 |
train_dataset, eval_dataset = process_datasets_for_packing(
|
|
|
528 |
return train_dataset, eval_dataset
|
529 |
|
530 |
|
531 |
+
def encode_pretraining(
|
532 |
+
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
|
533 |
+
) -> Dict[str, List]:
|
534 |
res = tokenizer(
|
535 |
+
examples,
|
536 |
truncation=True,
|
537 |
max_length=max_tokens - 2,
|
538 |
add_special_tokens=True,
|
|
|
640 |
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
641 |
dataset = load_dataset(path, streaming=True, split="train")
|
642 |
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
|
643 |
+
dataset = dataset.map(
|
644 |
+
encode,
|
645 |
+
batched=True,
|
646 |
+
input_columns="text",
|
647 |
+
remove_columns=[
|
648 |
+
"text",
|
649 |
+
],
|
650 |
+
)
|
651 |
return dataset
|
tests/test_data.py
CHANGED
@@ -35,7 +35,7 @@ class TestEncodePretraining(unittest.TestCase):
|
|
35 |
"hello, hello",
|
36 |
]
|
37 |
}
|
38 |
-
result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
|
39 |
|
40 |
self.assertEqual(len(result["input_ids"]), 3)
|
41 |
|
|
|
35 |
"hello, hello",
|
36 |
]
|
37 |
}
|
38 |
+
result = encode_pretraining(self.tokenizer, self.max_tokens, examples["text"])
|
39 |
|
40 |
self.assertEqual(len(result["input_ids"]), 3)
|
41 |
|