winglian commited on
Commit
1a6309c
·
unverified ·
1 Parent(s): 105d0b3

cleanup the old multipack dataloader (#841)

Browse files
src/axolotl/core/trainer_builder.py CHANGED
@@ -11,7 +11,7 @@ from abc import abstractmethod
11
  from dataclasses import dataclass, field
12
  from functools import partial
13
  from pathlib import Path
14
- from typing import Optional, Union
15
 
16
  import torch
17
  import transformers
@@ -31,7 +31,6 @@ from axolotl.utils.callbacks import (
31
  log_prediction_callback_factory,
32
  )
33
  from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
34
- from axolotl.utils.dataloader import MultipackDistributedDataloader
35
  from axolotl.utils.samplers import MultipackBatchSampler
36
  from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
37
 
@@ -215,9 +214,7 @@ class AxolotlTrainer(Trainer):
215
  )
216
  return super().get_train_dataloader()
217
 
218
- def get_eval_dataloader(
219
- self, eval_dataset: Optional[Dataset] = None
220
- ) -> Union[DataLoader, MultipackDistributedDataloader]:
221
  if self.args.sample_packing and self.args.eval_sample_packing is not False:
222
  eval_dataset = (
223
  eval_dataset if eval_dataset is not None else self.eval_dataset
@@ -260,7 +257,7 @@ class AxolotlTrainer(Trainer):
260
  def get_bench_dataloader(
261
  self,
262
  bench_dataset: Dataset,
263
- ) -> Union[DataLoader, MultipackDistributedDataloader]:
264
  dataloader_params = {
265
  "batch_size": self.args.eval_batch_size,
266
  "collate_fn": self.bench_data_collator,
 
11
  from dataclasses import dataclass, field
12
  from functools import partial
13
  from pathlib import Path
14
+ from typing import Optional
15
 
16
  import torch
17
  import transformers
 
31
  log_prediction_callback_factory,
32
  )
33
  from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
 
34
  from axolotl.utils.samplers import MultipackBatchSampler
35
  from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
36
 
 
214
  )
215
  return super().get_train_dataloader()
216
 
217
+ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
 
 
218
  if self.args.sample_packing and self.args.eval_sample_packing is not False:
219
  eval_dataset = (
220
  eval_dataset if eval_dataset is not None else self.eval_dataset
 
257
  def get_bench_dataloader(
258
  self,
259
  bench_dataset: Dataset,
260
+ ) -> DataLoader:
261
  dataloader_params = {
262
  "batch_size": self.args.eval_batch_size,
263
  "collate_fn": self.bench_data_collator,
src/axolotl/prompters.py CHANGED
@@ -22,7 +22,13 @@ class PromptStyle(Enum):
22
  CHATML = "chatml"
23
 
24
 
25
- class AlpacaPrompter:
 
 
 
 
 
 
26
  """
27
  Base class for alpaca prompters
28
  """
@@ -159,7 +165,7 @@ class NomicGPT4AllPrompter(AlpacaPrompter):
159
  """
160
 
161
 
162
- class ReflectAlpacaPrompter:
163
  """
164
  Prompter for ReflectAlpaca
165
  """
@@ -254,7 +260,7 @@ SHAREGPT_ASSERTION_FAILED_ROLE = (
254
  )
255
 
256
 
257
- class ShareGPTPrompter: # pylint: disable=too-few-public-methods
258
  """
259
  A prompter that generates prompts for the ShareGPT
260
  """
@@ -349,7 +355,7 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
349
  )
350
 
351
 
352
- class UnsupportedPrompter:
353
  """
354
  A dummy class for custom prompters
355
  """
 
22
  CHATML = "chatml"
23
 
24
 
25
+ class Prompter:
26
+ """
27
+ Base prompter class for all prompters
28
+ """
29
+
30
+
31
+ class AlpacaPrompter(Prompter):
32
  """
33
  Base class for alpaca prompters
34
  """
 
165
  """
166
 
167
 
168
+ class ReflectAlpacaPrompter(Prompter):
169
  """
170
  Prompter for ReflectAlpaca
171
  """
 
260
  )
261
 
262
 
263
+ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
264
  """
265
  A prompter that generates prompts for the ShareGPT
266
  """
 
355
  )
356
 
357
 
358
+ class UnsupportedPrompter(Prompter):
359
  """
360
  A dummy class for custom prompters
361
  """
src/axolotl/utils/data.py CHANGED
@@ -3,7 +3,7 @@ import functools
3
  import hashlib
4
  import logging
5
  from pathlib import Path
6
- from typing import Any, Dict, List, Tuple, Union
7
 
8
  import torch
9
  from datasets import (
@@ -34,6 +34,7 @@ from axolotl.prompters import (
34
  JeopardyPrompter,
35
  MultipleChoiceConcisePrompter,
36
  MultipleChoiceExplainPrompter,
 
37
  ReflectAlpacaPrompter,
38
  SummarizeTLDRPrompter,
39
  UnsupportedPrompter,
@@ -90,7 +91,7 @@ def prepare_dataset(cfg, tokenizer):
90
 
91
  def load_tokenized_prepared_datasets(
92
  tokenizer, cfg, default_dataset_prepared_path
93
- ) -> DatasetDict:
94
  tokenizer_name = tokenizer.__class__.__name__
95
  ds_hash = str(
96
  md5(
@@ -302,7 +303,7 @@ def load_prepare_datasets(
302
  tokenizer: PreTrainedTokenizerBase,
303
  cfg,
304
  default_dataset_prepared_path,
305
- ) -> Tuple[Dataset, Dataset, List[Any]]:
306
  max_packed_sequence_len = (
307
  cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
308
  )
@@ -311,7 +312,7 @@ def load_prepare_datasets(
311
  ) # make sure we don't accidentally set it larger than sequence_len
312
 
313
  tokenizer_name = tokenizer.__class__.__name__
314
- prompters = []
315
  if cfg.max_packed_sequence_len is not None:
316
  # see if we can go ahead and load the stacked dataset
317
  seed = f"@{str(cfg.seed)}" if cfg.seed else ""
@@ -445,14 +446,13 @@ def load_prepare_datasets(
445
  train_fingerprint = md5(to_hash_train)
446
  test_fingerprint = md5(to_hash_test)
447
 
448
- with zero_first(is_main_process()):
449
- dataset = dataset.train_test_split(
450
- test_size=cfg.val_set_size,
451
- shuffle=False,
452
- seed=cfg.seed or 42,
453
- train_new_fingerprint=train_fingerprint,
454
- test_new_fingerprint=test_fingerprint,
455
- )
456
 
457
  train_dataset = dataset["train"]
458
  eval_dataset = dataset["test"]
 
3
  import hashlib
4
  import logging
5
  from pathlib import Path
6
+ from typing import Dict, List, Tuple, Union
7
 
8
  import torch
9
  from datasets import (
 
34
  JeopardyPrompter,
35
  MultipleChoiceConcisePrompter,
36
  MultipleChoiceExplainPrompter,
37
+ Prompter,
38
  ReflectAlpacaPrompter,
39
  SummarizeTLDRPrompter,
40
  UnsupportedPrompter,
 
91
 
92
  def load_tokenized_prepared_datasets(
93
  tokenizer, cfg, default_dataset_prepared_path
94
+ ) -> Tuple[DatasetDict, List[Prompter]]:
95
  tokenizer_name = tokenizer.__class__.__name__
96
  ds_hash = str(
97
  md5(
 
303
  tokenizer: PreTrainedTokenizerBase,
304
  cfg,
305
  default_dataset_prepared_path,
306
+ ) -> Tuple[Dataset, Dataset, List[Prompter]]:
307
  max_packed_sequence_len = (
308
  cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
309
  )
 
312
  ) # make sure we don't accidentally set it larger than sequence_len
313
 
314
  tokenizer_name = tokenizer.__class__.__name__
315
+ prompters: List[Prompter] = []
316
  if cfg.max_packed_sequence_len is not None:
317
  # see if we can go ahead and load the stacked dataset
318
  seed = f"@{str(cfg.seed)}" if cfg.seed else ""
 
446
  train_fingerprint = md5(to_hash_train)
447
  test_fingerprint = md5(to_hash_test)
448
 
449
+ dataset = dataset.train_test_split(
450
+ test_size=cfg.val_set_size,
451
+ shuffle=False,
452
+ seed=cfg.seed or 42,
453
+ train_new_fingerprint=train_fingerprint,
454
+ test_new_fingerprint=test_fingerprint,
455
+ )
 
456
 
457
  train_dataset = dataset["train"]
458
  eval_dataset = dataset["test"]
src/axolotl/utils/dataloader.py DELETED
@@ -1,342 +0,0 @@
1
- # pylint: skip-file
2
- import hashlib
3
- import itertools
4
- import logging
5
- import math
6
- import time
7
- from queue import Queue
8
- from threading import Thread
9
- from typing import Any, Callable, List, Union
10
-
11
- import numba
12
- import numpy as np
13
- from torch.utils.data import DistributedSampler, Sampler
14
-
15
- LOG = logging.getLogger("axolotl.utils.dataloader")
16
-
17
-
18
- @numba.njit
19
- def ffd_check(a: np.ndarray, c: int, n: int):
20
- # First-fit-decreasing bin packing
21
- # Check if a[] could fit in n bins with capacity c
22
- # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
23
-
24
- a = np.sort(a)[::-1]
25
- bins = np.full((n,), c, dtype=a.dtype)
26
- for size in a:
27
- not_found = True
28
- for idx in range(n):
29
- if bins[idx] >= size:
30
- bins[idx] -= size
31
- not_found = False
32
- break
33
-
34
- if not_found:
35
- return False
36
-
37
- return True
38
-
39
-
40
- @numba.njit
41
- def ffd_with_result(a: np.ndarray, c: int, start_index: int):
42
- # First-fit-decreasing bin packing (with result return)
43
-
44
- indices = np.argsort(a)[::-1]
45
- a = a[indices]
46
-
47
- bins: List[Any] = []
48
- bins_result: List[Any] = []
49
- for a_id, size in enumerate(a):
50
- add_new = True
51
- for idx in range(len(bins)):
52
- if bins[idx] >= size:
53
- bins[idx] -= size
54
- bins_result[idx].append(indices[a_id] + start_index)
55
- add_new = False
56
- break
57
-
58
- if add_new:
59
- bins.append(c - size)
60
- bins_result.append([indices[a_id] + start_index])
61
-
62
- return bins_result, len(a)
63
-
64
-
65
- @numba.njit
66
- def allocate(
67
- lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
68
- ):
69
- """
70
- :param lengths: array of lengths of each sample
71
- :param lengths_cumsum: cumulative sum of consecutive lengths
72
- :param rank: rank for this process
73
- :param c: length of tokens per batch
74
- :param n: number of ranks
75
- :return:
76
- """
77
- # Dynamic batch allocator, similar to Multifit
78
- # https://en.wikipedia.org/wiki/Multifit_algorithm
79
- # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
80
-
81
- s = 0
82
- start_index = 0
83
- result = []
84
- result_totseqs = []
85
-
86
- while True:
87
- # binary search [left, right)
88
- left = 1
89
- right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
90
-
91
- while right - left > 1:
92
- mid = (left + right) // 2
93
- if ffd_check(lengths[start_index : start_index + mid], c, n):
94
- left = mid
95
- else:
96
- right = mid
97
-
98
- # use length left
99
- batch, tot_seqs = ffd_with_result(
100
- lengths[start_index : start_index + left], c, start_index
101
- )
102
- if len(batch) < n:
103
- break
104
-
105
- start_index += left
106
- s = lengths_cumsum[start_index - 1]
107
-
108
- # add local rank
109
- result.append(batch[rank])
110
- # add total seqs for all ranks
111
- result_totseqs.append(tot_seqs)
112
- # yield batch[rank], tot_seqs, s, len(result) * c * n
113
- return result, result_totseqs, s, len(result) * c * n
114
-
115
-
116
- def chunk(iterable, n):
117
- """
118
- Chunk data into tuples of length n
119
- """
120
- # batched('ABCDEFG', 3) --> ABC DEF G
121
- if n < 1:
122
- raise ValueError("n must be at least one")
123
- it = iter(iterable)
124
- while batch := tuple(itertools.islice(it, n)):
125
- yield batch
126
-
127
-
128
- def hash_indices(lst: List[int]) -> str:
129
- # Convert the list of integers to a string representation
130
- concatenated = ",".join(map(str, lst))
131
-
132
- # Generate the hash
133
- sha256 = hashlib.sha256()
134
- sha256.update(concatenated.encode())
135
-
136
- return sha256.hexdigest()
137
-
138
-
139
- class MultipackDistributedDataloader:
140
- """Unpadded data loading using Multipack.
141
- Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py
142
- Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.
143
- """
144
-
145
- def __init__(
146
- self,
147
- dataset: Any,
148
- collate_fn: Callable,
149
- seq_max_length: int = 2048,
150
- batch_size: int = 1,
151
- sampler: Union[Sampler, DistributedSampler] = None,
152
- packing_efficiency_estimate: float = 1.0,
153
- sample_packing_seq_len_multiplier: int = 1,
154
- device_count: int = 1,
155
- prefetch_max: int = 1000,
156
- num_epochs: int = 1,
157
- ):
158
- # Dataset
159
- self.dataset = dataset
160
- self.lengths = (
161
- dataset.data.column("position_ids")
162
- .to_pandas()
163
- .apply(lambda x: x[-1] + 1)
164
- .values
165
- )
166
- assert isinstance(self.lengths, np.ndarray)
167
- assert batch_size % sample_packing_seq_len_multiplier == 0
168
- assert batch_size >= sample_packing_seq_len_multiplier
169
- self.sampler = sampler
170
- self.batch_size = batch_size
171
- self.sample_packing_seq_len_multiplier = sample_packing_seq_len_multiplier
172
- self.seq_max_length = seq_max_length
173
- self.batch_max_length = batch_size * seq_max_length
174
- self.collate_fn = collate_fn
175
- self.num_epochs = num_epochs
176
-
177
- self.num_replicas = 1
178
- self.rank = 0
179
-
180
- # statistics
181
- self.eff_total_used = 0
182
- self.eff_total_slots = 0
183
- self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
184
- self.device_count = device_count
185
-
186
- # maxsize is maximum number of samples in queue
187
- self.prefetch_max = prefetch_max
188
- self.queue: Queue = Queue(maxsize=prefetch_max)
189
- self.thread = None
190
-
191
- def _worker(self):
192
- LOG.info(
193
- f"[WORKER] Epochs: {self.num_epochs}, Samples: {self.len_w_stats()*self.batch_size}"
194
- )
195
- for epoch in range(self.num_epochs):
196
- for sample in self._internal_batch_generator():
197
- while True:
198
- if self.queue.full():
199
- time.sleep(1)
200
- else:
201
- break
202
- self.queue.put(sample)
203
-
204
- # stop the queue when epoch is done
205
- self.queue.put(None)
206
-
207
- def __iter__(self):
208
- if hasattr(self.sampler, "set_epoch"):
209
- new_epoch = self.sampler.epoch + 1
210
- self.sampler.set_epoch(new_epoch)
211
- LOG.info(f"calling sampler.set_epoch({new_epoch})")
212
-
213
- if self.thread is None:
214
- self.thread = Thread(target=self._worker, daemon=True)
215
- self.thread.start()
216
-
217
- while True:
218
- item = self.queue.get()
219
-
220
- if item is None:
221
- break
222
- yield item
223
-
224
- def generate_batches(self, set_stats=False):
225
- LOG.info("generating packed batches")
226
- if self.sampler:
227
- indices = [idx for idx in self.sampler]
228
- else:
229
- indices = range(0, len(self.dataset))
230
-
231
- LOG.info(hash_indices(indices))
232
- lengths = self.lengths[indices]
233
- lengths_cumsum = np.cumsum(lengths)
234
-
235
- batches, totseqs, total_used, total_slots = allocate(
236
- lengths=lengths,
237
- lengths_cumsum=lengths_cumsum,
238
- rank=self.rank,
239
- # c=self.batch_max_length,
240
- c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
241
- n=self.num_replicas,
242
- )
243
-
244
- batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
245
-
246
- # statistics
247
- if set_stats:
248
- self.eff_total_used += total_used
249
- self.eff_total_slots += total_slots
250
-
251
- return batches, totseqs
252
-
253
- def _internal_batch_generator(self):
254
- all_batches, _ = self.generate_batches(set_stats=True)
255
- features = self.dataset.features.keys()
256
- len_remaining = self._len_est()
257
- for batches in chunk(
258
- all_batches, self.batch_size // self.sample_packing_seq_len_multiplier
259
- ):
260
- chunked_data = []
261
- attn_mask_cum_idx = 0
262
- for batch in batches:
263
- concatenated = {}
264
- batched_data = [self.dataset[batch_idx] for batch_idx in batch]
265
- for feature in features:
266
- if feature == "length":
267
- continue
268
- if feature == "attention_mask":
269
- arrays = [
270
- (attn_mask_cum_idx + idx + 1) * np.array(item[feature])
271
- for idx, item in enumerate(batched_data)
272
- if feature in item
273
- ]
274
- attn_mask_cum_idx += len(batched_data)
275
- concatenated[feature] = np.concatenate(arrays)
276
- else:
277
- arrays = [
278
- np.array(item[feature])
279
- for item in batched_data
280
- if feature in item
281
- ]
282
- concatenated[feature] = np.concatenate(arrays)
283
- chunked_data.append(concatenated)
284
- yield self.collate_fn(chunked_data)
285
- len_remaining -= 1
286
- if not len_remaining:
287
- return
288
- # yield a no-op for cases where we don't have any data left to pack
289
- for i in range(0, len_remaining):
290
- yield self.collate_fn(
291
- [
292
- {
293
- "input_ids": [0],
294
- "labels": [-100],
295
- "attention_mask": [True],
296
- "position_ids": [0],
297
- }
298
- ]
299
- )
300
-
301
- def _len_est(self):
302
- lengths_sum = np.sum(self.lengths)
303
- lengths_sum_per_device = lengths_sum // self.device_count
304
- LOG.info(
305
- f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
306
- f"total_num_tokens per device: {lengths_sum_per_device}"
307
- )
308
-
309
- # shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
310
- return (
311
- math.floor(
312
- 0.99
313
- * lengths_sum_per_device
314
- / self.packing_efficiency_estimate
315
- // self.seq_max_length
316
- // self.batch_size
317
- )
318
- - 1
319
- )
320
-
321
- def __len__(self):
322
- # this doesn't return the actual length b/c with distributed samplers, not all dataloaders get
323
- # the same share of total tokens
324
- # if not self.eff_total_used:
325
- # batches, _ = self.generate_batches(set_stats=True)
326
- # LOG.info(
327
- # f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
328
- # f"actual packing efficiency: {self.efficiency()}"
329
- # )
330
- return max(1, self._len_est())
331
-
332
- def len_w_stats(self):
333
- if not self.eff_total_used:
334
- batches, _ = self.generate_batches(set_stats=True)
335
- LOG.info(
336
- f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
337
- f"actual packing efficiency: {self.efficiency()}"
338
- )
339
- return max(1, self._len_est())
340
-
341
- def efficiency(self):
342
- return self.eff_total_used / self.eff_total_slots