winglian commited on
Commit
a6b37bd
·
unverified ·
1 Parent(s): b752080

revert multipack batch sampler changes (#1672)

Browse files

* revert multipack batch sampler changes

* fix default val for drop_last

src/axolotl/utils/samplers/multipack.py CHANGED
@@ -1,64 +1,105 @@
 
1
  """
2
  Multipack Batch Sampler
3
  """
4
  import logging
5
- from concurrent.futures import ProcessPoolExecutor
6
- from multiprocessing import cpu_count
 
7
 
8
  import numba
9
  import numpy as np
10
- from torch.utils.data import BatchSampler
11
 
12
  LOG = logging.getLogger("axolotl.utils.samplers.multipack")
13
 
14
 
15
- # First-fit-decreasing bin packing.
16
  @numba.njit
17
- def pack_group(items, group_offset, bin_capacity, max_items_per_bin):
18
- idxs = np.argsort(items)[::-1]
19
- sorted_items = items[idxs]
20
- num_bins = len(items)
21
- bins = np.full(num_bins, bin_capacity, dtype=np.int32)
22
- bin_counts = np.zeros(num_bins, dtype=np.int32)
23
- group_packing = np.full((num_bins, max_items_per_bin), -1, dtype=np.int32)
24
-
25
- for idx, item in enumerate(sorted_items):
26
- global_idx = idxs[idx] + group_offset
27
-
28
- placed = False
29
- for i in range(num_bins):
30
- if bins[i] >= item and bin_counts[i] < max_items_per_bin:
31
- bins[i] -= item
32
- group_packing[i, bin_counts[i]] = global_idx
33
- bin_counts[i] += 1
34
- placed = True
35
  break
36
 
37
- if not placed:
38
- raise ValueError(
39
- f"Item could not be packed. Try increasing cfg.sample_packing_bin_size ({max_items_per_bin})."
40
- )
41
 
42
- return group_packing
43
 
44
 
45
- def pack(items, bin_capacity, group_size, max_items_per_bin):
46
- num_items = len(items)
47
- num_processes = max(1, min(num_items // group_size, cpu_count()))
48
- tasks = [
49
- (items[i : i + group_size], i, bin_capacity, max_items_per_bin)
50
- for i in range(0, num_items, group_size)
51
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- packed_bins = []
54
- with ProcessPoolExecutor(max_workers=num_processes) as executor:
55
- for group_packing in executor.map(pack_group, *zip(*tasks)):
56
- for bin_pack in group_packing:
57
- filtered_pack = bin_pack[bin_pack != -1]
58
- if filtered_pack.size > 0:
59
- packed_bins.append(filtered_pack.tolist())
60
 
61
- return packed_bins
62
 
63
 
64
  class MultipackBatchSampler(BatchSampler):
@@ -68,63 +109,95 @@ class MultipackBatchSampler(BatchSampler):
68
 
69
  def __init__(
70
  self,
71
- sampler,
72
- lengths,
73
- batch_max_len,
74
- batch_size,
75
- group_size=100_000,
76
- bin_size=200,
77
- drop_last=False,
78
  ):
79
- self.sampler = sampler
80
- self.lengths = np.array(lengths, dtype=np.int32)
81
- self.batch_max_len = batch_max_len
82
  self.batch_size = batch_size
83
- self.group_size = group_size if group_size is not None else 100_000
84
- self.bin_size = bin_size if bin_size is not None else 200
85
- self.drop_last = drop_last
86
 
87
- self._efficiency = None
88
- self._batches = None
89
 
90
- def efficiency(self):
91
- if self._efficiency is None:
92
- self._batches = self._pack_batches()
93
- return self._efficiency
94
-
95
- def _pack_batches(self):
96
- # Get possibly shuffled indices from sampler.
97
- sample_idxs = np.arange(len(self.sampler))
98
- lengths = self.lengths[sample_idxs]
99
-
100
- pack_idxs = pack(
101
- lengths,
102
- self.batch_max_len,
103
- self.group_size,
104
- self.bin_size,
105
- )
106
 
107
- used_tokens = self.lengths.sum()
108
- available_tokens = len(pack_idxs) * self.batch_max_len
109
- self._efficiency = used_tokens / available_tokens
 
 
 
 
 
 
 
 
 
 
110
 
111
- # Wrap packs into batches.
112
- batch_idxs = [
113
- pack_idxs[i : i + self.batch_size]
114
- for i in range(0, len(pack_idxs), self.batch_size)
 
 
115
  ]
116
 
117
- # Drop last batch if needed.
118
- if self.drop_last and len(batch_idxs[-1]) < self.batch_size:
119
- batch_idxs = batch_idxs[:-1]
 
120
 
121
- return batch_idxs
122
 
123
  def __iter__(self):
124
- self._batches = self._pack_batches()
125
- return iter(self._batches)
 
 
 
 
 
 
 
126
 
127
  def __len__(self):
128
- if self._batches is None:
129
- self._batches = self._pack_batches()
130
- return len(self._batches)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
  """
3
  Multipack Batch Sampler
4
  """
5
  import logging
6
+ import math
7
+ import os
8
+ from typing import Any, Iterable, List, Union
9
 
10
  import numba
11
  import numpy as np
12
+ from torch.utils.data import BatchSampler, Sampler
13
 
14
  LOG = logging.getLogger("axolotl.utils.samplers.multipack")
15
 
16
 
 
17
  @numba.njit
18
+ def ffd_check(a: np.ndarray, c: int, n: int):
19
+ # First-fit-decreasing bin packing
20
+ # Check if a[] could fit in n bins with capacity c
21
+ # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
22
+
23
+ a = np.sort(a)[::-1]
24
+ bins = np.full((n,), c, dtype=a.dtype)
25
+ for size in a:
26
+ not_found = True
27
+ for idx in range(n):
28
+ if bins[idx] >= size:
29
+ bins[idx] -= size
30
+ not_found = False
 
 
 
 
 
31
  break
32
 
33
+ if not_found:
34
+ return False
 
 
35
 
36
+ return True
37
 
38
 
39
+ @numba.njit
40
+ def ffd_with_result(a: np.ndarray, c: int, start_index: int):
41
+ # First-fit-decreasing bin packing (with result return)
42
+
43
+ indices = np.argsort(a)[::-1]
44
+ a = a[indices]
45
+
46
+ bins: List[Any] = []
47
+ bins_result: List[Any] = []
48
+ for a_id, size in enumerate(a):
49
+ add_new = True
50
+ for idx in range(len(bins)):
51
+ if bins[idx] >= size:
52
+ bins[idx] -= size
53
+ bins_result[idx].append(indices[a_id] + start_index)
54
+ add_new = False
55
+ break
56
+
57
+ if add_new:
58
+ bins.append(c - size)
59
+ bins_result.append([indices[a_id] + start_index])
60
+
61
+ return bins_result
62
+
63
+
64
+ @numba.njit
65
+ def allocate(
66
+ lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
67
+ ):
68
+ # Dynamic batch allocator, similar to Multifit
69
+ # https://en.wikipedia.org/wiki/Multifit_algorithm
70
+ # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
71
+
72
+ s = 0
73
+ start_index = 0
74
+ result = []
75
+
76
+ while True:
77
+ # binary search [l, r)
78
+ left = 1
79
+ right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
80
+
81
+ while right - left > 1:
82
+ mid = (left + right) // 2
83
+ if ffd_check(lengths[start_index : start_index + mid], c, n):
84
+ left = mid
85
+ else:
86
+ right = mid
87
+
88
+ # use length l
89
+ batch = ffd_with_result(
90
+ lengths[start_index : start_index + left], c, start_index
91
+ )
92
+ assert len(batch) <= n
93
+ if len(batch) < n:
94
+ break
95
+
96
+ start_index += left
97
+ s = lengths_cumsum[start_index - 1]
98
 
99
+ # add local rank
100
+ result.append(batch[rank])
 
 
 
 
 
101
 
102
+ return result, s, len(result) * c * n
103
 
104
 
105
  class MultipackBatchSampler(BatchSampler):
 
109
 
110
  def __init__(
111
  self,
112
+ sampler: Union[Sampler[int], Iterable[int]],
113
+ batch_size: int,
114
+ batch_max_len: int,
115
+ lengths: np.ndarray,
116
+ packing_efficiency_estimate: float = 1.0,
117
+ drop_last: bool = False,
118
+ **kwargs,
119
  ):
120
+ super().__init__(sampler, batch_size, drop_last)
 
 
121
  self.batch_size = batch_size
122
+ self.batch_max_len = batch_max_len
123
+ self.lengths: np.ndarray = lengths
124
+ self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
125
 
126
+ assert isinstance(self.lengths, np.ndarray)
 
127
 
128
+ self.epoch = 0
129
+
130
+ # statistics
131
+ self.eff_total_used = 0
132
+ self.eff_total_slots = 0
133
+
134
+ def set_epoch(self, epoch: int):
135
+ self.epoch = epoch
 
 
 
 
 
 
 
 
136
 
137
+ def generate_batches(self, set_stats=False):
138
+ indices = [idx for idx in self.sampler]
139
+
140
+ lengths = self.lengths[indices]
141
+ lengths_cumsum = np.cumsum(lengths)
142
+
143
+ batches, total_used, total_slots = allocate(
144
+ lengths=lengths,
145
+ lengths_cumsum=lengths_cumsum,
146
+ rank=0,
147
+ c=self.batch_max_len,
148
+ n=1,
149
+ )
150
 
151
+ batches = [
152
+ [
153
+ [indices[b_idx] for b_idx in batch]
154
+ for batch in batches[i : i + self.batch_size]
155
+ ]
156
+ for i in range(0, len(batches), self.batch_size)
157
  ]
158
 
159
+ # statistics
160
+ if set_stats:
161
+ self.eff_total_used += total_used
162
+ self.eff_total_slots += total_slots
163
 
164
+ return batches
165
 
166
  def __iter__(self):
167
+ batches = self.generate_batches(set_stats=True)
168
+ return iter(batches)
169
+
170
+ def num_batches(self):
171
+ batches = self.generate_batches(set_stats=True)
172
+ return len(batches)
173
+
174
+ def efficiency(self):
175
+ return self.eff_total_used / self.eff_total_slots
176
 
177
  def __len__(self):
178
+ self.num_batches()
179
+ return self._len_est()
180
+
181
+ def _len_est(self):
182
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
183
+ lengths_sum = np.sum(self.lengths)
184
+ lengths_sum_per_device = lengths_sum // world_size
185
+ LOG.info(
186
+ f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
187
+ f"total_num_tokens per device: {lengths_sum_per_device}"
188
+ )
189
+
190
+ # shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
191
+ return max(
192
+ 0,
193
+ (
194
+ world_size
195
+ * math.floor(
196
+ 0.99
197
+ * lengths_sum_per_device
198
+ / self.packing_efficiency_estimate
199
+ // (self.batch_max_len * self.batch_size)
200
+ )
201
+ - 1
202
+ ),
203
+ )