monsoon-nlp
commited on
fix pretraining_ on odd datasets (#1463)
Browse files* can configure name of split of pretraining dataset
* streaming data and dataset map
* text column customized
* allow text_column to be set in pretrain
* pretrain type
* load a bit of the dataset
* fix dataset where splits have separate configs
* ok name param here is the config
* whitespace
src/axolotl/prompt_strategies/pretrain.py
CHANGED
@@ -20,10 +20,11 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy):
|
|
20 |
def supports_batched(self):
|
21 |
return True
|
22 |
|
23 |
-
def __init__(self, *args, max_length=None, **kwargs):
|
24 |
super().__init__(*args, **kwargs)
|
25 |
if max_length:
|
26 |
self.max_length = max_length
|
|
|
27 |
|
28 |
def _tokenize(
|
29 |
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
@@ -44,7 +45,7 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy):
|
|
44 |
return res
|
45 |
|
46 |
def tokenize_prompt(self, prompt):
|
47 |
-
return self._tokenize(prompt[
|
48 |
|
49 |
|
50 |
def load(tokenizer, cfg):
|
@@ -53,6 +54,7 @@ def load(tokenizer, cfg):
|
|
53 |
tokenizer,
|
54 |
cfg.train_on_inputs,
|
55 |
cfg.sequence_len,
|
|
|
56 |
max_length=cfg.sequence_len * 64,
|
57 |
)
|
58 |
return strat
|
|
|
20 |
def supports_batched(self):
|
21 |
return True
|
22 |
|
23 |
+
def __init__(self, *args, max_length=None, text_column="text", **kwargs):
|
24 |
super().__init__(*args, **kwargs)
|
25 |
if max_length:
|
26 |
self.max_length = max_length
|
27 |
+
self.text_column = text_column
|
28 |
|
29 |
def _tokenize(
|
30 |
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
|
|
45 |
return res
|
46 |
|
47 |
def tokenize_prompt(self, prompt):
|
48 |
+
return self._tokenize(prompt[self.text_column])
|
49 |
|
50 |
|
51 |
def load(tokenizer, cfg):
|
|
|
54 |
tokenizer,
|
55 |
cfg.train_on_inputs,
|
56 |
cfg.sequence_len,
|
57 |
+
text_column=cfg.pretraining_dataset[0]["text_column"] or "text",
|
58 |
max_length=cfg.sequence_len * 64,
|
59 |
)
|
60 |
return strat
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
@@ -61,7 +61,11 @@ class RemappedParameters(BaseModel):
|
|
61 |
class PretrainingDataset(BaseModel):
|
62 |
"""pretraining dataset configuration subset"""
|
63 |
|
|
|
64 |
path: Optional[str] = None
|
|
|
|
|
|
|
65 |
|
66 |
|
67 |
class UserDefinedPrompterType(BaseModel):
|
@@ -448,7 +452,7 @@ class AxolotlInputConfig(
|
|
448 |
dataset_shard_idx: Optional[int] = None
|
449 |
|
450 |
pretraining_dataset: Optional[ # type: ignore
|
451 |
-
conlist(Union[
|
452 |
] = Field(
|
453 |
default=None, metadata={"help": {"streaming dataset to use for pretraining"}}
|
454 |
)
|
|
|
61 |
class PretrainingDataset(BaseModel):
|
62 |
"""pretraining dataset configuration subset"""
|
63 |
|
64 |
+
name: Optional[str] = None
|
65 |
path: Optional[str] = None
|
66 |
+
split: Optional[str] = "train"
|
67 |
+
text_column: Optional[str] = "text"
|
68 |
+
type: Optional[str] = "pretrain"
|
69 |
|
70 |
|
71 |
class UserDefinedPrompterType(BaseModel):
|
|
|
452 |
dataset_shard_idx: Optional[int] = None
|
453 |
|
454 |
pretraining_dataset: Optional[ # type: ignore
|
455 |
+
conlist(Union[PretrainingDataset, SFTDataset], min_length=1)
|
456 |
] = Field(
|
457 |
default=None, metadata={"help": {"streaming dataset to use for pretraining"}}
|
458 |
)
|
src/axolotl/utils/data.py
CHANGED
@@ -82,12 +82,15 @@ def prepare_dataset(cfg, tokenizer):
|
|
82 |
)
|
83 |
else:
|
84 |
path = cfg.pretraining_dataset
|
|
|
85 |
name = None
|
86 |
if isinstance(cfg.pretraining_dataset, list) and isinstance(
|
87 |
cfg.pretraining_dataset[0], dict
|
88 |
):
|
89 |
path = cfg.pretraining_dataset[0]["path"]
|
90 |
name = cfg.pretraining_dataset[0]["name"]
|
|
|
|
|
91 |
|
92 |
ds_wrapper_partial = functools.partial(
|
93 |
get_dataset_wrapper,
|
@@ -98,7 +101,7 @@ def prepare_dataset(cfg, tokenizer):
|
|
98 |
)
|
99 |
|
100 |
train_dataset = wrap_pretraining_dataset(
|
101 |
-
load_dataset(path, streaming=True, split=
|
102 |
tokenizer,
|
103 |
cfg,
|
104 |
ds_wrapper_partial,
|
@@ -831,14 +834,23 @@ def wrap_pretraining_dataset(
|
|
831 |
else:
|
832 |
LOG.debug("NOT shuffling merged pretraining datasets")
|
833 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
834 |
dataset = dataset.map(
|
835 |
encode,
|
836 |
batched=True,
|
837 |
batch_size=buffer_size,
|
838 |
# input_columns="text",
|
839 |
-
|
840 |
-
# a different length than the encoded/tokenized column
|
841 |
-
remove_columns=dataset.features.keys(),
|
842 |
)
|
843 |
return dataset
|
844 |
|
|
|
82 |
)
|
83 |
else:
|
84 |
path = cfg.pretraining_dataset
|
85 |
+
split = "train"
|
86 |
name = None
|
87 |
if isinstance(cfg.pretraining_dataset, list) and isinstance(
|
88 |
cfg.pretraining_dataset[0], dict
|
89 |
):
|
90 |
path = cfg.pretraining_dataset[0]["path"]
|
91 |
name = cfg.pretraining_dataset[0]["name"]
|
92 |
+
if "split" in cfg.pretraining_dataset[0]:
|
93 |
+
split = cfg.pretraining_dataset[0]["split"]
|
94 |
|
95 |
ds_wrapper_partial = functools.partial(
|
96 |
get_dataset_wrapper,
|
|
|
101 |
)
|
102 |
|
103 |
train_dataset = wrap_pretraining_dataset(
|
104 |
+
load_dataset(path, streaming=True, split=split, name=name),
|
105 |
tokenizer,
|
106 |
cfg,
|
107 |
ds_wrapper_partial,
|
|
|
834 |
else:
|
835 |
LOG.debug("NOT shuffling merged pretraining datasets")
|
836 |
|
837 |
+
# remove all the existing columns after mapping since they end up having
|
838 |
+
# a different length than the encoded/tokenized column
|
839 |
+
# this is empty during streaming/pretraining
|
840 |
+
remove_columns = []
|
841 |
+
if dataset.features is None:
|
842 |
+
for first_row in dataset:
|
843 |
+
remove_columns = first_row.keys()
|
844 |
+
break
|
845 |
+
else:
|
846 |
+
remove_columns = dataset.features.keys()
|
847 |
+
|
848 |
dataset = dataset.map(
|
849 |
encode,
|
850 |
batched=True,
|
851 |
batch_size=buffer_size,
|
852 |
# input_columns="text",
|
853 |
+
remove_columns=remove_columns,
|
|
|
|
|
854 |
)
|
855 |
return dataset
|
856 |
|