winglian commited on
Commit
e0fcef4
·
unverified ·
1 Parent(s): c2b64e4

refactor utils.data module for line count linter (#1476)

Browse files
src/axolotl/utils/data/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data processing modules
3
+ """
4
+ from axolotl.utils.data.dpo import load_prepare_dpo_datasets # noqa: F401
5
+ from axolotl.utils.data.pretraining import ( # noqa: F401
6
+ encode_pretraining,
7
+ wrap_pretraining_dataset,
8
+ )
9
+ from axolotl.utils.data.sft import ( # noqa: F401
10
+ get_dataset_wrapper,
11
+ load_prepare_datasets,
12
+ load_tokenized_prepared_datasets,
13
+ prepare_dataset,
14
+ )
15
+ from axolotl.utils.data.utils import md5 # noqa: F401
src/axolotl/utils/data/dpo.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """data handling specific to DPO"""
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Any, List
6
+
7
+ import yaml
8
+ from datasets import concatenate_datasets, load_dataset, load_from_disk
9
+
10
+ from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
11
+ from axolotl.prompt_strategies.dpo import load as load_dpo
12
+ from axolotl.utils.data.utils import md5
13
+ from axolotl.utils.dict import DictDefault
14
+ from axolotl.utils.distributed import is_main_process, zero_first
15
+
16
+ LOG = logging.getLogger("axolotl")
17
+
18
+
19
+ def _get_path(ds_hash, cfg):
20
+ prepared_ds_path = (
21
+ Path(cfg.dataset_prepared_path) / ds_hash
22
+ if cfg.dataset_prepared_path
23
+ else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
24
+ )
25
+
26
+ return prepared_ds_path
27
+
28
+
29
+ def _load_preprocessed_ds(cfg, sub_cfg):
30
+ ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
31
+ prepared_ds_path = _get_path(ds_hash, cfg)
32
+ dataset = None
33
+
34
+ # pylint: disable=duplicate-code
35
+ if (
36
+ cfg.dataset_prepared_path
37
+ and any(prepared_ds_path.glob("*"))
38
+ and not cfg.is_preprocess
39
+ ):
40
+ LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
41
+ dataset = load_from_disk(str(prepared_ds_path))
42
+
43
+ return dataset
44
+
45
+
46
+ def _save_preprocessed_ds(cfg, sub_cfg, dataset):
47
+ ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
48
+ prepared_ds_path = _get_path(ds_hash, cfg)
49
+
50
+ if cfg.is_preprocess and is_main_process():
51
+ LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
52
+ dataset.save_to_disk(str(prepared_ds_path))
53
+
54
+
55
+ def load_prepare_dpo_datasets(cfg):
56
+ def load_split(dataset_cfgs, _cfg):
57
+ split_datasets: List[Any] = []
58
+ for i, ds_cfg in enumerate(dataset_cfgs):
59
+ if ds_cfg["ds_type"] == "json":
60
+ for data_file in ds_cfg["data_files"]:
61
+ data_files = {ds_cfg["split"]: data_file}
62
+ ds = load_dataset( # pylint: disable=invalid-name
63
+ "json",
64
+ data_files=data_files,
65
+ split=ds_cfg["split"],
66
+ )
67
+ split_datasets.insert(i, ds)
68
+ else:
69
+ ds = load_dataset( # pylint: disable=invalid-name
70
+ ds_cfg["path"],
71
+ split=ds_cfg["split"],
72
+ )
73
+ split_datasets.insert(i, ds)
74
+
75
+ for i, data_set in enumerate(split_datasets):
76
+ _type = dataset_cfgs[i]["type"]
77
+ if _type:
78
+ if isinstance(_type, DictDefault):
79
+ _type = "user_defined.default"
80
+ ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
81
+ split_datasets[i] = data_set.map(
82
+ ds_transform_fn,
83
+ desc="Mapping RL Dataset",
84
+ )
85
+ else:
86
+ # If no `type` is provided, assume the dataset is already in the expected format with
87
+ # "prompt", "chosen" and "rejected" already preprocessed
88
+ split_datasets[i] = data_set
89
+
90
+ return concatenate_datasets(split_datasets)
91
+
92
+ with zero_first(is_main_process()):
93
+ train_is_preprocessed = False
94
+ eval_is_preprocessed = False
95
+ if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets):
96
+ train_is_preprocessed = True
97
+ else:
98
+ train_dataset = load_split(cfg.datasets, cfg)
99
+
100
+ eval_dataset = None
101
+ if cfg.test_datasets:
102
+ if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets):
103
+ eval_is_preprocessed = True
104
+ else:
105
+ eval_dataset = load_split(cfg.test_datasets, cfg)
106
+ if not eval_dataset:
107
+ eval_dataset = None
108
+
109
+ if not train_is_preprocessed:
110
+ _save_preprocessed_ds(cfg, cfg.datasets, train_dataset)
111
+ if eval_dataset and not eval_is_preprocessed:
112
+ _save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)
113
+
114
+ return train_dataset, eval_dataset
src/axolotl/utils/data/pretraining.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """data handling specific to pretraining"""
2
+
3
+ import functools
4
+ import logging
5
+ from collections import defaultdict
6
+ from typing import Callable, Dict, List, Optional
7
+
8
+ import torch
9
+ from datasets import Dataset
10
+ from torch.utils.data import RandomSampler
11
+ from transformers import PreTrainedTokenizerBase
12
+
13
+ from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
14
+ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
15
+ from axolotl.utils.trainer import process_pretraining_datasets_for_packing
16
+
17
+ LOG = logging.getLogger("axolotl")
18
+
19
+
20
+ def encode_pretraining(
21
+ tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
22
+ ) -> Dict[str, List]:
23
+ res = tokenizer(
24
+ examples,
25
+ truncation=True,
26
+ max_length=max_tokens - 2,
27
+ add_special_tokens=True,
28
+ )
29
+ # Convert to PyTorch tensors
30
+ input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
31
+ attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
32
+ new_input_ids = []
33
+ new_attention_mask = []
34
+ # Append EOS and PAD tokens to input_ids, and correct attention_mask
35
+ for i, _ in enumerate(input_ids):
36
+ input_ids[i] = torch.cat(
37
+ (
38
+ input_ids[i],
39
+ torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]),
40
+ ),
41
+ dim=0,
42
+ )
43
+ attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
44
+
45
+ # Concatenate tokens so that their lengths are less than max_tokens
46
+ buffer_input_ids = torch.tensor([], dtype=torch.long)
47
+ buffer_attention_mask = torch.tensor([], dtype=torch.long)
48
+
49
+ for ids, mask in zip(input_ids, attention_mask):
50
+ if buffer_input_ids.numel() == max_tokens:
51
+ new_input_ids.append(buffer_input_ids)
52
+ new_attention_mask.append(buffer_attention_mask)
53
+ buffer_input_ids = torch.tensor([], dtype=torch.long)
54
+ buffer_attention_mask = torch.tensor([], dtype=torch.long)
55
+ buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
56
+ buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
57
+ elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
58
+ buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
59
+ buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
60
+ else:
61
+ buffer_input_ids = torch.cat(
62
+ (
63
+ buffer_input_ids,
64
+ torch.full(
65
+ (max_tokens - buffer_input_ids.numel(),),
66
+ tokenizer.pad_token_id,
67
+ dtype=torch.long,
68
+ ),
69
+ ),
70
+ dim=0,
71
+ )
72
+ buffer_attention_mask = torch.cat(
73
+ (
74
+ buffer_attention_mask,
75
+ torch.full(
76
+ (max_tokens - buffer_attention_mask.numel(),),
77
+ 0,
78
+ dtype=torch.long,
79
+ ),
80
+ ),
81
+ dim=0,
82
+ )
83
+ new_input_ids.append(buffer_input_ids)
84
+ new_attention_mask.append(buffer_attention_mask)
85
+ buffer_input_ids = torch.tensor([], dtype=torch.long)
86
+ buffer_attention_mask = torch.tensor([], dtype=torch.long)
87
+
88
+ buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
89
+ buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
90
+
91
+ if buffer_input_ids.numel() > 0: # for any leftover tokens
92
+ while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size
93
+ buffer_input_ids = torch.cat(
94
+ (
95
+ buffer_input_ids,
96
+ torch.full(
97
+ (max_tokens - buffer_input_ids.numel(),),
98
+ tokenizer.pad_token_id,
99
+ dtype=torch.long,
100
+ ),
101
+ ),
102
+ dim=0,
103
+ )
104
+ buffer_attention_mask = torch.cat(
105
+ (
106
+ buffer_attention_mask,
107
+ torch.full(
108
+ (max_tokens - buffer_attention_mask.numel(),),
109
+ 0,
110
+ dtype=torch.long,
111
+ ),
112
+ ),
113
+ dim=0,
114
+ )
115
+ new_input_ids.append(buffer_input_ids)
116
+ new_attention_mask.append(buffer_attention_mask)
117
+
118
+ ret = {
119
+ "input_ids": [seq.tolist() for seq in new_input_ids],
120
+ "labels": [seq.tolist() for seq in new_input_ids],
121
+ "attention_mask": [seq.tolist() for seq in new_attention_mask],
122
+ }
123
+
124
+ LOG.debug(len(ret["input_ids"]))
125
+ return ret
126
+
127
+
128
+ def wrap_pretraining_dataset(
129
+ dataset,
130
+ tokenizer,
131
+ cfg,
132
+ ds_wrapper_fn,
133
+ max_tokens=2048,
134
+ batch_size=1,
135
+ seed=42,
136
+ buffer_size=10_000,
137
+ ):
138
+ if cfg.sample_packing:
139
+ collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
140
+ tokenizer,
141
+ return_tensors="pt",
142
+ padding=True,
143
+ pad_to_multiple_of=max_tokens * batch_size,
144
+ multipack_attn=cfg.pretrain_multipack_attn,
145
+ )
146
+ encode = functools.partial(
147
+ encode_packed_pretraining,
148
+ collate_fn,
149
+ ds_wrapper_fn,
150
+ max_seq_length=max_tokens,
151
+ batch_size=batch_size,
152
+ multipack_attn=cfg.pretrain_multipack_attn,
153
+ )
154
+ # set this to 1 so downstream data_loader doesn't try to increase the batch again
155
+ cfg.micro_batch_size = 1
156
+ else:
157
+ encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
158
+
159
+ if cfg.shuffle_merged_datasets:
160
+ dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
161
+ else:
162
+ LOG.debug("NOT shuffling merged pretraining datasets")
163
+
164
+ # remove all the existing columns after mapping since they end up having
165
+ # a different length than the encoded/tokenized column
166
+ # this is empty during streaming/pretraining
167
+ remove_columns = []
168
+ if dataset.features is None:
169
+ for first_row in dataset:
170
+ remove_columns = first_row.keys()
171
+ break
172
+ else:
173
+ remove_columns = dataset.features.keys()
174
+
175
+ dataset = dataset.map(
176
+ encode,
177
+ batched=True,
178
+ batch_size=buffer_size,
179
+ # input_columns="text",
180
+ remove_columns=remove_columns,
181
+ )
182
+ return dataset
183
+
184
+
185
+ def encode_packed_pretraining(
186
+ collate_fn,
187
+ ds_wrapper: Callable,
188
+ examples: Dict[str, List],
189
+ max_seq_length: int = 2048,
190
+ batch_size: int = 4,
191
+ multipack_attn: Optional[bool] = False,
192
+ ) -> Dict[str, List]:
193
+ # pylint: disable=duplicate-code
194
+ # tokenize all the examples
195
+ # rows get split with stride (overlap)
196
+ train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
197
+
198
+ train_dataset = process_pretraining_datasets_for_packing(
199
+ train_dataset,
200
+ max_seq_length,
201
+ skip_position_ids=not multipack_attn,
202
+ )
203
+
204
+ sampler = MultipackBatchSampler(
205
+ RandomSampler(train_dataset),
206
+ batch_size=1,
207
+ drop_last=True,
208
+ batch_max_len=batch_size * max_seq_length,
209
+ lengths=get_dataset_lengths(train_dataset),
210
+ )
211
+
212
+ chunked_data = defaultdict(list)
213
+
214
+ for batch in sampler:
215
+ for data in batch:
216
+ features = train_dataset[data]
217
+ if "num_truncated_tokens" in features:
218
+ del features["num_truncated_tokens"]
219
+ if "num_truncated_tokens" in features:
220
+ del features["num_truncated_tokens"]
221
+ if "overflow_to_sample_mapping" in features:
222
+ del features["overflow_to_sample_mapping"]
223
+ if "labels" not in features:
224
+ features["labels"] = features["input_ids"].copy()
225
+ collated_features = collate_fn(features)
226
+
227
+ for feature in features.keys():
228
+ if feature == "length":
229
+ continue
230
+ chunked_data[feature].append(collated_features[feature].squeeze(0))
231
+
232
+ return chunked_data
src/axolotl/utils/{data.py → data/sft.py} RENAMED
@@ -1,14 +1,10 @@
1
- """Module containing data utilities"""
2
 
3
  import functools
4
- import hashlib
5
  import logging
6
- from collections import defaultdict
7
  from pathlib import Path
8
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
9
 
10
- import torch
11
- import yaml
12
  from datasets import (
13
  Dataset,
14
  DatasetDict,
@@ -18,13 +14,11 @@ from datasets import (
18
  )
19
  from huggingface_hub import hf_hub_download
20
  from huggingface_hub.utils import HFValidationError
21
- from torch.utils.data import RandomSampler
22
  from transformers import PreTrainedTokenizerBase
23
 
24
  from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
25
  from axolotl.datasets import TokenizedPromptDataset
26
  from axolotl.prompt_strategies import load
27
- from axolotl.prompt_strategies.dpo import load as load_dpo
28
  from axolotl.prompt_tokenizers import (
29
  AlpacaMultipleChoicePromptTokenizingStrategy,
30
  AlpacaPromptTokenizingStrategy,
@@ -45,26 +39,18 @@ from axolotl.prompters import (
45
  SummarizeTLDRPrompter,
46
  UnsupportedPrompter,
47
  )
48
- from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
 
49
  from axolotl.utils.dict import DictDefault
50
  from axolotl.utils.distributed import is_main_process, zero_first
51
- from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
52
  from axolotl.utils.trainer import (
53
  calculate_total_num_steps,
54
  process_datasets_for_packing,
55
- process_pretraining_datasets_for_packing,
56
  )
57
 
58
  LOG = logging.getLogger("axolotl")
59
 
60
 
61
- def md5(to_hash: str, encoding: str = "utf-8") -> str:
62
- try:
63
- return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
64
- except TypeError:
65
- return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
66
-
67
-
68
  def prepare_dataset(cfg, tokenizer):
69
  prompters = []
70
  if not cfg.pretraining_dataset:
@@ -182,6 +168,7 @@ def load_tokenized_prepared_datasets(
182
  except Exception: # pylint: disable=broad-except # nosec
183
  pass
184
 
 
185
  if dataset:
186
  ...
187
  elif (
@@ -691,315 +678,3 @@ def get_dataset_wrapper(
691
  )
692
 
693
  return dataset_wrapper, dataset_prompter
694
-
695
-
696
- def encode_pretraining(
697
- tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
698
- ) -> Dict[str, List]:
699
- res = tokenizer(
700
- examples,
701
- truncation=True,
702
- max_length=max_tokens - 2,
703
- add_special_tokens=True,
704
- )
705
- # Convert to PyTorch tensors
706
- input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
707
- attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
708
- new_input_ids = []
709
- new_attention_mask = []
710
- # Append EOS and PAD tokens to input_ids, and correct attention_mask
711
- for i, _ in enumerate(input_ids):
712
- input_ids[i] = torch.cat(
713
- (
714
- input_ids[i],
715
- torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]),
716
- ),
717
- dim=0,
718
- )
719
- attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
720
-
721
- # Concatenate tokens so that their lengths are less than max_tokens
722
- buffer_input_ids = torch.tensor([], dtype=torch.long)
723
- buffer_attention_mask = torch.tensor([], dtype=torch.long)
724
-
725
- for ids, mask in zip(input_ids, attention_mask):
726
- if buffer_input_ids.numel() == max_tokens:
727
- new_input_ids.append(buffer_input_ids)
728
- new_attention_mask.append(buffer_attention_mask)
729
- buffer_input_ids = torch.tensor([], dtype=torch.long)
730
- buffer_attention_mask = torch.tensor([], dtype=torch.long)
731
- buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
732
- buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
733
- elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
734
- buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
735
- buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
736
- else:
737
- buffer_input_ids = torch.cat(
738
- (
739
- buffer_input_ids,
740
- torch.full(
741
- (max_tokens - buffer_input_ids.numel(),),
742
- tokenizer.pad_token_id,
743
- dtype=torch.long,
744
- ),
745
- ),
746
- dim=0,
747
- )
748
- buffer_attention_mask = torch.cat(
749
- (
750
- buffer_attention_mask,
751
- torch.full(
752
- (max_tokens - buffer_attention_mask.numel(),),
753
- 0,
754
- dtype=torch.long,
755
- ),
756
- ),
757
- dim=0,
758
- )
759
- new_input_ids.append(buffer_input_ids)
760
- new_attention_mask.append(buffer_attention_mask)
761
- buffer_input_ids = torch.tensor([], dtype=torch.long)
762
- buffer_attention_mask = torch.tensor([], dtype=torch.long)
763
-
764
- buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
765
- buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
766
-
767
- if buffer_input_ids.numel() > 0: # for any leftover tokens
768
- while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size
769
- buffer_input_ids = torch.cat(
770
- (
771
- buffer_input_ids,
772
- torch.full(
773
- (max_tokens - buffer_input_ids.numel(),),
774
- tokenizer.pad_token_id,
775
- dtype=torch.long,
776
- ),
777
- ),
778
- dim=0,
779
- )
780
- buffer_attention_mask = torch.cat(
781
- (
782
- buffer_attention_mask,
783
- torch.full(
784
- (max_tokens - buffer_attention_mask.numel(),),
785
- 0,
786
- dtype=torch.long,
787
- ),
788
- ),
789
- dim=0,
790
- )
791
- new_input_ids.append(buffer_input_ids)
792
- new_attention_mask.append(buffer_attention_mask)
793
-
794
- ret = {
795
- "input_ids": [seq.tolist() for seq in new_input_ids],
796
- "labels": [seq.tolist() for seq in new_input_ids],
797
- "attention_mask": [seq.tolist() for seq in new_attention_mask],
798
- }
799
-
800
- LOG.debug(len(ret["input_ids"]))
801
- return ret
802
-
803
-
804
- def wrap_pretraining_dataset(
805
- dataset,
806
- tokenizer,
807
- cfg,
808
- ds_wrapper_fn,
809
- max_tokens=2048,
810
- batch_size=1,
811
- seed=42,
812
- buffer_size=10_000,
813
- ):
814
- if cfg.sample_packing:
815
- collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
816
- tokenizer,
817
- return_tensors="pt",
818
- padding=True,
819
- pad_to_multiple_of=max_tokens * batch_size,
820
- multipack_attn=cfg.pretrain_multipack_attn,
821
- )
822
- encode = functools.partial(
823
- encode_packed_pretraining,
824
- collate_fn,
825
- ds_wrapper_fn,
826
- max_seq_length=max_tokens,
827
- batch_size=batch_size,
828
- multipack_attn=cfg.pretrain_multipack_attn,
829
- )
830
- # set this to 1 so downstream data_loader doesn't try to increase the batch again
831
- cfg.micro_batch_size = 1
832
- else:
833
- encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
834
-
835
- if cfg.shuffle_merged_datasets:
836
- dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
837
- else:
838
- LOG.debug("NOT shuffling merged pretraining datasets")
839
-
840
- # remove all the existing columns after mapping since they end up having
841
- # a different length than the encoded/tokenized column
842
- # this is empty during streaming/pretraining
843
- remove_columns = []
844
- if dataset.features is None:
845
- for first_row in dataset:
846
- remove_columns = first_row.keys()
847
- break
848
- else:
849
- remove_columns = dataset.features.keys()
850
-
851
- dataset = dataset.map(
852
- encode,
853
- batched=True,
854
- batch_size=buffer_size,
855
- # input_columns="text",
856
- remove_columns=remove_columns,
857
- )
858
- return dataset
859
-
860
-
861
- def encode_packed_pretraining(
862
- collate_fn,
863
- ds_wrapper: Callable,
864
- examples: Dict[str, List],
865
- max_seq_length: int = 2048,
866
- batch_size: int = 4,
867
- multipack_attn: Optional[bool] = False,
868
- ) -> Dict[str, List]:
869
- # pylint: disable=duplicate-code
870
- # tokenize all the examples
871
- # rows get split with stride (overlap)
872
- train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
873
-
874
- train_dataset = process_pretraining_datasets_for_packing(
875
- train_dataset,
876
- max_seq_length,
877
- skip_position_ids=not multipack_attn,
878
- )
879
-
880
- sampler = MultipackBatchSampler(
881
- RandomSampler(train_dataset),
882
- batch_size=1,
883
- drop_last=True,
884
- batch_max_len=batch_size * max_seq_length,
885
- lengths=get_dataset_lengths(train_dataset),
886
- )
887
-
888
- chunked_data = defaultdict(list)
889
-
890
- for batch in sampler:
891
- for data in batch:
892
- features = train_dataset[data]
893
- if "num_truncated_tokens" in features:
894
- del features["num_truncated_tokens"]
895
- if "num_truncated_tokens" in features:
896
- del features["num_truncated_tokens"]
897
- if "overflow_to_sample_mapping" in features:
898
- del features["overflow_to_sample_mapping"]
899
- if "labels" not in features:
900
- features["labels"] = features["input_ids"].copy()
901
- collated_features = collate_fn(features)
902
-
903
- for feature in features.keys():
904
- if feature == "length":
905
- continue
906
- chunked_data[feature].append(collated_features[feature].squeeze(0))
907
-
908
- return chunked_data
909
-
910
-
911
- def _get_path(ds_hash, cfg):
912
- prepared_ds_path = (
913
- Path(cfg.dataset_prepared_path) / ds_hash
914
- if cfg.dataset_prepared_path
915
- else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
916
- )
917
-
918
- return prepared_ds_path
919
-
920
-
921
- def _load_preprocessed_ds(cfg, sub_cfg):
922
- ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
923
- prepared_ds_path = _get_path(ds_hash, cfg)
924
- dataset = None
925
-
926
- if (
927
- cfg.dataset_prepared_path
928
- and any(prepared_ds_path.glob("*"))
929
- and not cfg.is_preprocess
930
- ):
931
- LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
932
- dataset = load_from_disk(str(prepared_ds_path))
933
-
934
- return dataset
935
-
936
-
937
- def _save_preprocessed_ds(cfg, sub_cfg, dataset):
938
- ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
939
- prepared_ds_path = _get_path(ds_hash, cfg)
940
-
941
- if cfg.is_preprocess and is_main_process():
942
- LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
943
- dataset.save_to_disk(str(prepared_ds_path))
944
-
945
-
946
- def load_prepare_dpo_datasets(cfg):
947
- def load_split(dataset_cfgs, _cfg):
948
- split_datasets: List[Any] = []
949
- for i, ds_cfg in enumerate(dataset_cfgs):
950
- if ds_cfg["ds_type"] == "json":
951
- for data_file in ds_cfg["data_files"]:
952
- data_files = {ds_cfg["split"]: data_file}
953
- ds = load_dataset( # pylint: disable=invalid-name
954
- "json",
955
- data_files=data_files,
956
- split=ds_cfg["split"],
957
- )
958
- split_datasets.insert(i, ds)
959
- else:
960
- ds = load_dataset( # pylint: disable=invalid-name
961
- ds_cfg["path"],
962
- split=ds_cfg["split"],
963
- )
964
- split_datasets.insert(i, ds)
965
-
966
- for i, data_set in enumerate(split_datasets):
967
- _type = dataset_cfgs[i]["type"]
968
- if _type:
969
- if isinstance(_type, DictDefault):
970
- _type = "user_defined.default"
971
- ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
972
- split_datasets[i] = data_set.map(
973
- ds_transform_fn,
974
- desc="Mapping RL Dataset",
975
- )
976
- else:
977
- # If no `type` is provided, assume the dataset is already in the expected format with
978
- # "prompt", "chosen" and "rejected" already preprocessed
979
- split_datasets[i] = data_set
980
-
981
- return concatenate_datasets(split_datasets)
982
-
983
- with zero_first(is_main_process()):
984
- train_is_preprocessed = False
985
- eval_is_preprocessed = False
986
- if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets):
987
- train_is_preprocessed = True
988
- else:
989
- train_dataset = load_split(cfg.datasets, cfg)
990
-
991
- eval_dataset = None
992
- if cfg.test_datasets:
993
- if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets):
994
- eval_is_preprocessed = True
995
- else:
996
- eval_dataset = load_split(cfg.test_datasets, cfg)
997
- if not eval_dataset:
998
- eval_dataset = None
999
-
1000
- if not train_is_preprocessed:
1001
- _save_preprocessed_ds(cfg, cfg.datasets, train_dataset)
1002
- if eval_dataset and not eval_is_preprocessed:
1003
- _save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)
1004
-
1005
- return train_dataset, eval_dataset
 
1
+ """data handling specific to SFT"""
2
 
3
  import functools
 
4
  import logging
 
5
  from pathlib import Path
6
+ from typing import List, Optional, Tuple, Union
7
 
 
 
8
  from datasets import (
9
  Dataset,
10
  DatasetDict,
 
14
  )
15
  from huggingface_hub import hf_hub_download
16
  from huggingface_hub.utils import HFValidationError
 
17
  from transformers import PreTrainedTokenizerBase
18
 
19
  from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
20
  from axolotl.datasets import TokenizedPromptDataset
21
  from axolotl.prompt_strategies import load
 
22
  from axolotl.prompt_tokenizers import (
23
  AlpacaMultipleChoicePromptTokenizingStrategy,
24
  AlpacaPromptTokenizingStrategy,
 
39
  SummarizeTLDRPrompter,
40
  UnsupportedPrompter,
41
  )
42
+ from axolotl.utils.data.pretraining import wrap_pretraining_dataset
43
+ from axolotl.utils.data.utils import md5
44
  from axolotl.utils.dict import DictDefault
45
  from axolotl.utils.distributed import is_main_process, zero_first
 
46
  from axolotl.utils.trainer import (
47
  calculate_total_num_steps,
48
  process_datasets_for_packing,
 
49
  )
50
 
51
  LOG = logging.getLogger("axolotl")
52
 
53
 
 
 
 
 
 
 
 
54
  def prepare_dataset(cfg, tokenizer):
55
  prompters = []
56
  if not cfg.pretraining_dataset:
 
168
  except Exception: # pylint: disable=broad-except # nosec
169
  pass
170
 
171
+ # pylint: disable=duplicate-code
172
  if dataset:
173
  ...
174
  elif (
 
678
  )
679
 
680
  return dataset_wrapper, dataset_prompter
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/axolotl/utils/data/utils.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """data handling helpers"""
2
+
3
+ import hashlib
4
+
5
+
6
+ def md5(to_hash: str, encoding: str = "utf-8") -> str:
7
+ try:
8
+ return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
9
+ except TypeError:
10
+ return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec