winglian commited on
Commit
eea2731
·
1 Parent(s): 1db46a9

add streaming dataset support for pretraining datasets

Browse files
README.md CHANGED
@@ -410,6 +410,8 @@ optimizer:
410
  # specify weight decay
411
  weight_decay:
412
 
 
 
413
  # whether to use xformers attention patch https://github.com/facebookresearch/xformers:
414
  xformers_attention:
415
  # whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
 
410
  # specify weight decay
411
  weight_decay:
412
 
413
+ # whether to bettertransformers
414
+ flash_optimum:
415
  # whether to use xformers attention patch https://github.com/facebookresearch/xformers:
416
  xformers_attention:
417
  # whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
scripts/finetune.py CHANGED
@@ -14,7 +14,6 @@ import torch
14
  import yaml
15
 
16
  # add src to the pythonpath so we don't need to pip install this
17
- from datasets import Dataset
18
  from optimum.bettertransformer import BetterTransformer
19
  from transformers import GenerationConfig, TextStreamer
20
 
@@ -208,14 +207,11 @@ def train(
208
  tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
209
  )
210
  else:
211
- if cfg.pretraining_dataset is True:
212
- pretraining_dataset = "togethercomputer/RedPajama-Data-1T"
213
- else:
214
- pretraining_dataset = cfg.pretraining_dataset
215
  train_dataset = load_pretraining_dataset(
216
- pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len
217
  )
218
- train_dataset = Dataset.from_list(list(train_dataset))
 
219
  eval_dataset = None
220
 
221
  if cfg.debug or "debug" in kwargs:
@@ -262,19 +258,6 @@ def train(
262
  model.save_pretrained(cfg.output_dir)
263
  return
264
 
265
- if cfg.debug:
266
- logging.info("check_dataset_labels...")
267
- check_dataset_labels(
268
- train_dataset.select(
269
- [random.randrange(0, len(train_dataset) - 1) for i in range(5)] # nosec
270
- ),
271
- tokenizer,
272
- )
273
-
274
- if prepare_ds_only:
275
- logging.info("Finished preparing dataset. Exiting...")
276
- return
277
-
278
  model.train()
279
 
280
  trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
 
14
  import yaml
15
 
16
  # add src to the pythonpath so we don't need to pip install this
 
17
  from optimum.bettertransformer import BetterTransformer
18
  from transformers import GenerationConfig, TextStreamer
19
 
 
207
  tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
208
  )
209
  else:
 
 
 
 
210
  train_dataset = load_pretraining_dataset(
211
+ cfg.pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len
212
  )
213
+ # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
214
+ train_dataset = train_dataset.with_format("torch")
215
  eval_dataset = None
216
 
217
  if cfg.debug or "debug" in kwargs:
 
258
  model.save_pretrained(cfg.output_dir)
259
  return
260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  model.train()
262
 
263
  trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
src/axolotl/utils/data.py CHANGED
@@ -1,12 +1,12 @@
1
  """Module containing data utilities"""
2
-
3
  import logging
4
  from hashlib import md5
5
  from pathlib import Path
6
  from typing import List, Tuple, Union
7
 
8
  import torch
9
- from datasets import Dataset, DatasetDict, IterableDataset, load_dataset, load_from_disk
10
  from huggingface_hub import hf_hub_download
11
  from transformers import PreTrainedTokenizerBase
12
 
@@ -399,32 +399,116 @@ def load_prepare_datasets(
399
  return train_dataset, eval_dataset
400
 
401
 
402
- class PretrainingDatasetWrapper(IterableDataset):
403
- """
404
- Wrapper for pretraining dataset that avoids loading the dataset into memory
405
- """
406
-
407
- def __init__(self, tokenizer, dataset_path, max_tokens=2048):
408
- self.tokenizer = tokenizer
409
- self.dataset_path = dataset_path
410
- self.max_tokens = max_tokens
411
-
412
- def __iter__(self):
413
- buffer = []
414
- for sample in load_dataset(
415
- self.dataset_path,
416
- )["train"].shuffle():
417
- buffer += self.tokenizer(sample["text"])["input_ids"]
418
- buffer += [self.tokenizer.eos_token_id]
419
- while len(buffer) > self.max_tokens:
420
- input_ids = torch.tensor(buffer[: self.max_tokens])
421
- yield {
422
- "input_ids": input_ids,
423
- "attention_mask": torch.ones(input_ids.size()),
424
- "labels": input_ids,
425
- }
426
- buffer = buffer[self.max_tokens :]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
 
428
 
429
  def load_pretraining_dataset(path, tokenizer, max_tokens=2048):
430
- return PretrainingDatasetWrapper(tokenizer, path, max_tokens=max_tokens)
 
 
 
 
 
 
1
  """Module containing data utilities"""
2
+ import functools
3
  import logging
4
  from hashlib import md5
5
  from pathlib import Path
6
  from typing import List, Tuple, Union
7
 
8
  import torch
9
+ from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
10
  from huggingface_hub import hf_hub_download
11
  from transformers import PreTrainedTokenizerBase
12
 
 
399
  return train_dataset, eval_dataset
400
 
401
 
402
+ def encode_pretraining(tokenizer, max_tokens, examples):
403
+ res = tokenizer(
404
+ examples["text"],
405
+ truncation=True,
406
+ max_length=max_tokens - 2,
407
+ add_special_tokens=True,
408
+ )
409
+ # Convert to PyTorch tensors
410
+ input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
411
+ attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
412
+ new_input_ids = []
413
+ new_attention_mask = []
414
+ # Append EOS and PAD tokens to input_ids, and correct attention_mask
415
+ for i, _ in enumerate(input_ids):
416
+ input_ids[i] = torch.cat(
417
+ (
418
+ input_ids[i],
419
+ torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]),
420
+ ),
421
+ dim=0,
422
+ )
423
+ attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
424
+
425
+ # Concatenate tokens so that their lengths are less than max_tokens
426
+ buffer_input_ids = torch.tensor([], dtype=torch.long)
427
+ buffer_attention_mask = torch.tensor([], dtype=torch.long)
428
+
429
+ for ids, mask in zip(input_ids, attention_mask):
430
+ if buffer_input_ids.numel() == max_tokens:
431
+ new_input_ids.append(buffer_input_ids)
432
+ new_attention_mask.append(buffer_attention_mask)
433
+ buffer_input_ids = torch.tensor([], dtype=torch.long)
434
+ buffer_attention_mask = torch.tensor([], dtype=torch.long)
435
+ buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
436
+ buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
437
+ elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
438
+ buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
439
+ buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
440
+ else:
441
+ buffer_input_ids = torch.cat(
442
+ (
443
+ buffer_input_ids,
444
+ torch.full(
445
+ (max_tokens - buffer_input_ids.numel(),),
446
+ tokenizer.pad_token_id,
447
+ dtype=torch.long,
448
+ ),
449
+ ),
450
+ dim=0,
451
+ )
452
+ buffer_attention_mask = torch.cat(
453
+ (
454
+ buffer_attention_mask,
455
+ torch.full(
456
+ (max_tokens - buffer_attention_mask.numel(),),
457
+ 0,
458
+ dtype=torch.long,
459
+ ),
460
+ ),
461
+ dim=0,
462
+ )
463
+ new_input_ids.append(buffer_input_ids)
464
+ new_attention_mask.append(buffer_attention_mask)
465
+ buffer_input_ids = torch.tensor([], dtype=torch.long)
466
+ buffer_attention_mask = torch.tensor([], dtype=torch.long)
467
+
468
+ buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
469
+ buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
470
+
471
+ if buffer_input_ids.numel() > 0: # for any leftover tokens
472
+ while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size
473
+ buffer_input_ids = torch.cat(
474
+ (
475
+ buffer_input_ids,
476
+ torch.full(
477
+ (max_tokens - buffer_input_ids.numel(),),
478
+ tokenizer.pad_token_id,
479
+ dtype=torch.long,
480
+ ),
481
+ ),
482
+ dim=0,
483
+ )
484
+ buffer_attention_mask = torch.cat(
485
+ (
486
+ buffer_attention_mask,
487
+ torch.full(
488
+ (max_tokens - buffer_attention_mask.numel(),),
489
+ 0,
490
+ dtype=torch.long,
491
+ ),
492
+ ),
493
+ dim=0,
494
+ )
495
+ new_input_ids.append(buffer_input_ids)
496
+ new_attention_mask.append(buffer_attention_mask)
497
+
498
+ ret = {
499
+ "input_ids": [seq.tolist() for seq in new_input_ids],
500
+ "labels": [seq.tolist() for seq in new_input_ids],
501
+ "attention_mask": [seq.tolist() for seq in new_attention_mask],
502
+ }
503
+
504
+ logging.debug(len(ret["input_ids"]))
505
+ return ret
506
 
507
 
508
  def load_pretraining_dataset(path, tokenizer, max_tokens=2048):
509
+ encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
510
+ dataset = load_dataset(path, streaming=True, split="train")
511
+ dataset = dataset.shuffle(seed=42, buffer_size=10_000)
512
+ # TODO dynamically figure out which columns/features to remove
513
+ dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
514
+ return dataset
src/axolotl/utils/validation.py CHANGED
@@ -77,6 +77,11 @@ def validate_config(cfg):
77
  f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
78
  )
79
 
 
 
 
 
 
80
  # TODO
81
  # MPT 7b
82
  # https://github.com/facebookresearch/bitsandbytes/issues/25
 
77
  f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
78
  )
79
 
80
+ if cfg.pretraining_dataset and cfg.group_by_length:
81
+ logging.warning(
82
+ "You probably want to disable group_by_length as it will force a streamed dataset to download completely."
83
+ )
84
+
85
  # TODO
86
  # MPT 7b
87
  # https://github.com/facebookresearch/bitsandbytes/issues/25
tests/test_validation.py CHANGED
@@ -198,3 +198,54 @@ class ValidationTest(unittest.TestCase):
198
  )
199
 
200
  validate_config(cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  )
199
 
200
  validate_config(cfg)
201
+
202
+ def test_flash_optimum(self):
203
+ cfg = DictDefault(
204
+ {
205
+ "flash_optimum": True,
206
+ "adapter": "lora",
207
+ }
208
+ )
209
+
210
+ with self._caplog.at_level(logging.WARNING):
211
+ validate_config(cfg)
212
+ assert any(
213
+ "BetterTransformers probably doesn't work with PEFT adapters"
214
+ in record.message
215
+ for record in self._caplog.records
216
+ )
217
+
218
+ cfg = DictDefault(
219
+ {
220
+ "flash_optimum": True,
221
+ }
222
+ )
223
+
224
+ with self._caplog.at_level(logging.WARNING):
225
+ validate_config(cfg)
226
+ assert any(
227
+ "probably set bfloat16 or float16" in record.message
228
+ for record in self._caplog.records
229
+ )
230
+
231
+ cfg = DictDefault(
232
+ {
233
+ "flash_optimum": True,
234
+ "fp16": True,
235
+ }
236
+ )
237
+ regex_exp = r".*AMP is not supported.*"
238
+
239
+ with pytest.raises(ValueError, match=regex_exp):
240
+ validate_config(cfg)
241
+
242
+ cfg = DictDefault(
243
+ {
244
+ "flash_optimum": True,
245
+ "bf16": True,
246
+ }
247
+ )
248
+ regex_exp = r".*AMP is not supported.*"
249
+
250
+ with pytest.raises(ValueError, match=regex_exp):
251
+ validate_config(cfg)