winglian commited on
Commit
2bb0b78
·
unverified ·
1 Parent(s): a276c9c

Attention mask and position id fixes for packing (#285)

Browse files

* fix attetion mask with packing

* set position ids and use block diagonal attn mask

* fix expand mask for multiple batch items, make sure we pad position_ids

* don't move masks to cpu

* use multi pack dataloader w random sampler

* add position_ids back

* more fixes for dataloader integration

* est total tokens, fix field loop

* more fixes, position_ids seems broken

* more fixes for sample packing

* use distributed sampler, avoid accelerate prepare

* use accelerator prepare for dataloader

* fix for position_ids w packing

* Update src/axolotl/utils/dataloader.py

* validation for sample packing and doc

* more fixes for 4k and optimizations

* optimized expand mask fn

* better handling of variance in multipack dataloader length and trainer hanging when it runs out of data

* fix rounding of len of batches to int

* better handling so that all devices have the same dataloader len

* fix step calc for packing

* pass sample packing efficiency to training args

* add a test for the mask expansion for sequence packing

* only process eval dataset for packing if not None

* don't split batches when packing

* weighted CE losses

* weighted CEL fixes

* limit packing to sequences of max seq len

* seq_len_multiple for packing

* make sure the chunk size is an int

* sample_packing_seq_len_multiplier config

* use cumulative seq len with var len flash attn v2 w packing

* properly calculate max len

* fix flash-attn, xformers, packing, support chatml

* fix chatml system prompt for openorca, legacy tokenizer opts

* add chatml

* add unit tests for cum seq lens, add ability to build cu_seq_lens from positional ids, fix prompt test

* fix test and pylint checks

* more packing and dataset optimizations and fixes

* filter w multiple cpus

* more fixes and optimizations

* fixes and go back to distributed sampler since batch sampler won't work

* fix counts by accounting for num devices

* fix steps calculation

* previous accelerate is still most performant

* add numba to requirements.

* use custom distributed checks

* fix sampler to prevent overfit w new epochs

* let's not cleanup the cached datasets

* calculate cum seq lens with pos_ids instead of mask, simplify packing params, fix distributed barrier

* speed optimizations and set accelerate fsdp env vars

* optimize dataset concatenation?

* more optimizations for dataset handling

* fix import for annotation

* manual pre-commit fixes

* another sum optimization and bug fix for calc steps

* fix packing estimations

* fix formatting

* pylint problems

* add back flash attention branch for handling unpacked sequences seperately

* Address PR feedback

* add optional sample packing config params to readme

README.md CHANGED
@@ -375,7 +375,14 @@ dataset_shard_idx:
375
  sequence_len: 2048
376
  # max sequence length to concatenate training samples together up to
377
  # inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
 
378
  max_packed_sequence_len: 1024
 
 
 
 
 
 
379
 
380
  # if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
381
  adapter: lora
@@ -421,6 +428,7 @@ learning_rate: 0.00003
421
  logging_steps:
422
  save_steps:
423
  eval_steps:
 
424
 
425
  # save model as safetensors (require safetensors package)
426
  save_safetensors:
@@ -534,7 +542,7 @@ accelerate launch scripts/finetune.py configs/your_config.yml
534
 
535
  #### Multi-GPU
536
 
537
- It is recommended to pre-tokenize dataset with the following before finetuning:
538
  ```bash
539
  CUDA_VISIBLE_DEVICES="" accelerate ... --prepare_ds_only
540
  ```
 
375
  sequence_len: 2048
376
  # max sequence length to concatenate training samples together up to
377
  # inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
378
+ # FutureWarning: This will soon be DEPRECATED
379
  max_packed_sequence_len: 1024
380
+ # use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
381
+ sample_packing:
382
+ # you can set these packing optimizations AFTER starting a training at least once.
383
+ # The trainer will provide recommended values for these values.
384
+ sample_packing_eff_est:
385
+ total_num_tokens:
386
 
387
  # if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
388
  adapter: lora
 
428
  logging_steps:
429
  save_steps:
430
  eval_steps:
431
+ save_total_limit:
432
 
433
  # save model as safetensors (require safetensors package)
434
  save_safetensors:
 
542
 
543
  #### Multi-GPU
544
 
545
+ You can optionally pre-tokenize dataset with the following before finetuning:
546
  ```bash
547
  CUDA_VISIBLE_DEVICES="" accelerate ... --prepare_ds_only
548
  ```
requirements.txt CHANGED
@@ -13,6 +13,8 @@ einops
13
  xformers
14
  optimum
15
  hf_transfer
 
 
16
  # qlora things
17
  bert-score==0.3.13
18
  evaluate==0.4.0
 
13
  xformers
14
  optimum
15
  hf_transfer
16
+ numba
17
+ numpy==1.24.4
18
  # qlora things
19
  bert-score==0.3.13
20
  evaluate==0.4.0
scripts/finetune.py CHANGED
@@ -21,9 +21,14 @@ from axolotl.logging_config import configure_logging
21
  from axolotl.utils.bench import log_gpu_memory_usage
22
  from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
23
  from axolotl.utils.dict import DictDefault
 
24
  from axolotl.utils.models import load_model, load_tokenizer
25
  from axolotl.utils.tokenization import check_dataset_labels
26
- from axolotl.utils.trainer import setup_trainer
 
 
 
 
27
  from axolotl.utils.validation import validate_config
28
  from axolotl.utils.wandb import setup_wandb_env_vars
29
 
@@ -232,12 +237,25 @@ def train(
232
  cfg.pretraining_dataset,
233
  tokenizer,
234
  max_tokens=cfg.sequence_len,
235
- seed=cfg.seed,
236
  )
237
  # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
238
  train_dataset = train_dataset.with_format("torch")
239
  eval_dataset = None
240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  if cfg.debug or "debug" in kwargs:
242
  LOG.info("check_dataset_labels...")
243
  check_dataset_labels(
@@ -254,7 +272,7 @@ def train(
254
  log_gpu_memory_usage(LOG, "baseline", cfg.device)
255
 
256
  # Load the model and tokenizer
257
- LOG.info("loading model and peft_config...")
258
  model, peft_config = load_model(cfg, tokenizer)
259
 
260
  safe_serialization = cfg.save_safetensors is True
@@ -288,7 +306,9 @@ def train(
288
  model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
289
  return
290
 
291
- trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
 
 
292
 
293
  model.config.use_cache = False
294
 
@@ -347,14 +367,12 @@ def train(
347
  # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
348
  # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
349
  if cfg.fsdp:
350
- model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
351
  elif cfg.local_rank == 0:
352
  if cfg.flash_optimum:
353
  model = BetterTransformer.reverse(model)
354
  model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
355
 
356
- # trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
357
-
358
 
359
  if __name__ == "__main__":
360
  fire.Fire(train)
 
21
  from axolotl.utils.bench import log_gpu_memory_usage
22
  from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
23
  from axolotl.utils.dict import DictDefault
24
+ from axolotl.utils.distributed import barrier, is_main_process
25
  from axolotl.utils.models import load_model, load_tokenizer
26
  from axolotl.utils.tokenization import check_dataset_labels
27
+ from axolotl.utils.trainer import (
28
+ calculate_total_num_steps,
29
+ process_datasets_for_packing,
30
+ setup_trainer,
31
+ )
32
  from axolotl.utils.validation import validate_config
33
  from axolotl.utils.wandb import setup_wandb_env_vars
34
 
 
237
  cfg.pretraining_dataset,
238
  tokenizer,
239
  max_tokens=cfg.sequence_len,
240
+ seed=cfg.seed or 42,
241
  )
242
  # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
243
  train_dataset = train_dataset.with_format("torch")
244
  eval_dataset = None
245
 
246
+ if is_main_process():
247
+ # process on rank 0 first so it gets cached so other ranks load from cache
248
+ train_dataset, eval_dataset = process_datasets_for_packing(
249
+ cfg, train_dataset, eval_dataset
250
+ )
251
+ barrier()
252
+ if not is_main_process():
253
+ train_dataset, eval_dataset = process_datasets_for_packing(
254
+ cfg, train_dataset, eval_dataset
255
+ )
256
+ barrier()
257
+ total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
258
+
259
  if cfg.debug or "debug" in kwargs:
260
  LOG.info("check_dataset_labels...")
261
  check_dataset_labels(
 
272
  log_gpu_memory_usage(LOG, "baseline", cfg.device)
273
 
274
  # Load the model and tokenizer
275
+ LOG.info("loading model and (optionally) peft_config...")
276
  model, peft_config = load_model(cfg, tokenizer)
277
 
278
  safe_serialization = cfg.save_safetensors is True
 
306
  model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
307
  return
308
 
309
+ trainer = setup_trainer(
310
+ cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
311
+ )
312
 
313
  model.config.use_cache = False
314
 
 
367
  # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
368
  # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
369
  if cfg.fsdp:
370
+ trainer.save_model(cfg.output_dir)
371
  elif cfg.local_rank == 0:
372
  if cfg.flash_optimum:
373
  model = BetterTransformer.reverse(model)
374
  model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
375
 
 
 
376
 
377
  if __name__ == "__main__":
378
  fire.Fire(train)
src/axolotl/datasets.py CHANGED
@@ -5,7 +5,7 @@ import os
5
  from typing import List
6
 
7
  import torch
8
- from datasets import IterableDataset
9
 
10
  from .prompt_tokenizers import PromptTokenizingStrategy
11
 
@@ -18,9 +18,9 @@ from .prompt_tokenizers import PromptTokenizingStrategy
18
  LOG = logging.getLogger("axolotl")
19
 
20
 
21
- class TokenizedPromptDataset(IterableDataset):
22
  """
23
- Iterable dataset that returns tokenized prompts from a stream of text files.
24
  Args:
25
  prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
26
  dataset (dataset.Dataset): Dataset with text files.
@@ -30,19 +30,18 @@ class TokenizedPromptDataset(IterableDataset):
30
  self,
31
  prompt_tokenizer: PromptTokenizingStrategy,
32
  dataset: IterableDataset,
 
33
  ):
34
  self.prompt_tokenizer = prompt_tokenizer
35
- self.dataset = dataset
36
-
37
- def __iter__(self):
38
- features = self.dataset.features.keys()
39
- num_proc = os.cpu_count()
40
- return iter(
41
- self.dataset.map(
42
- self.prompt_tokenizer.tokenize_prompt,
43
- num_proc=num_proc,
44
- remove_columns=features,
45
- )
46
  )
47
 
48
 
@@ -77,14 +76,21 @@ class ConstantLengthDataset(IterableDataset):
77
  self.tokens_dtype = torch.int64
78
 
79
  def __iter__(self):
80
- buffer = {"input_ids": [], "attention_mask": [], "labels": []}
 
 
 
 
 
81
  buffer_len = 0
82
  for dataset in self.datasets:
 
83
  iterator = iter(dataset)
84
  more_examples = True
85
  while more_examples:
86
  try:
87
  example = next(iterator)
 
88
  except StopIteration:
89
  more_examples = False
90
  example = None
@@ -106,6 +112,9 @@ class ConstantLengthDataset(IterableDataset):
106
  attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
107
  : self.seq_length
108
  ]
 
 
 
109
  labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
110
  if labels.size() == input_ids.size() and (
111
  attention_mask.size() == input_ids.size()
@@ -114,6 +123,7 @@ class ConstantLengthDataset(IterableDataset):
114
  "input_ids": input_ids,
115
  "labels": labels,
116
  "attention_mask": attention_mask,
 
117
  }
118
  else:
119
  LOG.warning(
@@ -123,8 +133,10 @@ class ConstantLengthDataset(IterableDataset):
123
  "input_ids": [],
124
  "attention_mask": [],
125
  "labels": [],
 
126
  }
127
  buffer_len = 0
 
128
 
129
  if example:
130
  # FIXME
@@ -133,11 +145,6 @@ class ConstantLengthDataset(IterableDataset):
133
  input_ids = example["input_ids"]
134
  attention_mask = example["attention_mask"]
135
  labels = example["labels"]
136
- if (
137
- buffer["input_ids"]
138
- and input_ids[0] == self.tokenizer.bos_token_id
139
- ):
140
- attention_mask[0] = 0
141
 
142
  if add_concat_token:
143
  input_ids.append(self.concat_token_id)
@@ -148,13 +155,17 @@ class ConstantLengthDataset(IterableDataset):
148
  input_ids, dtype=self.tokens_dtype
149
  )
150
  attention_mask_with_concat = torch.tensor(
151
- attention_mask, dtype=self.tokens_dtype
152
  )
153
  labels_with_concat = torch.tensor(
154
  labels, dtype=self.tokens_dtype
155
  )
 
 
 
156
 
157
  buffer["input_ids"].append(input_ids_with_concat)
158
  buffer["attention_mask"].append(attention_mask_with_concat)
159
  buffer["labels"].append(labels_with_concat)
 
160
  buffer_len += len(input_ids)
 
5
  from typing import List
6
 
7
  import torch
8
+ from datasets import Dataset, IterableDataset
9
 
10
  from .prompt_tokenizers import PromptTokenizingStrategy
11
 
 
18
  LOG = logging.getLogger("axolotl")
19
 
20
 
21
+ class TokenizedPromptDataset(Dataset):
22
  """
23
+ Dataset that returns tokenized prompts from a stream of text files.
24
  Args:
25
  prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
26
  dataset (dataset.Dataset): Dataset with text files.
 
30
  self,
31
  prompt_tokenizer: PromptTokenizingStrategy,
32
  dataset: IterableDataset,
33
+ **kwargs,
34
  ):
35
  self.prompt_tokenizer = prompt_tokenizer
36
+ super().__init__(self.process(dataset).data, **kwargs)
37
+
38
+ def process(self, dataset):
39
+ features = dataset.features.keys()
40
+ num_proc = min(64, os.cpu_count())
41
+ return dataset.map(
42
+ self.prompt_tokenizer.tokenize_prompt,
43
+ num_proc=num_proc,
44
+ remove_columns=features,
 
 
45
  )
46
 
47
 
 
76
  self.tokens_dtype = torch.int64
77
 
78
  def __iter__(self):
79
+ buffer = {
80
+ "input_ids": [],
81
+ "attention_mask": [],
82
+ "labels": [],
83
+ "position_ids": [],
84
+ }
85
  buffer_len = 0
86
  for dataset in self.datasets:
87
+ idx = 0
88
  iterator = iter(dataset)
89
  more_examples = True
90
  while more_examples:
91
  try:
92
  example = next(iterator)
93
+ idx += 1
94
  except StopIteration:
95
  more_examples = False
96
  example = None
 
112
  attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
113
  : self.seq_length
114
  ]
115
+ position_ids = torch.cat(buffer["position_ids"], dim=-1)[
116
+ : self.seq_length
117
+ ]
118
  labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
119
  if labels.size() == input_ids.size() and (
120
  attention_mask.size() == input_ids.size()
 
123
  "input_ids": input_ids,
124
  "labels": labels,
125
  "attention_mask": attention_mask,
126
+ "position_ids": position_ids,
127
  }
128
  else:
129
  LOG.warning(
 
133
  "input_ids": [],
134
  "attention_mask": [],
135
  "labels": [],
136
+ "position_ids": [],
137
  }
138
  buffer_len = 0
139
+ idx = 1
140
 
141
  if example:
142
  # FIXME
 
145
  input_ids = example["input_ids"]
146
  attention_mask = example["attention_mask"]
147
  labels = example["labels"]
 
 
 
 
 
148
 
149
  if add_concat_token:
150
  input_ids.append(self.concat_token_id)
 
155
  input_ids, dtype=self.tokens_dtype
156
  )
157
  attention_mask_with_concat = torch.tensor(
158
+ [idx * m for m in attention_mask], dtype=torch.int16
159
  )
160
  labels_with_concat = torch.tensor(
161
  labels, dtype=self.tokens_dtype
162
  )
163
+ position_ids = torch.arange(
164
+ len(input_ids), dtype=self.tokens_dtype
165
+ )
166
 
167
  buffer["input_ids"].append(input_ids_with_concat)
168
  buffer["attention_mask"].append(attention_mask_with_concat)
169
  buffer["labels"].append(labels_with_concat)
170
+ buffer["position_ids"].append(position_ids)
171
  buffer_len += len(input_ids)
src/axolotl/monkeypatch/llama_attn_hijack_flash.py CHANGED
@@ -8,9 +8,18 @@ import torch
8
  import transformers
9
  from einops import rearrange
10
  from flash_attn.bert_padding import pad_input, unpad_input
11
- from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
 
 
 
 
 
 
 
12
  from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
13
 
 
 
14
 
15
  def forward(
16
  self,
@@ -79,6 +88,16 @@ def forward(
79
  dtype=torch.int32,
80
  device=qkv.device,
81
  )
 
 
 
 
 
 
 
 
 
 
82
  output = flash_attn_varlen_qkvpacked_func(
83
  qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
84
  )
@@ -113,6 +132,7 @@ def forward(
113
  "b s (h d) -> b s h d",
114
  h=nheads,
115
  )
 
116
  return (
117
  self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
118
  None,
 
8
  import transformers
9
  from einops import rearrange
10
  from flash_attn.bert_padding import pad_input, unpad_input
11
+
12
+ try:
13
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
14
+ except ImportError:
15
+ from flash_attn.flash_attn_interface import (
16
+ flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
17
+ )
18
+
19
  from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
20
 
21
+ from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
22
+
23
 
24
  def forward(
25
  self,
 
88
  dtype=torch.int32,
89
  device=qkv.device,
90
  )
91
+ output = flash_attn_varlen_qkvpacked_func(
92
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
93
+ )
94
+ output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
95
+ elif position_ids.shape[0] == 1:
96
+ # special handling using sample packing
97
+ qkv = rearrange(qkv, "b s ... -> (b s) ...")
98
+ cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
99
+ cu_q_lens = cu_q_lens.squeeze()
100
+
101
  output = flash_attn_varlen_qkvpacked_func(
102
  qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
103
  )
 
132
  "b s (h d) -> b s h d",
133
  h=nheads,
134
  )
135
+
136
  return (
137
  self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
138
  None,
src/axolotl/monkeypatch/llama_attn_hijack_xformers.py CHANGED
@@ -128,6 +128,7 @@ def xformers_forward(
128
  query_states,
129
  key_states,
130
  value_states,
 
131
  attn_bias=xformers.ops.LowerTriangularMask(),
132
  )
133
  attn_weights = None
 
128
  query_states,
129
  key_states,
130
  value_states,
131
+ # attn_bias=attention_mask,
132
  attn_bias=xformers.ops.LowerTriangularMask(),
133
  )
134
  attn_weights = None
src/axolotl/monkeypatch/llama_expand_mask.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
3
+ """
4
+ from typing import Optional
5
+
6
+ import torch
7
+
8
+
9
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
10
+ """
11
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
12
+ This expansion handles packed sequences so that sequences share the same attention mask integer value
13
+ when they attend to each other within that sequence.
14
+ This expansion transforms the mask to lower triangular form to prevent future peeking.
15
+ """
16
+ bsz, src_len = mask.size()
17
+ tgt_len = tgt_len if tgt_len is not None else src_len
18
+
19
+ mask = mask.unsqueeze(1).unsqueeze(2)
20
+ mask = mask.expand(bsz, 1, tgt_len, src_len)
21
+
22
+ # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
23
+ binary_mask = torch.where(
24
+ mask != 0,
25
+ torch.tensor(1).to(dtype),
26
+ torch.tensor(0).to(dtype),
27
+ )
28
+
29
+ # Create a block-diagonal mask.
30
+ # we multiply by the binary mask so that 0's in the original mask are correctly excluded
31
+ zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
32
+
33
+ # Now let's create a lower triangular mask of ones that will zero out the upper triangular part
34
+ lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
35
+ mask.device
36
+ )
37
+
38
+ # Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
39
+ masked_zero_one_mask = zero_one_mask * lower_triangular_ones
40
+ inverted_mask = 1.0 - masked_zero_one_mask
41
+
42
+ return inverted_mask.masked_fill(
43
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
44
+ )
45
+
46
+
47
+ def hijack_expand_mask():
48
+ import transformers
49
+
50
+ transformers.models.llama.modeling_llama._expand_mask = ( # pylint: disable=protected-access
51
+ _expand_mask
52
+ )
src/axolotl/monkeypatch/utils.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Shared utils for the monkeypatches
3
+ """
4
+ import torch
5
+
6
+
7
+ def get_cu_seqlens(attn_mask):
8
+ """generate a cumulative sequence length mask for flash attention using attn mask"""
9
+ if len(attn_mask.shape) == 1:
10
+ attn_mask = attn_mask.unsqueeze(0)
11
+
12
+ device = attn_mask.device
13
+ results = []
14
+ max_seq_lens = []
15
+
16
+ for row in attn_mask:
17
+ # Exclude zeros to avoid adding their positions to the mask
18
+ t_non_zeros = row[row != 0]
19
+ # Find where the sequence number changes (including the first position)
20
+ seq_change = torch.cat(
21
+ [
22
+ torch.tensor([1], dtype=torch.int32, device=device),
23
+ t_non_zeros[1:] != t_non_zeros[:-1],
24
+ ]
25
+ )
26
+ # Get the indices where the sequence changes
27
+ change_indices = torch.cat(
28
+ [
29
+ (seq_change == 1).nonzero(as_tuple=True)[0],
30
+ torch.tensor([len(t_non_zeros)], dtype=torch.int32, device=device),
31
+ ]
32
+ )
33
+ # Calculate the sequence lengths
34
+ seq_lengths = change_indices[1:] - change_indices[:-1]
35
+ # Calculate the length of the final sequence or padding
36
+ final_seq_length = len(row) - change_indices[-1]
37
+ # Append the length of the final sequence or padding to seq_lengths
38
+ if final_seq_length.item():
39
+ seq_lengths = torch.cat(
40
+ [
41
+ seq_lengths,
42
+ torch.tensor(
43
+ [final_seq_length.item()], dtype=torch.int32, device=device
44
+ ),
45
+ ]
46
+ )
47
+ # Calculate the cumulative sequence lengths
48
+ cu_seqlens = torch.cat(
49
+ [torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
50
+ )
51
+ max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
52
+ results.append(cu_seqlens)
53
+ max_seq_lens.append(max_seq_len)
54
+
55
+ return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
56
+
57
+
58
+ def get_cu_seqlens_from_pos_ids(position_ids):
59
+ """generate a cumulative sequence length mask for flash attention using pos ids"""
60
+ if len(position_ids.shape) == 1:
61
+ position_ids = position_ids.unsqueeze(0)
62
+
63
+ device = position_ids.device
64
+ results = []
65
+ max_seq_lens = []
66
+
67
+ for row in position_ids:
68
+ # Count the number of consecutive zeros from the right side
69
+ padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
70
+
71
+ # Adjust the row to exclude padding
72
+ adjusted_row = row[:-padding_length] if padding_length else row.clone()
73
+
74
+ # Find where the position resets to 0 (indicating a new sequence)
75
+ seq_starts = torch.cat(
76
+ [
77
+ torch.tensor([True], dtype=torch.bool, device=device),
78
+ adjusted_row[1:] == 0,
79
+ ]
80
+ )
81
+ # Get the indices where the sequence starts
82
+ start_indices = torch.cat(
83
+ [
84
+ (seq_starts).nonzero(as_tuple=True)[0],
85
+ torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
86
+ ]
87
+ )
88
+ # Calculate the sequence lengths
89
+ seq_lengths = start_indices[1:] - start_indices[:-1]
90
+ # Calculate the cumulative sequence lengths
91
+ cu_seqlens = torch.cat(
92
+ [torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
93
+ )
94
+ # Append the padding length to the cumulative sequence lengths
95
+ if padding_length:
96
+ cu_seqlens = torch.cat(
97
+ [cu_seqlens, torch.tensor([len(row)], dtype=torch.int32, device=device)]
98
+ )
99
+ max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
100
+ results.append(cu_seqlens)
101
+ max_seq_lens.append(max_seq_len)
102
+
103
+ return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
src/axolotl/prompt_strategies/alpaca_w_system.py CHANGED
@@ -66,7 +66,11 @@ class SystemDataPrompter(AlpacaPrompter):
66
  ) -> Generator[str, None, None]:
67
  # returns the full prompt from instruction and optional input
68
  # if a label (=response, =output) is provided, it's also appended.
69
- formatted_sys_prompt = f"### System:\n{system}\n\n" if system else ""
 
 
 
 
70
  if input:
71
  res = formatted_sys_prompt + self.turn_format.format(
72
  instruction=instruction, input=input
@@ -86,12 +90,20 @@ class OpenOrcaSystemDataPrompter(SystemDataPrompter):
86
  """
87
 
88
  def match_prompt_style(self):
 
89
  if self.prompt_style == PromptStyle.INSTRUCT.value:
90
  self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
91
  self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
92
  if self.prompt_style == PromptStyle.CHAT.value:
93
  self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
94
  self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
 
 
 
 
 
 
 
95
 
96
 
97
  class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
@@ -137,3 +149,12 @@ def load_open_orca(tokenizer, cfg):
137
  cfg.train_on_inputs,
138
  cfg.sequence_len,
139
  )
 
 
 
 
 
 
 
 
 
 
66
  ) -> Generator[str, None, None]:
67
  # returns the full prompt from instruction and optional input
68
  # if a label (=response, =output) is provided, it's also appended.
69
+ formatted_sys_prompt = (
70
+ self.system_format.format(system=system)
71
+ if system and self.system_format
72
+ else ""
73
+ )
74
  if input:
75
  res = formatted_sys_prompt + self.turn_format.format(
76
  instruction=instruction, input=input
 
90
  """
91
 
92
  def match_prompt_style(self):
93
+ # pylint: disable=duplicate-code
94
  if self.prompt_style == PromptStyle.INSTRUCT.value:
95
  self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
96
  self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
97
  if self.prompt_style == PromptStyle.CHAT.value:
98
  self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
99
  self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
100
+ self.system_format = "SYSTEM: {system}\n"
101
+ if self.prompt_style == PromptStyle.CHATML.value:
102
+ self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
103
+ self.turn_no_input_format = (
104
+ "<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
105
+ )
106
+ self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
107
 
108
 
109
  class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
 
149
  cfg.train_on_inputs,
150
  cfg.sequence_len,
151
  )
152
+
153
+
154
+ def load_open_orca_chatml(tokenizer, cfg):
155
+ return OpenOrcaPromptTokenizingStrategy(
156
+ OpenOrcaSystemDataPrompter(PromptStyle.CHATML.value),
157
+ tokenizer,
158
+ cfg.train_on_inputs,
159
+ cfg.sequence_len,
160
+ )
src/axolotl/prompters.py CHANGED
@@ -16,6 +16,7 @@ class PromptStyle(Enum):
16
 
17
  INSTRUCT = "instruct"
18
  CHAT = "chat"
 
19
 
20
 
21
  class AlpacaPrompter:
@@ -25,6 +26,7 @@ class AlpacaPrompter:
25
 
26
  system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
27
  system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
 
28
  turn_format: str
29
  turn_no_input_format: str
30
  prompt_style: Optional[PromptStyle] = None
@@ -34,14 +36,23 @@ class AlpacaPrompter:
34
  self.match_prompt_style()
35
 
36
  def match_prompt_style(self):
 
37
  if self.prompt_style == PromptStyle.INSTRUCT.value:
38
  self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
39
  self.turn_no_input_format = (
40
  "### Instruction:\n{instruction}\n\n### Response:\n"
41
  )
 
42
  if self.prompt_style == PromptStyle.CHAT.value:
43
  self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
44
  self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
 
 
 
 
 
 
 
45
 
46
  def build_prompt(
47
  self,
 
16
 
17
  INSTRUCT = "instruct"
18
  CHAT = "chat"
19
+ CHATML = "chatml"
20
 
21
 
22
  class AlpacaPrompter:
 
26
 
27
  system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
28
  system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
29
+ system_format: str
30
  turn_format: str
31
  turn_no_input_format: str
32
  prompt_style: Optional[PromptStyle] = None
 
36
  self.match_prompt_style()
37
 
38
  def match_prompt_style(self):
39
+ # pylint: disable=duplicate-code
40
  if self.prompt_style == PromptStyle.INSTRUCT.value:
41
  self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
42
  self.turn_no_input_format = (
43
  "### Instruction:\n{instruction}\n\n### Response:\n"
44
  )
45
+ self.system_format = "### System:\n{system}\n\n"
46
  if self.prompt_style == PromptStyle.CHAT.value:
47
  self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
48
  self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
49
+ self.system_format = "SYSTEM: {system}\n"
50
+ if self.prompt_style == PromptStyle.CHATML.value:
51
+ self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
52
+ self.turn_no_input_format = (
53
+ "<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
54
+ )
55
+ self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
56
 
57
  def build_prompt(
58
  self,
src/axolotl/utils/collators.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DataCollator for axolotl to pad labels and position_ids for packed sequences
3
+ """
4
+ from dataclasses import dataclass
5
+ from typing import Any, Optional, Union
6
+
7
+ import numpy as np
8
+ from transformers import PreTrainedTokenizerBase
9
+ from transformers.utils import PaddingStrategy
10
+
11
+
12
+ @dataclass
13
+ class DataCollatorForSeq2Seq:
14
+ """
15
+ Data collator that will dynamically pad the inputs received, as well as the labels and position_ids
16
+
17
+ Args:
18
+ tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
19
+ The tokenizer used for encoding the data.
20
+ model ([`PreTrainedModel`]):
21
+ The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to
22
+ prepare the *decoder_input_ids*
23
+
24
+ This is useful when using *label_smoothing* to avoid calculating loss twice.
25
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
26
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
27
+ among:
28
+
29
+ - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
30
+ sequence is provided).
31
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
32
+ acceptable input length for the model if that argument is not provided.
33
+ - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
34
+ max_length (`int`, *optional*):
35
+ Maximum length of the returned list and optionally padding length (see above).
36
+ pad_to_multiple_of (`int`, *optional*):
37
+ If set will pad the sequence to a multiple of the provided value.
38
+
39
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
40
+ 7.5 (Volta).
41
+ label_pad_token_id (`int`, *optional*, defaults to -100):
42
+ The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
43
+ return_tensors (`str`):
44
+ The type of Tensor to return. Allowable values are "np", "pt" and "tf".
45
+ """
46
+
47
+ tokenizer: PreTrainedTokenizerBase
48
+ model: Optional[Any] = None
49
+ padding: Union[bool, str, PaddingStrategy] = True
50
+ max_length: Optional[int] = None
51
+ pad_to_multiple_of: Optional[int] = None
52
+ label_pad_token_id: int = -100
53
+ position_pad_token_id: int = 0
54
+ return_tensors: str = "pt"
55
+
56
+ def __call__(self, features, return_tensors=None):
57
+ labels = None
58
+ if return_tensors is None:
59
+ return_tensors = self.return_tensors
60
+
61
+ for feature_name, pad_token_id in [
62
+ ("labels", self.label_pad_token_id),
63
+ ("position_ids", self.position_pad_token_id),
64
+ ]:
65
+ feat = (
66
+ [feature[feature_name] for feature in features]
67
+ if feature_name in features[0].keys()
68
+ else None
69
+ )
70
+ labels = feat if feat and feature_name == "labels" else labels
71
+ # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
72
+ # same length to return tensors.
73
+ if feat is not None:
74
+ max_feature_length = max(len(l) for l in feat) # noqa: E741
75
+ if self.pad_to_multiple_of is not None:
76
+ max_feature_length = (
77
+ (max_feature_length + self.pad_to_multiple_of - 1)
78
+ // self.pad_to_multiple_of
79
+ * self.pad_to_multiple_of
80
+ )
81
+
82
+ padding_side = self.tokenizer.padding_side
83
+ for feature in features:
84
+ remainder = [pad_token_id] * (
85
+ max_feature_length - len(feature[feature_name])
86
+ )
87
+ if isinstance(feature[feature_name], list):
88
+ feature[feature_name] = (
89
+ feature[feature_name] + remainder
90
+ if padding_side == "right"
91
+ else remainder + feature[feature_name]
92
+ )
93
+ elif padding_side == "right":
94
+ feature[feature_name] = np.concatenate(
95
+ [feature[feature_name], remainder]
96
+ ).astype(np.int64)
97
+ else:
98
+ feature[feature_name] = np.concatenate(
99
+ [remainder, feature[feature_name]]
100
+ ).astype(np.int64)
101
+
102
+ features = self.tokenizer.pad(
103
+ features,
104
+ padding=self.padding,
105
+ max_length=self.max_length,
106
+ pad_to_multiple_of=self.pad_to_multiple_of,
107
+ return_tensors=return_tensors,
108
+ )
109
+
110
+ # prepare decoder_input_ids
111
+ if (
112
+ labels is not None
113
+ and self.model is not None
114
+ and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
115
+ ):
116
+ decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(
117
+ labels=features["labels"]
118
+ )
119
+ features["decoder_input_ids"] = decoder_input_ids
120
+
121
+ return features
src/axolotl/utils/data.py CHANGED
@@ -1,13 +1,19 @@
1
  """Module containing data utilities"""
2
  import functools
3
- import itertools
4
  import logging
5
  from hashlib import md5
6
  from pathlib import Path
7
- from typing import List, Tuple, Union
8
 
9
  import torch
10
- from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
 
 
 
 
 
 
11
  from huggingface_hub import hf_hub_download
12
  from transformers import PreTrainedTokenizerBase
13
 
@@ -35,6 +41,7 @@ from axolotl.prompters import (
35
  ShareGPTPrompter,
36
  SummarizeTLDRPrompter,
37
  )
 
38
 
39
  LOG = logging.getLogger("axolotl")
40
 
@@ -109,6 +116,7 @@ def load_tokenized_prepared_datasets(
109
  local_path = Path(d.path)
110
  if local_path.exists():
111
  if local_path.is_dir():
 
112
  ds = load_dataset(
113
  d.path,
114
  name=d.name,
@@ -262,20 +270,12 @@ def load_tokenized_prepared_datasets(
262
  raise ValueError(
263
  f"unhandled prompt tokenization strategy: {d.type} {suffix}"
264
  )
265
- LOG.info("tokenizing, merging, and shuffling master dataset")
266
-
267
- samples: List[int] = []
268
- chunk_size = 1000
269
- for d in datasets:
270
- d_iter = iter(d)
271
- while True:
272
- chunk = list(itertools.islice(d_iter, chunk_size))
273
- if not chunk:
274
- break
275
- samples.extend(chunk)
276
-
277
- LOG.info("shuffle")
278
- dataset = Dataset.from_list(samples).shuffle(seed=seed)
279
  if cfg.local_rank == 0:
280
  LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
281
  dataset.save_to_disk(prepared_ds_path)
@@ -374,6 +374,7 @@ def load_prepare_datasets(
374
  dataset = Dataset.from_list(list(constant_len_dataset))
375
 
376
  # filter out bad data
 
377
  dataset = Dataset.from_list(
378
  [
379
  d
@@ -413,7 +414,51 @@ def load_prepare_datasets(
413
  )
414
 
415
  if cfg.val_set_size:
416
- dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
  train_dataset = dataset["train"]
418
  eval_dataset = dataset["test"]
419
  else:
 
1
  """Module containing data utilities"""
2
  import functools
3
+ import hashlib
4
  import logging
5
  from hashlib import md5
6
  from pathlib import Path
7
+ from typing import Tuple, Union
8
 
9
  import torch
10
+ from datasets import (
11
+ Dataset,
12
+ DatasetDict,
13
+ concatenate_datasets,
14
+ load_dataset,
15
+ load_from_disk,
16
+ )
17
  from huggingface_hub import hf_hub_download
18
  from transformers import PreTrainedTokenizerBase
19
 
 
41
  ShareGPTPrompter,
42
  SummarizeTLDRPrompter,
43
  )
44
+ from axolotl.utils.distributed import barrier, is_main_process
45
 
46
  LOG = logging.getLogger("axolotl")
47
 
 
116
  local_path = Path(d.path)
117
  if local_path.exists():
118
  if local_path.is_dir():
119
+ # TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
120
  ds = load_dataset(
121
  d.path,
122
  name=d.name,
 
270
  raise ValueError(
271
  f"unhandled prompt tokenization strategy: {d.type} {suffix}"
272
  )
273
+ LOG.info("merging datasets")
274
+ dataset = concatenate_datasets(datasets)
275
+
276
+ if len(datasets) > 1:
277
+ LOG.info("shuffle merged datasets")
278
+ dataset = dataset.shuffle(seed=seed)
 
 
 
 
 
 
 
 
279
  if cfg.local_rank == 0:
280
  LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
281
  dataset.save_to_disk(prepared_ds_path)
 
374
  dataset = Dataset.from_list(list(constant_len_dataset))
375
 
376
  # filter out bad data
377
+ # TODO convert to dataset.filter(...)
378
  dataset = Dataset.from_list(
379
  [
380
  d
 
414
  )
415
 
416
  if cfg.val_set_size:
417
+ # ensure we end up with the same fingerprint by doing rank0 first and being able to cache
418
+ to_hash_train = (
419
+ dataset._fingerprint # pylint: disable=protected-access
420
+ + "|"
421
+ + str(cfg.val_set_size)
422
+ + "|"
423
+ + "train"
424
+ + "|"
425
+ + str(cfg.seed or 42)
426
+ )
427
+ to_hash_test = (
428
+ dataset._fingerprint # pylint: disable=protected-access
429
+ + "|"
430
+ + str(cfg.val_set_size)
431
+ + "|"
432
+ + "test"
433
+ + "|"
434
+ + str(cfg.seed or 42)
435
+ )
436
+ train_fingerprint = hashlib.md5(
437
+ to_hash_train.encode(), usedforsecurity=False
438
+ ).hexdigest()
439
+ test_fingerprint = hashlib.md5(
440
+ to_hash_test.encode(), usedforsecurity=False
441
+ ).hexdigest()
442
+
443
+ if is_main_process():
444
+ dataset = dataset.train_test_split(
445
+ test_size=cfg.val_set_size,
446
+ shuffle=False,
447
+ seed=cfg.seed or 42,
448
+ train_new_fingerprint=train_fingerprint,
449
+ test_new_fingerprint=test_fingerprint,
450
+ )
451
+ barrier()
452
+ if not is_main_process():
453
+ dataset = dataset.train_test_split(
454
+ test_size=cfg.val_set_size,
455
+ shuffle=False,
456
+ seed=cfg.seed or 42,
457
+ train_new_fingerprint=train_fingerprint,
458
+ test_new_fingerprint=test_fingerprint,
459
+ )
460
+ barrier()
461
+
462
  train_dataset = dataset["train"]
463
  eval_dataset = dataset["test"]
464
  else:
src/axolotl/utils/dataloader.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ import hashlib
3
+ import itertools
4
+ import logging
5
+ import math
6
+ from typing import Any, Callable, List, Union
7
+
8
+ import numba
9
+ import numpy as np
10
+ from torch.utils.data import DistributedSampler, Sampler
11
+
12
+ LOG = logging.getLogger("axolotl.utils.dataloader")
13
+
14
+
15
+ @numba.njit
16
+ def ffd_check(a: np.ndarray, c: int, n: int):
17
+ # First-fit-decreasing bin packing
18
+ # Check if a[] could fit in n bins with capacity c
19
+ # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
20
+
21
+ a = np.sort(a)[::-1]
22
+ bins = np.full((n,), c, dtype=a.dtype)
23
+ for size in a:
24
+ not_found = True
25
+ for idx in range(n):
26
+ if bins[idx] >= size:
27
+ bins[idx] -= size
28
+ not_found = False
29
+ break
30
+
31
+ if not_found:
32
+ return False
33
+
34
+ return True
35
+
36
+
37
+ @numba.njit
38
+ def ffd_with_result(a: np.ndarray, c: int, start_index: int):
39
+ # First-fit-decreasing bin packing (with result return)
40
+
41
+ indices = np.argsort(a)[::-1]
42
+ a = a[indices]
43
+
44
+ bins: List[Any] = []
45
+ bins_result: List[Any] = []
46
+ for a_id, size in enumerate(a):
47
+ add_new = True
48
+ for idx in range(len(bins)):
49
+ if bins[idx] >= size:
50
+ bins[idx] -= size
51
+ bins_result[idx].append(indices[a_id] + start_index)
52
+ add_new = False
53
+ break
54
+
55
+ if add_new:
56
+ bins.append(c - size)
57
+ bins_result.append([indices[a_id] + start_index])
58
+
59
+ return bins_result, len(a)
60
+
61
+
62
+ @numba.njit
63
+ def allocate(
64
+ lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
65
+ ):
66
+ """
67
+ :param lengths: array of lengths of each sample
68
+ :param lengths_cumsum: cumulative sum of consecutive lengths
69
+ :param rank: rank for this process
70
+ :param c: length of tokens per batch
71
+ :param n: number of ranks
72
+ :return:
73
+ """
74
+ # Dynamic batch allocator, similar to Multifit
75
+ # https://en.wikipedia.org/wiki/Multifit_algorithm
76
+ # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
77
+
78
+ s = 0
79
+ start_index = 0
80
+ result = []
81
+ result_totseqs = []
82
+
83
+ while True:
84
+ # binary search [left, right)
85
+ left = 1
86
+ right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
87
+
88
+ while right - left > 1:
89
+ mid = (left + right) // 2
90
+ if ffd_check(lengths[start_index : start_index + mid], c, n):
91
+ left = mid
92
+ else:
93
+ right = mid
94
+
95
+ # use length left
96
+ batch, tot_seqs = ffd_with_result(
97
+ lengths[start_index : start_index + left], c, start_index
98
+ )
99
+ if len(batch) < n:
100
+ break
101
+
102
+ start_index += left
103
+ s = lengths_cumsum[start_index - 1]
104
+
105
+ # add local rank
106
+ result.append(batch[rank])
107
+ # add total seqs for all ranks
108
+ result_totseqs.append(tot_seqs)
109
+ # yield batch[rank], tot_seqs, s, len(result) * c * n
110
+ return result, result_totseqs, s, len(result) * c * n
111
+
112
+
113
+ def chunk(iterable, n):
114
+ """
115
+ Chunk data into tuples of length n
116
+ """
117
+ # batched('ABCDEFG', 3) --> ABC DEF G
118
+ if n < 1:
119
+ raise ValueError("n must be at least one")
120
+ it = iter(iterable)
121
+ while batch := tuple(itertools.islice(it, n)):
122
+ yield batch
123
+
124
+
125
+ def hash_indices(lst: List[int]) -> str:
126
+ # Convert the list of integers to a string representation
127
+ concatenated = ",".join(map(str, lst))
128
+
129
+ # Generate the hash
130
+ sha256 = hashlib.sha256()
131
+ sha256.update(concatenated.encode())
132
+
133
+ return sha256.hexdigest()
134
+
135
+
136
+ class MultipackDistributedDataloader:
137
+ """Unpadded data loading using Multipack.
138
+ Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py
139
+ Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.
140
+ """
141
+
142
+ def __init__(
143
+ self,
144
+ dataset: Any,
145
+ collate_fn: Callable,
146
+ seq_max_length: int = 2048,
147
+ batch_size: int = 1,
148
+ sampler: Union[Sampler, DistributedSampler] = None,
149
+ packing_efficiency_estimate: float = 1.0,
150
+ sample_packing_seq_len_multiplier: int = 1,
151
+ device_count: int = 1,
152
+ ):
153
+ # Dataset
154
+ self.dataset = dataset
155
+ self.lengths = (
156
+ dataset.data.column("position_ids")
157
+ .to_pandas()
158
+ .apply(lambda x: x[-1] + 1)
159
+ .values
160
+ )
161
+ assert isinstance(self.lengths, np.ndarray)
162
+ assert batch_size % sample_packing_seq_len_multiplier == 0
163
+ assert batch_size >= sample_packing_seq_len_multiplier
164
+ self.sampler = sampler
165
+ self.batch_size = batch_size
166
+ self.sample_packing_seq_len_multiplier = sample_packing_seq_len_multiplier
167
+ self.seq_max_length = seq_max_length
168
+ self.batch_max_length = batch_size * seq_max_length
169
+ self.collate_fn = collate_fn
170
+
171
+ self.num_replicas = 1
172
+ self.rank = 0
173
+
174
+ # statistics
175
+ self.eff_total_used = 0
176
+ self.eff_total_slots = 0
177
+ self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
178
+ self.device_count = device_count
179
+
180
+ def generate_batches(self, set_stats=False):
181
+ LOG.info("generating packed batches")
182
+ if self.sampler:
183
+ indices = [idx for idx in self.sampler]
184
+ else:
185
+ indices = range(0, len(self.dataset))
186
+
187
+ LOG.info(hash_indices(indices))
188
+ lengths = self.lengths[indices]
189
+ lengths_cumsum = np.cumsum(lengths)
190
+
191
+ batches, totseqs, total_used, total_slots = allocate(
192
+ lengths=lengths,
193
+ lengths_cumsum=lengths_cumsum,
194
+ rank=self.rank,
195
+ # c=self.batch_max_length,
196
+ c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
197
+ n=self.num_replicas,
198
+ )
199
+
200
+ batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
201
+
202
+ # statistics
203
+ if set_stats:
204
+ self.eff_total_used += total_used
205
+ self.eff_total_slots += total_slots
206
+
207
+ return batches, totseqs
208
+
209
+ def __iter__(self):
210
+ if hasattr(self.sampler, "set_epoch"):
211
+ new_epoch = self.sampler.epoch + 1
212
+ self.sampler.set_epoch(new_epoch)
213
+ LOG.info(f"calling sampler.set_epoch({new_epoch})")
214
+ all_batches, _ = self.generate_batches(set_stats=True)
215
+ features = self.dataset.features.keys()
216
+ len_remaining = self._len_est()
217
+ for batches in chunk(
218
+ all_batches, self.batch_size // self.sample_packing_seq_len_multiplier
219
+ ):
220
+ chunked_data = []
221
+ attn_mask_cum_idx = 0
222
+ for batch in batches:
223
+ concatenated = {}
224
+ batched_data = [self.dataset[batch_idx] for batch_idx in batch]
225
+ for feature in features:
226
+ if feature == "attention_mask":
227
+ arrays = [
228
+ (attn_mask_cum_idx + idx + 1) * np.array(item[feature])
229
+ for idx, item in enumerate(batched_data)
230
+ if feature in item
231
+ ]
232
+ attn_mask_cum_idx += len(batched_data)
233
+ concatenated[feature] = np.concatenate(arrays)
234
+ else:
235
+ arrays = [
236
+ np.array(item[feature])
237
+ for item in batched_data
238
+ if feature in item
239
+ ]
240
+ concatenated[feature] = np.concatenate(arrays)
241
+ chunked_data.append(concatenated)
242
+ yield self.collate_fn(chunked_data)
243
+ len_remaining -= 1
244
+ if not len_remaining:
245
+ return
246
+
247
+ def _len_est(self):
248
+ lengths_sum = np.sum(self.lengths)
249
+ lengths_sum_per_device = lengths_sum // self.device_count
250
+ LOG.info(
251
+ f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
252
+ f"total_num_tokens per device: {lengths_sum_per_device}"
253
+ )
254
+
255
+ # shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
256
+ return (
257
+ math.floor(
258
+ 0.99
259
+ * lengths_sum_per_device
260
+ / self.packing_efficiency_estimate
261
+ // self.seq_max_length
262
+ // self.batch_size
263
+ )
264
+ - 1
265
+ )
266
+
267
+ def __len__(self):
268
+ # this doesn't return the actual length b/c with distributed samplers, not all dataloaders get
269
+ # the same share of total tokens
270
+ # if not self.eff_total_used:
271
+ # batches, _ = self.generate_batches(set_stats=True)
272
+ # LOG.info(
273
+ # f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
274
+ # f"actual packing efficiency: {self.efficiency()}"
275
+ # )
276
+ return max(1, self._len_est())
277
+
278
+ def len_w_stats(self):
279
+ if not self.eff_total_used:
280
+ batches, _ = self.generate_batches(set_stats=True)
281
+ LOG.info(
282
+ f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
283
+ f"actual packing efficiency: {self.efficiency()}"
284
+ )
285
+ return max(1, self._len_est())
286
+
287
+ def efficiency(self):
288
+ return self.eff_total_used / self.eff_total_slots
src/axolotl/utils/distributed.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ utility helpers for distributed checks
3
+ """
4
+ import torch.distributed as dist
5
+ from accelerate import Accelerator
6
+
7
+ accelerate = None # pylint: disable=invalid-name
8
+
9
+
10
+ def load_accelerate():
11
+ global accelerate # pylint: disable=global-statement
12
+ accelerate = Accelerator()
13
+
14
+
15
+ def is_distributed():
16
+ """
17
+ Check if distributed training is initialized.
18
+ """
19
+ global accelerate # pylint: disable=global-statement
20
+ if not accelerate:
21
+ accelerate = Accelerator()
22
+ return dist.is_available() and dist.is_initialized()
23
+
24
+
25
+ def barrier():
26
+ """
27
+ Acts as a barrier to wait for all processes. This ensures that all processes
28
+ reach the barrier before proceeding further.
29
+ """
30
+ if is_distributed():
31
+ dist.barrier()
32
+
33
+
34
+ def is_main_process():
35
+ """
36
+ Check if the current process is the main process.
37
+ If not in distributed mode, always return True.
38
+ """
39
+ if not is_distributed():
40
+ return True
41
+ return dist.get_rank() == 0
src/axolotl/utils/models.py CHANGED
@@ -37,20 +37,26 @@ def load_tokenizer(
37
  tokenizer_type,
38
  cfg,
39
  ):
 
40
  use_fast = True # this is the default
41
  if cfg.tokenizer_use_fast is not None:
42
  use_fast = cfg.tokenizer_use_fast
 
 
 
43
  if tokenizer_type:
44
  tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
45
  tokenizer_config,
46
  trust_remote_code=cfg.trust_remote_code or False,
47
  use_fast=use_fast,
 
48
  )
49
  else:
50
  tokenizer = AutoTokenizer.from_pretrained(
51
  tokenizer_config,
52
  trust_remote_code=cfg.trust_remote_code or False,
53
  use_fast=use_fast,
 
54
  )
55
 
56
  LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
@@ -90,8 +96,10 @@ def load_model(
90
 
91
  # TODO refactor as a kwarg
92
  load_in_8bit = cfg.load_in_8bit
93
- cfg.is_llama_derived_model = "llama" in base_model or (
94
- cfg.model_type and "llama" in cfg.model_type.lower()
 
 
95
  )
96
 
97
  if cfg.is_llama_derived_model and cfg.flash_attention:
@@ -136,6 +144,14 @@ def load_model(
136
  LOG.info("patching with xpos rope")
137
  replace_llama_rope_with_xpos_rope()
138
 
 
 
 
 
 
 
 
 
139
  if cfg.bf16 or cfg.bfloat16:
140
  torch_dtype = torch.bfloat16
141
  elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
@@ -228,7 +244,6 @@ def load_model(
228
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
229
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
230
  torch_dtype=torch_dtype,
231
- device_map="auto" if cfg.world_size == 1 else cfg.device_map,
232
  **model_kwargs,
233
  )
234
  # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
@@ -263,7 +278,6 @@ def load_model(
263
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
264
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
265
  torch_dtype=torch_dtype,
266
- device_map=cfg.device_map,
267
  trust_remote_code=cfg.trust_remote_code or False,
268
  **model_kwargs,
269
  )
@@ -294,7 +308,6 @@ def load_model(
294
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
295
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
296
  torch_dtype=torch_dtype,
297
- device_map=cfg.device_map,
298
  trust_remote_code=cfg.trust_remote_code or False,
299
  **model_kwargs,
300
  )
@@ -308,7 +321,6 @@ def load_model(
308
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
309
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
310
  torch_dtype=torch_dtype,
311
- device_map=cfg.device_map,
312
  trust_remote_code=cfg.trust_remote_code or False,
313
  **model_kwargs,
314
  )
 
37
  tokenizer_type,
38
  cfg,
39
  ):
40
+ tokenizer_kwargs = {}
41
  use_fast = True # this is the default
42
  if cfg.tokenizer_use_fast is not None:
43
  use_fast = cfg.tokenizer_use_fast
44
+ if cfg.tokenizer_legacy is not None:
45
+ # True is the default w/ https://github.com/huggingface/transformers/pull/25224
46
+ tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
47
  if tokenizer_type:
48
  tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
49
  tokenizer_config,
50
  trust_remote_code=cfg.trust_remote_code or False,
51
  use_fast=use_fast,
52
+ **tokenizer_kwargs,
53
  )
54
  else:
55
  tokenizer = AutoTokenizer.from_pretrained(
56
  tokenizer_config,
57
  trust_remote_code=cfg.trust_remote_code or False,
58
  use_fast=use_fast,
59
+ **tokenizer_kwargs,
60
  )
61
 
62
  LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
 
96
 
97
  # TODO refactor as a kwarg
98
  load_in_8bit = cfg.load_in_8bit
99
+ cfg.is_llama_derived_model = (
100
+ "llama" in base_model
101
+ or (cfg.model_type and "llama" in cfg.model_type.lower())
102
+ or cfg.is_llama_derived_model
103
  )
104
 
105
  if cfg.is_llama_derived_model and cfg.flash_attention:
 
144
  LOG.info("patching with xpos rope")
145
  replace_llama_rope_with_xpos_rope()
146
 
147
+ if cfg.is_llama_derived_model and (
148
+ cfg.max_packed_sequence_len or cfg.sample_packing
149
+ ):
150
+ from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
151
+
152
+ LOG.info("patching _expand_mask")
153
+ hijack_expand_mask()
154
+
155
  if cfg.bf16 or cfg.bfloat16:
156
  torch_dtype = torch.bfloat16
157
  elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
 
244
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
245
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
246
  torch_dtype=torch_dtype,
 
247
  **model_kwargs,
248
  )
249
  # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
 
278
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
279
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
280
  torch_dtype=torch_dtype,
 
281
  trust_remote_code=cfg.trust_remote_code or False,
282
  **model_kwargs,
283
  )
 
308
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
309
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
310
  torch_dtype=torch_dtype,
 
311
  trust_remote_code=cfg.trust_remote_code or False,
312
  **model_kwargs,
313
  )
 
321
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
322
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
323
  torch_dtype=torch_dtype,
 
324
  trust_remote_code=cfg.trust_remote_code or False,
325
  **model_kwargs,
326
  )
src/axolotl/utils/trainer.py CHANGED
@@ -1,19 +1,23 @@
1
  """Module containing the Trainer class and related functions"""
2
-
3
  import importlib
4
  import logging
5
  import math
6
  import os
7
  import sys
 
8
  from dataclasses import dataclass, field
 
9
  from pathlib import Path
10
- from typing import Optional
11
 
12
  import bitsandbytes as bnb
 
13
  import torch.cuda
14
  import transformers
 
15
  from torch import nn
16
  from torch.optim.lr_scheduler import OneCycleLR
 
17
  from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
18
  from transformers.trainer_pt_utils import get_parameter_names
19
 
@@ -22,6 +26,8 @@ from axolotl.utils.callbacks import (
22
  SaveBetterTransformerModelCallback,
23
  SavePeftModelCallback,
24
  )
 
 
25
  from axolotl.utils.schedulers import (
26
  InterpolatingLogScheduler,
27
  get_cosine_schedule_with_quadratic_warmup,
@@ -30,6 +36,68 @@ from axolotl.utils.schedulers import (
30
  LOG = logging.getLogger("axolotl")
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  @dataclass
34
  class AxolotlTrainingArguments(TrainingArguments):
35
  """
@@ -40,6 +108,22 @@ class AxolotlTrainingArguments(TrainingArguments):
40
  default=False,
41
  metadata={"help": "Use quadratic warmup for cosine scheduling."},
42
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
 
45
  class AxolotlTrainer(Trainer):
@@ -77,6 +161,64 @@ class AxolotlTrainer(Trainer):
77
  return super().create_scheduler(num_training_steps, optimizer)
78
  return self.lr_scheduler
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  class OneCycleLRSchedulerTrainer(AxolotlTrainer):
82
  """
@@ -107,10 +249,121 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer):
107
  return self.lr_scheduler
108
 
109
 
110
- def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
111
- total_num_steps = int(
112
- math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
113
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  warmup_steps = (
115
  cfg.warmup_steps
116
  if cfg.warmup_steps is not None
@@ -190,7 +443,14 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
190
  if cfg.save_safetensors:
191
  training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
192
 
 
 
 
 
 
193
  training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
 
 
194
  per_device_train_batch_size=cfg.micro_batch_size,
195
  per_device_eval_batch_size=cfg.eval_batch_size
196
  if cfg.eval_batch_size is not None
@@ -204,7 +464,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
204
  eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
205
  save_steps=cfg.save_steps,
206
  output_dir=cfg.output_dir,
207
- save_total_limit=3,
208
  load_best_model_at_end=(
209
  cfg.load_best_model_at_end is not False
210
  and cfg.val_set_size > 0
@@ -222,6 +482,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
222
  if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep")
223
  else "cosine",
224
  weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
 
 
225
  **training_arguments_kwargs,
226
  )
227
 
@@ -316,11 +578,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
316
  if cfg.collator_pad_to_longest:
317
  data_collator_kwargs["padding"] = "longest"
318
  else:
319
- data_collator_kwargs["pad_to_multiple_of"] = 8
 
 
320
 
321
  if cfg.is_llama_derived_model and cfg.landmark_attention:
322
- from functools import partial
323
-
324
  from axolotl.monkeypatch.llama_landmark_attn import (
325
  add_mem_tokens,
326
  get_mem_id,
@@ -348,7 +610,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
348
  train_dataset=train_dataset,
349
  eval_dataset=eval_dataset,
350
  args=training_args,
351
- data_collator=transformers.DataCollatorForSeq2Seq(
352
  tokenizer,
353
  return_tensors="pt",
354
  **data_collator_kwargs,
 
1
  """Module containing the Trainer class and related functions"""
 
2
  import importlib
3
  import logging
4
  import math
5
  import os
6
  import sys
7
+ from contextlib import contextmanager
8
  from dataclasses import dataclass, field
9
+ from functools import partial
10
  from pathlib import Path
11
+ from typing import Optional, Union
12
 
13
  import bitsandbytes as bnb
14
+ import numpy as np
15
  import torch.cuda
16
  import transformers
17
+ from datasets import Dataset, set_caching_enabled
18
  from torch import nn
19
  from torch.optim.lr_scheduler import OneCycleLR
20
+ from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
21
  from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
22
  from transformers.trainer_pt_utils import get_parameter_names
23
 
 
26
  SaveBetterTransformerModelCallback,
27
  SavePeftModelCallback,
28
  )
29
+ from axolotl.utils.collators import DataCollatorForSeq2Seq
30
+ from axolotl.utils.dataloader import MultipackDistributedDataloader
31
  from axolotl.utils.schedulers import (
32
  InterpolatingLogScheduler,
33
  get_cosine_schedule_with_quadratic_warmup,
 
36
  LOG = logging.getLogger("axolotl")
37
 
38
 
39
+ @torch.jit.script
40
+ def weighted_cross_entropy(
41
+ logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor
42
+ ):
43
+ # Flatten the logits, labels, and weights tensors
44
+ logits = logits.view(
45
+ -1, logits.size(-1)
46
+ ) # logits becomes of shape [batch_size*sequence_length, vocab_size]
47
+ labels = labels.view(-1) # labels becomes of shape [batch_size*sequence_length]
48
+ weights = weights.view(-1) # weights becomes of shape [batch_size*sequence_length]
49
+
50
+ # Compute the unweighted cross entropy loss
51
+ losses = torch.nn.functional.cross_entropy(logits, labels, reduction="none")
52
+
53
+ # Apply the weights to the losses and compute their sum
54
+ return (weights * losses).sum()
55
+
56
+
57
+ @torch.jit.script
58
+ def create_weighted_mask(labels: torch.Tensor):
59
+ # Check if the tensor is 2D. If not, unsqueeze it to make it 2D
60
+ if len(labels.shape) == 1:
61
+ labels = labels.unsqueeze(0)
62
+
63
+ weights = torch.zeros_like(labels).float()
64
+ for i in range(labels.shape[0]):
65
+ mask = labels[i] != -100
66
+
67
+ # Create a tensor to track group ids
68
+ group_ids = torch.zeros_like(labels[i]).int()
69
+ curr_group_id = 0
70
+
71
+ for j in range(1, len(labels[i])):
72
+ if mask[j] and not mask[j - 1]: # switch from masked to unmasked label
73
+ curr_group_id += 1 # start new group
74
+ group_ids[j] = (
75
+ curr_group_id if mask[j] else 0
76
+ ) # assign group id if unmasked label
77
+
78
+ # Count only unmasked labels in each group
79
+ group_counts = torch.bincount(group_ids[mask])
80
+
81
+ mask_weights = torch.zeros_like(labels[i]).float()
82
+ mask_weights[mask] = 1.0 / group_counts[group_ids[mask]]
83
+
84
+ weights[i] = mask_weights
85
+
86
+ return weights.squeeze() # squeeze the output to match the input dimension
87
+
88
+
89
+ def trainer_weighted_loss(model_output, labels, shift_labels=True):
90
+ logits = (
91
+ model_output["logits"] if isinstance(model_output, dict) else model_output[0]
92
+ )
93
+ if shift_labels:
94
+ logits = logits[..., :-1, :].contiguous()
95
+ labels = labels[..., 1:].contiguous()
96
+
97
+ weights = create_weighted_mask(labels)
98
+ return weighted_cross_entropy(logits, labels, weights)
99
+
100
+
101
  @dataclass
102
  class AxolotlTrainingArguments(TrainingArguments):
103
  """
 
108
  default=False,
109
  metadata={"help": "Use quadratic warmup for cosine scheduling."},
110
  )
111
+ sample_packing: bool = field(
112
+ default=False,
113
+ metadata={"help": "Use sample packing for efficient training."},
114
+ )
115
+ sample_packing_efficiency: float = field(
116
+ default=1.0,
117
+ metadata={"help": "Sample packing efficiency for calculating batch length."},
118
+ )
119
+ max_seq_length: int = field(
120
+ default=2048,
121
+ metadata={"help": "The maximum sequence length the model can handle"},
122
+ )
123
+ sample_packing_seq_len_multiplier: int = field(
124
+ default=1,
125
+ metadata={"help": "the multiplier for the max len for packed sequences"},
126
+ )
127
 
128
 
129
  class AxolotlTrainer(Trainer):
 
161
  return super().create_scheduler(num_training_steps, optimizer)
162
  return self.lr_scheduler
163
 
164
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
165
+ if self.args.world_size > 1 and self.args.sample_packing:
166
+ return DistributedSampler(
167
+ self.train_dataset,
168
+ num_replicas=self.args.world_size,
169
+ rank=self.args.process_index,
170
+ seed=self.args.seed,
171
+ )
172
+ return super()._get_train_sampler()
173
+
174
+ def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
175
+ if self.args.sample_packing:
176
+ train_sampler = self._get_train_sampler()
177
+ return self.accelerator.prepare(
178
+ MultipackDistributedDataloader(
179
+ self.train_dataset,
180
+ batch_size=self._train_batch_size,
181
+ seq_max_length=self.args.max_seq_length,
182
+ collate_fn=self.data_collator,
183
+ sampler=train_sampler,
184
+ packing_efficiency_estimate=self.args.sample_packing_efficiency,
185
+ sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
186
+ device_count=int(os.environ.get("WORLD_SIZE", 1)),
187
+ )
188
+ )
189
+ return super().get_train_dataloader()
190
+
191
+ def get_eval_dataloader(
192
+ self, eval_dataset: Optional[Dataset] = None
193
+ ) -> Union[DataLoader, MultipackDistributedDataloader]:
194
+ if self.args.sample_packing:
195
+ eval_dataset = (
196
+ eval_dataset if eval_dataset is not None else self.eval_dataset
197
+ )
198
+ eval_sampler = self._get_eval_sampler(eval_dataset)
199
+ return self.accelerator.prepare(
200
+ MultipackDistributedDataloader(
201
+ eval_dataset,
202
+ batch_size=self.args.eval_batch_size,
203
+ seq_max_length=self.args.max_seq_length,
204
+ collate_fn=self.data_collator,
205
+ sampler=eval_sampler,
206
+ packing_efficiency_estimate=self.args.sample_packing_efficiency,
207
+ sample_packing_seq_len_multiplier=self.args.eval_batch_size,
208
+ device_count=int(os.environ.get("WORLD_SIZE", 1)),
209
+ )
210
+ )
211
+ return super().get_eval_dataloader(eval_dataset)
212
+
213
+ def compute_loss(self, model, inputs, return_outputs=False):
214
+ # use one's weighted cross entropy loss calc
215
+ # if self.args.sample_packing:
216
+ # labels = inputs.pop("labels")
217
+ # outputs = model(**inputs)
218
+ # loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
219
+ # return (loss, outputs) if return_outputs else loss
220
+ return super().compute_loss(model, inputs, return_outputs=return_outputs)
221
+
222
 
223
  class OneCycleLRSchedulerTrainer(AxolotlTrainer):
224
  """
 
249
  return self.lr_scheduler
250
 
251
 
252
+ def add_position_ids(sample):
253
+ sample["position_ids"] = torch.arange(len(sample["input_ids"]))
254
+ return sample
255
+
256
+
257
+ def drop_long_seq(sample, sequence_len=2048):
258
+ return len(sample["input_ids"]) <= sequence_len
259
+
260
+
261
+ @contextmanager
262
+ def disable_datasets_caching():
263
+ try:
264
+ set_caching_enabled(False)
265
+ yield
266
+ finally:
267
+ set_caching_enabled(True)
268
+
269
+
270
+ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
271
+ if cfg.sample_packing:
272
+ drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
273
+ train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count()).map(
274
+ add_position_ids, num_proc=os.cpu_count()
275
+ )
276
+ if eval_dataset:
277
+ eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count()).map(
278
+ add_position_ids, num_proc=os.cpu_count()
279
+ )
280
+ return train_dataset, eval_dataset
281
+
282
+
283
+ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
284
+ if cfg.sample_packing:
285
+ # we have to drop anything longer then sequence len otherwise
286
+ # flash attention with position ids fails
287
+ if not cfg.total_num_tokens:
288
+ LOG.info("calculating total_num_tokens")
289
+ total_num_tokens = np.sum(
290
+ train_dataset.data.column("input_ids")
291
+ .to_pandas()
292
+ .apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
293
+ .values
294
+ )
295
+ LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
296
+ cfg.total_num_tokens = total_num_tokens
297
+
298
+ if cfg.sample_packing_eff_est:
299
+ total_num_steps = (
300
+ # match count to len est in dataloader
301
+ (
302
+ math.floor(
303
+ 0.99
304
+ * cfg.total_num_tokens
305
+ / cfg.sample_packing_eff_est
306
+ / cfg.sequence_len
307
+ // cfg.batch_size
308
+ // int(os.environ.get("WORLD_SIZE", 1))
309
+ )
310
+ - 1
311
+ )
312
+ * cfg.num_epochs
313
+ )
314
+ LOG.info(
315
+ f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
316
+ )
317
+ else:
318
+ sampler = RandomSampler(train_dataset)
319
+ data_loader = MultipackDistributedDataloader(
320
+ train_dataset,
321
+ batch_size=cfg.micro_batch_size,
322
+ seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len,
323
+ collate_fn=DataCollatorForSeq2Seq(
324
+ tokenizer,
325
+ return_tensors="pt",
326
+ padding="longest",
327
+ ),
328
+ sampler=sampler,
329
+ packing_efficiency_estimate=cfg.sample_packing_eff_est,
330
+ sample_packing_seq_len_multiplier=cfg.micro_batch_size,
331
+ device_count=int(os.environ.get("WORLD_SIZE", 1)),
332
+ )
333
+ data_loader_len = data_loader.len_w_stats()
334
+ actual_eff = data_loader.efficiency()
335
+ LOG.info(f"data_loader_len: {data_loader_len}")
336
+ total_num_steps = int(
337
+ math.floor(
338
+ data_loader_len
339
+ * cfg.micro_batch_size
340
+ * cfg.num_epochs
341
+ // cfg.batch_size
342
+ )
343
+ )
344
+ LOG.info(
345
+ f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`"
346
+ )
347
+ cfg.sample_packing_eff_est = math.ceil(actual_eff * 100.0) / 100.0
348
+ else:
349
+ total_num_steps = int(
350
+ math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
351
+ )
352
+ LOG.info(f"total_num_steps: {total_num_steps}")
353
+ return total_num_steps
354
+
355
+
356
+ def setup_fsdp_envs(cfg):
357
+ os.environ["ACCELERATE_USE_FSDP"] = "true"
358
+ if cfg.fsdp_config.fsdp_sync_module_states:
359
+ os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
360
+ if cfg.fsdp_config.fsdp_state_dict_type:
361
+ os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type
362
+
363
+
364
+ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
365
+ if cfg.fsdp:
366
+ setup_fsdp_envs(cfg)
367
  warmup_steps = (
368
  cfg.warmup_steps
369
  if cfg.warmup_steps is not None
 
443
  if cfg.save_safetensors:
444
  training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
445
 
446
+ if cfg.sample_packing_eff_est:
447
+ training_arguments_kwargs[
448
+ "sample_packing_efficiency"
449
+ ] = cfg.sample_packing_eff_est
450
+
451
  training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
452
+ # max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
453
+ max_seq_length=cfg.sequence_len,
454
  per_device_train_batch_size=cfg.micro_batch_size,
455
  per_device_eval_batch_size=cfg.eval_batch_size
456
  if cfg.eval_batch_size is not None
 
464
  eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
465
  save_steps=cfg.save_steps,
466
  output_dir=cfg.output_dir,
467
+ save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
468
  load_best_model_at_end=(
469
  cfg.load_best_model_at_end is not False
470
  and cfg.val_set_size > 0
 
482
  if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep")
483
  else "cosine",
484
  weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
485
+ sample_packing=cfg.sample_packing if cfg.sample_packing else False,
486
+ sample_packing_seq_len_multiplier=cfg.micro_batch_size,
487
  **training_arguments_kwargs,
488
  )
489
 
 
578
  if cfg.collator_pad_to_longest:
579
  data_collator_kwargs["padding"] = "longest"
580
  else:
581
+ # A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
582
+ # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
583
+ data_collator_kwargs["pad_to_multiple_of"] = 64
584
 
585
  if cfg.is_llama_derived_model and cfg.landmark_attention:
 
 
586
  from axolotl.monkeypatch.llama_landmark_attn import (
587
  add_mem_tokens,
588
  get_mem_id,
 
610
  train_dataset=train_dataset,
611
  eval_dataset=eval_dataset,
612
  args=training_args,
613
+ data_collator=DataCollatorForSeq2Seq(
614
  tokenizer,
615
  return_tensors="pt",
616
  **data_collator_kwargs,
src/axolotl/utils/validation.py CHANGED
@@ -8,6 +8,19 @@ LOG = logging.getLogger("axolotl")
8
 
9
 
10
  def validate_config(cfg):
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  if cfg.gradient_accumulation_steps and cfg.batch_size:
12
  raise ValueError(
13
  "please set only one of gradient_accumulation_steps or batch_size"
@@ -104,6 +117,17 @@ def validate_config(cfg):
104
  + "point to its path, and remove model_revision from the config."
105
  )
106
 
 
 
 
 
 
 
 
 
 
 
 
107
  # TODO
108
  # MPT 7b
109
  # https://github.com/facebookresearch/bitsandbytes/issues/25
 
8
 
9
 
10
  def validate_config(cfg):
11
+ if cfg.max_packed_sequence_len and cfg.sample_packing:
12
+ raise ValueError(
13
+ "please set only one of max_packed_sequence_len (deprecated soon) or sample_packing"
14
+ )
15
+ if cfg.max_packed_sequence_len:
16
+ LOG.warning(
17
+ str(
18
+ PendingDeprecationWarning(
19
+ "max_packed_sequence_len will be deprecated in favor of sample_packing"
20
+ )
21
+ )
22
+ )
23
+
24
  if cfg.gradient_accumulation_steps and cfg.batch_size:
25
  raise ValueError(
26
  "please set only one of gradient_accumulation_steps or batch_size"
 
117
  + "point to its path, and remove model_revision from the config."
118
  )
119
 
120
+ if cfg.sample_packing and cfg.sdp_attention:
121
+ # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
122
+ raise ValueError(
123
+ "sample_packing not compatible with sdp_attention. Use flash_attention"
124
+ )
125
+
126
+ if cfg.sample_packing and cfg.xformers_attention:
127
+ raise ValueError(
128
+ "sample_packing not compatible with xformers_attention. Use flash_attention"
129
+ )
130
+
131
  # TODO
132
  # MPT 7b
133
  # https://github.com/facebookresearch/bitsandbytes/issues/25
tests/monkeypatch/test_llama_attn_hijack_flash.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unit tests for the monkeypatch utils
3
+ """
4
+ import unittest
5
+
6
+ import torch
7
+
8
+ from axolotl.monkeypatch.utils import get_cu_seqlens, get_cu_seqlens_from_pos_ids
9
+
10
+
11
+ class TestMonkeyPatchUtils(unittest.TestCase):
12
+ """
13
+ Unit test class for monkeypatch utils
14
+ """
15
+
16
+ def test_get_cu_seqlens_1d(self):
17
+ attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
18
+ target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32)
19
+ self.assertTrue(torch.allclose(get_cu_seqlens(attn_mask)[0], target_res))
20
+
21
+ def test_get_cu_seqlens_from_pos_ids_1d(self):
22
+ position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0]])
23
+ target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32)
24
+ self.assertTrue(
25
+ torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
26
+ )
27
+
28
+
29
+ if __name__ == "__main__":
30
+ unittest.main()
tests/test_expand_mask.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unit tests for the monkey patch for expand mask to handle packed sequences
3
+ """
4
+ import unittest
5
+
6
+ import torch
7
+
8
+ from axolotl.monkeypatch.llama_expand_mask import _expand_mask
9
+
10
+
11
+ class TestExpandMask(unittest.TestCase):
12
+ """
13
+ Test class for attention mask expansion for packed sequences
14
+ """
15
+
16
+ def test_output(self):
17
+ mask = torch.tensor([[1, 1, 1, 2], [2, 3, 3, 0]])
18
+ dtype = torch.float32
19
+ expected_output = torch.tensor(
20
+ [
21
+ [
22
+ [
23
+ [0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
24
+ [0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
25
+ [0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
26
+ [-3.4028e38, -3.4028e38, -3.4028e38, 0.0000e00],
27
+ ]
28
+ ],
29
+ [
30
+ [
31
+ [0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
32
+ [-3.4028e38, 0.0000e00, -3.4028e38, -3.4028e38],
33
+ [-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38],
34
+ [-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38],
35
+ ]
36
+ ],
37
+ ]
38
+ )
39
+ # Check that the output matches the expected output
40
+ self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
41
+
42
+
43
+ if __name__ == "__main__":
44
+ unittest.main()
tests/test_packed_dataset.py CHANGED
@@ -27,7 +27,7 @@ class TestPacking(unittest.TestCase):
27
  }
28
  )
29
 
30
- def test_resets_attention(self):
31
  prompter = AlpacaPrompter("chat")
32
  strat = AlpacaPromptTokenizingStrategy(
33
  prompter,
@@ -55,10 +55,14 @@ class TestPacking(unittest.TestCase):
55
  # first example doesn't have mask reset
56
  assert example["input_ids"][0] == self.tokenizer.bos_token_id
57
  assert example["attention_mask"][0] == 1
 
 
58
 
59
  # but subsequent one does
60
  assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
61
- assert example["attention_mask"][next_bos_index] == 0
 
 
62
 
63
 
64
  if __name__ == "__main__":
 
27
  }
28
  )
29
 
30
+ def test_increments_attention(self):
31
  prompter = AlpacaPrompter("chat")
32
  strat = AlpacaPromptTokenizingStrategy(
33
  prompter,
 
55
  # first example doesn't have mask reset
56
  assert example["input_ids"][0] == self.tokenizer.bos_token_id
57
  assert example["attention_mask"][0] == 1
58
+ assert example["position_ids"][0] == 0
59
+ assert example["position_ids"][1] == 1
60
 
61
  # but subsequent one does
62
  assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
63
+ assert example["attention_mask"][next_bos_index] == 2
64
+ assert example["position_ids"][next_bos_index] == 0
65
+ assert example["position_ids"][next_bos_index + 1] == 1
66
 
67
 
68
  if __name__ == "__main__":
tests/test_prompt_tokenizers.py CHANGED
@@ -134,9 +134,15 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
134
  "output": "Hi! How can I help?",
135
  }
136
  example = strat.tokenize_prompt(sample)
137
- assert example["input_ids"][0:4] == [1, 835, 2184, 29901] # "<s>### System:"
138
- assert example["input_ids"][5:7] == [1509, 20118] # "use cot"
139
- assert example["input_ids"][9] == 11889 # USER
 
 
 
 
 
 
140
 
141
 
142
  class Llama2ChatTokenizationTest(unittest.TestCase):
 
134
  "output": "Hi! How can I help?",
135
  }
136
  example = strat.tokenize_prompt(sample)
137
+ assert example["input_ids"][0:5] == [
138
+ 1,
139
+ 28962,
140
+ 1254,
141
+ 12665,
142
+ 29901,
143
+ ] # "<s>SYSTEM:"
144
+ assert example["input_ids"][5:7] == [671, 20118] # " use cot"
145
+ assert example["input_ids"][8] == 11889 # USER
146
 
147
 
148
  class Llama2ChatTokenizationTest(unittest.TestCase):
tests/test_prompters.py CHANGED
@@ -70,7 +70,7 @@ class AlpacaPrompterTest(unittest.TestCase):
70
  )
71
  )
72
  assert "use cot" in res
73
- assert res.startswith("### System:")
74
  assert "### Instruction:" not in res
75
  assert "### Input:" not in res
76
  assert "alpacas" in res
 
70
  )
71
  )
72
  assert "use cot" in res
73
+ assert res.startswith("SYSTEM:")
74
  assert "### Instruction:" not in res
75
  assert "### Input:" not in res
76
  assert "alpacas" in res
tests/test_validation.py CHANGED
@@ -313,3 +313,27 @@ class ValidationTest(unittest.TestCase):
313
  )
314
 
315
  validate_config(cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  )
314
 
315
  validate_config(cfg)
316
+
317
+ def test_packing(self):
318
+ cfg = DictDefault(
319
+ {
320
+ "max_packed_sequence_len": 2048,
321
+ }
322
+ )
323
+ with self._caplog.at_level(logging.WARNING):
324
+ validate_config(cfg)
325
+ assert any(
326
+ "max_packed_sequence_len will be deprecated in favor of sample_packing"
327
+ in record.message
328
+ for record in self._caplog.records
329
+ )
330
+
331
+ cfg = DictDefault(
332
+ {
333
+ "max_packed_sequence_len": 2048,
334
+ "sample_packing": True,
335
+ }
336
+ )
337
+ regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
338
+ with pytest.raises(ValueError, match=regex_exp):
339
+ validate_config(cfg)