Attention mask and position id fixes for packing (#285)
Browse files* fix attetion mask with packing
* set position ids and use block diagonal attn mask
* fix expand mask for multiple batch items, make sure we pad position_ids
* don't move masks to cpu
* use multi pack dataloader w random sampler
* add position_ids back
* more fixes for dataloader integration
* est total tokens, fix field loop
* more fixes, position_ids seems broken
* more fixes for sample packing
* use distributed sampler, avoid accelerate prepare
* use accelerator prepare for dataloader
* fix for position_ids w packing
* Update src/axolotl/utils/dataloader.py
* validation for sample packing and doc
* more fixes for 4k and optimizations
* optimized expand mask fn
* better handling of variance in multipack dataloader length and trainer hanging when it runs out of data
* fix rounding of len of batches to int
* better handling so that all devices have the same dataloader len
* fix step calc for packing
* pass sample packing efficiency to training args
* add a test for the mask expansion for sequence packing
* only process eval dataset for packing if not None
* don't split batches when packing
* weighted CE losses
* weighted CEL fixes
* limit packing to sequences of max seq len
* seq_len_multiple for packing
* make sure the chunk size is an int
* sample_packing_seq_len_multiplier config
* use cumulative seq len with var len flash attn v2 w packing
* properly calculate max len
* fix flash-attn, xformers, packing, support chatml
* fix chatml system prompt for openorca, legacy tokenizer opts
* add chatml
* add unit tests for cum seq lens, add ability to build cu_seq_lens from positional ids, fix prompt test
* fix test and pylint checks
* more packing and dataset optimizations and fixes
* filter w multiple cpus
* more fixes and optimizations
* fixes and go back to distributed sampler since batch sampler won't work
* fix counts by accounting for num devices
* fix steps calculation
* previous accelerate is still most performant
* add numba to requirements.
* use custom distributed checks
* fix sampler to prevent overfit w new epochs
* let's not cleanup the cached datasets
* calculate cum seq lens with pos_ids instead of mask, simplify packing params, fix distributed barrier
* speed optimizations and set accelerate fsdp env vars
* optimize dataset concatenation?
* more optimizations for dataset handling
* fix import for annotation
* manual pre-commit fixes
* another sum optimization and bug fix for calc steps
* fix packing estimations
* fix formatting
* pylint problems
* add back flash attention branch for handling unpacked sequences seperately
* Address PR feedback
* add optional sample packing config params to readme
- README.md +9 -1
- requirements.txt +2 -0
- scripts/finetune.py +25 -7
- src/axolotl/datasets.py +32 -21
- src/axolotl/monkeypatch/llama_attn_hijack_flash.py +21 -1
- src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +1 -0
- src/axolotl/monkeypatch/llama_expand_mask.py +52 -0
- src/axolotl/monkeypatch/utils.py +103 -0
- src/axolotl/prompt_strategies/alpaca_w_system.py +22 -1
- src/axolotl/prompters.py +11 -0
- src/axolotl/utils/collators.py +121 -0
- src/axolotl/utils/data.py +63 -18
- src/axolotl/utils/dataloader.py +288 -0
- src/axolotl/utils/distributed.py +41 -0
- src/axolotl/utils/models.py +18 -6
- src/axolotl/utils/trainer.py +273 -11
- src/axolotl/utils/validation.py +24 -0
- tests/monkeypatch/test_llama_attn_hijack_flash.py +30 -0
- tests/test_expand_mask.py +44 -0
- tests/test_packed_dataset.py +6 -2
- tests/test_prompt_tokenizers.py +9 -3
- tests/test_prompters.py +1 -1
- tests/test_validation.py +24 -0
@@ -375,7 +375,14 @@ dataset_shard_idx:
|
|
375 |
sequence_len: 2048
|
376 |
# max sequence length to concatenate training samples together up to
|
377 |
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
|
|
378 |
max_packed_sequence_len: 1024
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
|
380 |
# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
381 |
adapter: lora
|
@@ -421,6 +428,7 @@ learning_rate: 0.00003
|
|
421 |
logging_steps:
|
422 |
save_steps:
|
423 |
eval_steps:
|
|
|
424 |
|
425 |
# save model as safetensors (require safetensors package)
|
426 |
save_safetensors:
|
@@ -534,7 +542,7 @@ accelerate launch scripts/finetune.py configs/your_config.yml
|
|
534 |
|
535 |
#### Multi-GPU
|
536 |
|
537 |
-
|
538 |
```bash
|
539 |
CUDA_VISIBLE_DEVICES="" accelerate ... --prepare_ds_only
|
540 |
```
|
|
|
375 |
sequence_len: 2048
|
376 |
# max sequence length to concatenate training samples together up to
|
377 |
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
378 |
+
# FutureWarning: This will soon be DEPRECATED
|
379 |
max_packed_sequence_len: 1024
|
380 |
+
# use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
|
381 |
+
sample_packing:
|
382 |
+
# you can set these packing optimizations AFTER starting a training at least once.
|
383 |
+
# The trainer will provide recommended values for these values.
|
384 |
+
sample_packing_eff_est:
|
385 |
+
total_num_tokens:
|
386 |
|
387 |
# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
388 |
adapter: lora
|
|
|
428 |
logging_steps:
|
429 |
save_steps:
|
430 |
eval_steps:
|
431 |
+
save_total_limit:
|
432 |
|
433 |
# save model as safetensors (require safetensors package)
|
434 |
save_safetensors:
|
|
|
542 |
|
543 |
#### Multi-GPU
|
544 |
|
545 |
+
You can optionally pre-tokenize dataset with the following before finetuning:
|
546 |
```bash
|
547 |
CUDA_VISIBLE_DEVICES="" accelerate ... --prepare_ds_only
|
548 |
```
|
@@ -13,6 +13,8 @@ einops
|
|
13 |
xformers
|
14 |
optimum
|
15 |
hf_transfer
|
|
|
|
|
16 |
# qlora things
|
17 |
bert-score==0.3.13
|
18 |
evaluate==0.4.0
|
|
|
13 |
xformers
|
14 |
optimum
|
15 |
hf_transfer
|
16 |
+
numba
|
17 |
+
numpy==1.24.4
|
18 |
# qlora things
|
19 |
bert-score==0.3.13
|
20 |
evaluate==0.4.0
|
@@ -21,9 +21,14 @@ from axolotl.logging_config import configure_logging
|
|
21 |
from axolotl.utils.bench import log_gpu_memory_usage
|
22 |
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
23 |
from axolotl.utils.dict import DictDefault
|
|
|
24 |
from axolotl.utils.models import load_model, load_tokenizer
|
25 |
from axolotl.utils.tokenization import check_dataset_labels
|
26 |
-
from axolotl.utils.trainer import
|
|
|
|
|
|
|
|
|
27 |
from axolotl.utils.validation import validate_config
|
28 |
from axolotl.utils.wandb import setup_wandb_env_vars
|
29 |
|
@@ -232,12 +237,25 @@ def train(
|
|
232 |
cfg.pretraining_dataset,
|
233 |
tokenizer,
|
234 |
max_tokens=cfg.sequence_len,
|
235 |
-
seed=cfg.seed,
|
236 |
)
|
237 |
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
238 |
train_dataset = train_dataset.with_format("torch")
|
239 |
eval_dataset = None
|
240 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
if cfg.debug or "debug" in kwargs:
|
242 |
LOG.info("check_dataset_labels...")
|
243 |
check_dataset_labels(
|
@@ -254,7 +272,7 @@ def train(
|
|
254 |
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
255 |
|
256 |
# Load the model and tokenizer
|
257 |
-
LOG.info("loading model and peft_config...")
|
258 |
model, peft_config = load_model(cfg, tokenizer)
|
259 |
|
260 |
safe_serialization = cfg.save_safetensors is True
|
@@ -288,7 +306,9 @@ def train(
|
|
288 |
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
289 |
return
|
290 |
|
291 |
-
trainer = setup_trainer(
|
|
|
|
|
292 |
|
293 |
model.config.use_cache = False
|
294 |
|
@@ -347,14 +367,12 @@ def train(
|
|
347 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
348 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
349 |
if cfg.fsdp:
|
350 |
-
|
351 |
elif cfg.local_rank == 0:
|
352 |
if cfg.flash_optimum:
|
353 |
model = BetterTransformer.reverse(model)
|
354 |
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
355 |
|
356 |
-
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
357 |
-
|
358 |
|
359 |
if __name__ == "__main__":
|
360 |
fire.Fire(train)
|
|
|
21 |
from axolotl.utils.bench import log_gpu_memory_usage
|
22 |
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
23 |
from axolotl.utils.dict import DictDefault
|
24 |
+
from axolotl.utils.distributed import barrier, is_main_process
|
25 |
from axolotl.utils.models import load_model, load_tokenizer
|
26 |
from axolotl.utils.tokenization import check_dataset_labels
|
27 |
+
from axolotl.utils.trainer import (
|
28 |
+
calculate_total_num_steps,
|
29 |
+
process_datasets_for_packing,
|
30 |
+
setup_trainer,
|
31 |
+
)
|
32 |
from axolotl.utils.validation import validate_config
|
33 |
from axolotl.utils.wandb import setup_wandb_env_vars
|
34 |
|
|
|
237 |
cfg.pretraining_dataset,
|
238 |
tokenizer,
|
239 |
max_tokens=cfg.sequence_len,
|
240 |
+
seed=cfg.seed or 42,
|
241 |
)
|
242 |
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
243 |
train_dataset = train_dataset.with_format("torch")
|
244 |
eval_dataset = None
|
245 |
|
246 |
+
if is_main_process():
|
247 |
+
# process on rank 0 first so it gets cached so other ranks load from cache
|
248 |
+
train_dataset, eval_dataset = process_datasets_for_packing(
|
249 |
+
cfg, train_dataset, eval_dataset
|
250 |
+
)
|
251 |
+
barrier()
|
252 |
+
if not is_main_process():
|
253 |
+
train_dataset, eval_dataset = process_datasets_for_packing(
|
254 |
+
cfg, train_dataset, eval_dataset
|
255 |
+
)
|
256 |
+
barrier()
|
257 |
+
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
258 |
+
|
259 |
if cfg.debug or "debug" in kwargs:
|
260 |
LOG.info("check_dataset_labels...")
|
261 |
check_dataset_labels(
|
|
|
272 |
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
273 |
|
274 |
# Load the model and tokenizer
|
275 |
+
LOG.info("loading model and (optionally) peft_config...")
|
276 |
model, peft_config = load_model(cfg, tokenizer)
|
277 |
|
278 |
safe_serialization = cfg.save_safetensors is True
|
|
|
306 |
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
307 |
return
|
308 |
|
309 |
+
trainer = setup_trainer(
|
310 |
+
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
311 |
+
)
|
312 |
|
313 |
model.config.use_cache = False
|
314 |
|
|
|
367 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
368 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
369 |
if cfg.fsdp:
|
370 |
+
trainer.save_model(cfg.output_dir)
|
371 |
elif cfg.local_rank == 0:
|
372 |
if cfg.flash_optimum:
|
373 |
model = BetterTransformer.reverse(model)
|
374 |
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
375 |
|
|
|
|
|
376 |
|
377 |
if __name__ == "__main__":
|
378 |
fire.Fire(train)
|
@@ -5,7 +5,7 @@ import os
|
|
5 |
from typing import List
|
6 |
|
7 |
import torch
|
8 |
-
from datasets import IterableDataset
|
9 |
|
10 |
from .prompt_tokenizers import PromptTokenizingStrategy
|
11 |
|
@@ -18,9 +18,9 @@ from .prompt_tokenizers import PromptTokenizingStrategy
|
|
18 |
LOG = logging.getLogger("axolotl")
|
19 |
|
20 |
|
21 |
-
class TokenizedPromptDataset(
|
22 |
"""
|
23 |
-
|
24 |
Args:
|
25 |
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
|
26 |
dataset (dataset.Dataset): Dataset with text files.
|
@@ -30,19 +30,18 @@ class TokenizedPromptDataset(IterableDataset):
|
|
30 |
self,
|
31 |
prompt_tokenizer: PromptTokenizingStrategy,
|
32 |
dataset: IterableDataset,
|
|
|
33 |
):
|
34 |
self.prompt_tokenizer = prompt_tokenizer
|
35 |
-
self.dataset
|
36 |
-
|
37 |
-
def
|
38 |
-
features =
|
39 |
-
num_proc = os.cpu_count()
|
40 |
-
return
|
41 |
-
self.
|
42 |
-
|
43 |
-
|
44 |
-
remove_columns=features,
|
45 |
-
)
|
46 |
)
|
47 |
|
48 |
|
@@ -77,14 +76,21 @@ class ConstantLengthDataset(IterableDataset):
|
|
77 |
self.tokens_dtype = torch.int64
|
78 |
|
79 |
def __iter__(self):
|
80 |
-
buffer = {
|
|
|
|
|
|
|
|
|
|
|
81 |
buffer_len = 0
|
82 |
for dataset in self.datasets:
|
|
|
83 |
iterator = iter(dataset)
|
84 |
more_examples = True
|
85 |
while more_examples:
|
86 |
try:
|
87 |
example = next(iterator)
|
|
|
88 |
except StopIteration:
|
89 |
more_examples = False
|
90 |
example = None
|
@@ -106,6 +112,9 @@ class ConstantLengthDataset(IterableDataset):
|
|
106 |
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
|
107 |
: self.seq_length
|
108 |
]
|
|
|
|
|
|
|
109 |
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
110 |
if labels.size() == input_ids.size() and (
|
111 |
attention_mask.size() == input_ids.size()
|
@@ -114,6 +123,7 @@ class ConstantLengthDataset(IterableDataset):
|
|
114 |
"input_ids": input_ids,
|
115 |
"labels": labels,
|
116 |
"attention_mask": attention_mask,
|
|
|
117 |
}
|
118 |
else:
|
119 |
LOG.warning(
|
@@ -123,8 +133,10 @@ class ConstantLengthDataset(IterableDataset):
|
|
123 |
"input_ids": [],
|
124 |
"attention_mask": [],
|
125 |
"labels": [],
|
|
|
126 |
}
|
127 |
buffer_len = 0
|
|
|
128 |
|
129 |
if example:
|
130 |
# FIXME
|
@@ -133,11 +145,6 @@ class ConstantLengthDataset(IterableDataset):
|
|
133 |
input_ids = example["input_ids"]
|
134 |
attention_mask = example["attention_mask"]
|
135 |
labels = example["labels"]
|
136 |
-
if (
|
137 |
-
buffer["input_ids"]
|
138 |
-
and input_ids[0] == self.tokenizer.bos_token_id
|
139 |
-
):
|
140 |
-
attention_mask[0] = 0
|
141 |
|
142 |
if add_concat_token:
|
143 |
input_ids.append(self.concat_token_id)
|
@@ -148,13 +155,17 @@ class ConstantLengthDataset(IterableDataset):
|
|
148 |
input_ids, dtype=self.tokens_dtype
|
149 |
)
|
150 |
attention_mask_with_concat = torch.tensor(
|
151 |
-
attention_mask, dtype=
|
152 |
)
|
153 |
labels_with_concat = torch.tensor(
|
154 |
labels, dtype=self.tokens_dtype
|
155 |
)
|
|
|
|
|
|
|
156 |
|
157 |
buffer["input_ids"].append(input_ids_with_concat)
|
158 |
buffer["attention_mask"].append(attention_mask_with_concat)
|
159 |
buffer["labels"].append(labels_with_concat)
|
|
|
160 |
buffer_len += len(input_ids)
|
|
|
5 |
from typing import List
|
6 |
|
7 |
import torch
|
8 |
+
from datasets import Dataset, IterableDataset
|
9 |
|
10 |
from .prompt_tokenizers import PromptTokenizingStrategy
|
11 |
|
|
|
18 |
LOG = logging.getLogger("axolotl")
|
19 |
|
20 |
|
21 |
+
class TokenizedPromptDataset(Dataset):
|
22 |
"""
|
23 |
+
Dataset that returns tokenized prompts from a stream of text files.
|
24 |
Args:
|
25 |
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
|
26 |
dataset (dataset.Dataset): Dataset with text files.
|
|
|
30 |
self,
|
31 |
prompt_tokenizer: PromptTokenizingStrategy,
|
32 |
dataset: IterableDataset,
|
33 |
+
**kwargs,
|
34 |
):
|
35 |
self.prompt_tokenizer = prompt_tokenizer
|
36 |
+
super().__init__(self.process(dataset).data, **kwargs)
|
37 |
+
|
38 |
+
def process(self, dataset):
|
39 |
+
features = dataset.features.keys()
|
40 |
+
num_proc = min(64, os.cpu_count())
|
41 |
+
return dataset.map(
|
42 |
+
self.prompt_tokenizer.tokenize_prompt,
|
43 |
+
num_proc=num_proc,
|
44 |
+
remove_columns=features,
|
|
|
|
|
45 |
)
|
46 |
|
47 |
|
|
|
76 |
self.tokens_dtype = torch.int64
|
77 |
|
78 |
def __iter__(self):
|
79 |
+
buffer = {
|
80 |
+
"input_ids": [],
|
81 |
+
"attention_mask": [],
|
82 |
+
"labels": [],
|
83 |
+
"position_ids": [],
|
84 |
+
}
|
85 |
buffer_len = 0
|
86 |
for dataset in self.datasets:
|
87 |
+
idx = 0
|
88 |
iterator = iter(dataset)
|
89 |
more_examples = True
|
90 |
while more_examples:
|
91 |
try:
|
92 |
example = next(iterator)
|
93 |
+
idx += 1
|
94 |
except StopIteration:
|
95 |
more_examples = False
|
96 |
example = None
|
|
|
112 |
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
|
113 |
: self.seq_length
|
114 |
]
|
115 |
+
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
|
116 |
+
: self.seq_length
|
117 |
+
]
|
118 |
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
119 |
if labels.size() == input_ids.size() and (
|
120 |
attention_mask.size() == input_ids.size()
|
|
|
123 |
"input_ids": input_ids,
|
124 |
"labels": labels,
|
125 |
"attention_mask": attention_mask,
|
126 |
+
"position_ids": position_ids,
|
127 |
}
|
128 |
else:
|
129 |
LOG.warning(
|
|
|
133 |
"input_ids": [],
|
134 |
"attention_mask": [],
|
135 |
"labels": [],
|
136 |
+
"position_ids": [],
|
137 |
}
|
138 |
buffer_len = 0
|
139 |
+
idx = 1
|
140 |
|
141 |
if example:
|
142 |
# FIXME
|
|
|
145 |
input_ids = example["input_ids"]
|
146 |
attention_mask = example["attention_mask"]
|
147 |
labels = example["labels"]
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
if add_concat_token:
|
150 |
input_ids.append(self.concat_token_id)
|
|
|
155 |
input_ids, dtype=self.tokens_dtype
|
156 |
)
|
157 |
attention_mask_with_concat = torch.tensor(
|
158 |
+
[idx * m for m in attention_mask], dtype=torch.int16
|
159 |
)
|
160 |
labels_with_concat = torch.tensor(
|
161 |
labels, dtype=self.tokens_dtype
|
162 |
)
|
163 |
+
position_ids = torch.arange(
|
164 |
+
len(input_ids), dtype=self.tokens_dtype
|
165 |
+
)
|
166 |
|
167 |
buffer["input_ids"].append(input_ids_with_concat)
|
168 |
buffer["attention_mask"].append(attention_mask_with_concat)
|
169 |
buffer["labels"].append(labels_with_concat)
|
170 |
+
buffer["position_ids"].append(position_ids)
|
171 |
buffer_len += len(input_ids)
|
@@ -8,9 +8,18 @@ import torch
|
|
8 |
import transformers
|
9 |
from einops import rearrange
|
10 |
from flash_attn.bert_padding import pad_input, unpad_input
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
13 |
|
|
|
|
|
14 |
|
15 |
def forward(
|
16 |
self,
|
@@ -79,6 +88,16 @@ def forward(
|
|
79 |
dtype=torch.int32,
|
80 |
device=qkv.device,
|
81 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
output = flash_attn_varlen_qkvpacked_func(
|
83 |
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
84 |
)
|
@@ -113,6 +132,7 @@ def forward(
|
|
113 |
"b s (h d) -> b s h d",
|
114 |
h=nheads,
|
115 |
)
|
|
|
116 |
return (
|
117 |
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
|
118 |
None,
|
|
|
8 |
import transformers
|
9 |
from einops import rearrange
|
10 |
from flash_attn.bert_padding import pad_input, unpad_input
|
11 |
+
|
12 |
+
try:
|
13 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
14 |
+
except ImportError:
|
15 |
+
from flash_attn.flash_attn_interface import (
|
16 |
+
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
|
17 |
+
)
|
18 |
+
|
19 |
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
20 |
|
21 |
+
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
22 |
+
|
23 |
|
24 |
def forward(
|
25 |
self,
|
|
|
88 |
dtype=torch.int32,
|
89 |
device=qkv.device,
|
90 |
)
|
91 |
+
output = flash_attn_varlen_qkvpacked_func(
|
92 |
+
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
93 |
+
)
|
94 |
+
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
95 |
+
elif position_ids.shape[0] == 1:
|
96 |
+
# special handling using sample packing
|
97 |
+
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
98 |
+
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
|
99 |
+
cu_q_lens = cu_q_lens.squeeze()
|
100 |
+
|
101 |
output = flash_attn_varlen_qkvpacked_func(
|
102 |
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
103 |
)
|
|
|
132 |
"b s (h d) -> b s h d",
|
133 |
h=nheads,
|
134 |
)
|
135 |
+
|
136 |
return (
|
137 |
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
|
138 |
None,
|
@@ -128,6 +128,7 @@ def xformers_forward(
|
|
128 |
query_states,
|
129 |
key_states,
|
130 |
value_states,
|
|
|
131 |
attn_bias=xformers.ops.LowerTriangularMask(),
|
132 |
)
|
133 |
attn_weights = None
|
|
|
128 |
query_states,
|
129 |
key_states,
|
130 |
value_states,
|
131 |
+
# attn_bias=attention_mask,
|
132 |
attn_bias=xformers.ops.LowerTriangularMask(),
|
133 |
)
|
134 |
attn_weights = None
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
|
3 |
+
"""
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
10 |
+
"""
|
11 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
12 |
+
This expansion handles packed sequences so that sequences share the same attention mask integer value
|
13 |
+
when they attend to each other within that sequence.
|
14 |
+
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
15 |
+
"""
|
16 |
+
bsz, src_len = mask.size()
|
17 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
18 |
+
|
19 |
+
mask = mask.unsqueeze(1).unsqueeze(2)
|
20 |
+
mask = mask.expand(bsz, 1, tgt_len, src_len)
|
21 |
+
|
22 |
+
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
23 |
+
binary_mask = torch.where(
|
24 |
+
mask != 0,
|
25 |
+
torch.tensor(1).to(dtype),
|
26 |
+
torch.tensor(0).to(dtype),
|
27 |
+
)
|
28 |
+
|
29 |
+
# Create a block-diagonal mask.
|
30 |
+
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
|
31 |
+
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
|
32 |
+
|
33 |
+
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
|
34 |
+
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
|
35 |
+
mask.device
|
36 |
+
)
|
37 |
+
|
38 |
+
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
|
39 |
+
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
|
40 |
+
inverted_mask = 1.0 - masked_zero_one_mask
|
41 |
+
|
42 |
+
return inverted_mask.masked_fill(
|
43 |
+
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
def hijack_expand_mask():
|
48 |
+
import transformers
|
49 |
+
|
50 |
+
transformers.models.llama.modeling_llama._expand_mask = ( # pylint: disable=protected-access
|
51 |
+
_expand_mask
|
52 |
+
)
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Shared utils for the monkeypatches
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def get_cu_seqlens(attn_mask):
|
8 |
+
"""generate a cumulative sequence length mask for flash attention using attn mask"""
|
9 |
+
if len(attn_mask.shape) == 1:
|
10 |
+
attn_mask = attn_mask.unsqueeze(0)
|
11 |
+
|
12 |
+
device = attn_mask.device
|
13 |
+
results = []
|
14 |
+
max_seq_lens = []
|
15 |
+
|
16 |
+
for row in attn_mask:
|
17 |
+
# Exclude zeros to avoid adding their positions to the mask
|
18 |
+
t_non_zeros = row[row != 0]
|
19 |
+
# Find where the sequence number changes (including the first position)
|
20 |
+
seq_change = torch.cat(
|
21 |
+
[
|
22 |
+
torch.tensor([1], dtype=torch.int32, device=device),
|
23 |
+
t_non_zeros[1:] != t_non_zeros[:-1],
|
24 |
+
]
|
25 |
+
)
|
26 |
+
# Get the indices where the sequence changes
|
27 |
+
change_indices = torch.cat(
|
28 |
+
[
|
29 |
+
(seq_change == 1).nonzero(as_tuple=True)[0],
|
30 |
+
torch.tensor([len(t_non_zeros)], dtype=torch.int32, device=device),
|
31 |
+
]
|
32 |
+
)
|
33 |
+
# Calculate the sequence lengths
|
34 |
+
seq_lengths = change_indices[1:] - change_indices[:-1]
|
35 |
+
# Calculate the length of the final sequence or padding
|
36 |
+
final_seq_length = len(row) - change_indices[-1]
|
37 |
+
# Append the length of the final sequence or padding to seq_lengths
|
38 |
+
if final_seq_length.item():
|
39 |
+
seq_lengths = torch.cat(
|
40 |
+
[
|
41 |
+
seq_lengths,
|
42 |
+
torch.tensor(
|
43 |
+
[final_seq_length.item()], dtype=torch.int32, device=device
|
44 |
+
),
|
45 |
+
]
|
46 |
+
)
|
47 |
+
# Calculate the cumulative sequence lengths
|
48 |
+
cu_seqlens = torch.cat(
|
49 |
+
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
|
50 |
+
)
|
51 |
+
max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
52 |
+
results.append(cu_seqlens)
|
53 |
+
max_seq_lens.append(max_seq_len)
|
54 |
+
|
55 |
+
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
56 |
+
|
57 |
+
|
58 |
+
def get_cu_seqlens_from_pos_ids(position_ids):
|
59 |
+
"""generate a cumulative sequence length mask for flash attention using pos ids"""
|
60 |
+
if len(position_ids.shape) == 1:
|
61 |
+
position_ids = position_ids.unsqueeze(0)
|
62 |
+
|
63 |
+
device = position_ids.device
|
64 |
+
results = []
|
65 |
+
max_seq_lens = []
|
66 |
+
|
67 |
+
for row in position_ids:
|
68 |
+
# Count the number of consecutive zeros from the right side
|
69 |
+
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
|
70 |
+
|
71 |
+
# Adjust the row to exclude padding
|
72 |
+
adjusted_row = row[:-padding_length] if padding_length else row.clone()
|
73 |
+
|
74 |
+
# Find where the position resets to 0 (indicating a new sequence)
|
75 |
+
seq_starts = torch.cat(
|
76 |
+
[
|
77 |
+
torch.tensor([True], dtype=torch.bool, device=device),
|
78 |
+
adjusted_row[1:] == 0,
|
79 |
+
]
|
80 |
+
)
|
81 |
+
# Get the indices where the sequence starts
|
82 |
+
start_indices = torch.cat(
|
83 |
+
[
|
84 |
+
(seq_starts).nonzero(as_tuple=True)[0],
|
85 |
+
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
|
86 |
+
]
|
87 |
+
)
|
88 |
+
# Calculate the sequence lengths
|
89 |
+
seq_lengths = start_indices[1:] - start_indices[:-1]
|
90 |
+
# Calculate the cumulative sequence lengths
|
91 |
+
cu_seqlens = torch.cat(
|
92 |
+
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
|
93 |
+
)
|
94 |
+
# Append the padding length to the cumulative sequence lengths
|
95 |
+
if padding_length:
|
96 |
+
cu_seqlens = torch.cat(
|
97 |
+
[cu_seqlens, torch.tensor([len(row)], dtype=torch.int32, device=device)]
|
98 |
+
)
|
99 |
+
max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
100 |
+
results.append(cu_seqlens)
|
101 |
+
max_seq_lens.append(max_seq_len)
|
102 |
+
|
103 |
+
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
@@ -66,7 +66,11 @@ class SystemDataPrompter(AlpacaPrompter):
|
|
66 |
) -> Generator[str, None, None]:
|
67 |
# returns the full prompt from instruction and optional input
|
68 |
# if a label (=response, =output) is provided, it's also appended.
|
69 |
-
formatted_sys_prompt =
|
|
|
|
|
|
|
|
|
70 |
if input:
|
71 |
res = formatted_sys_prompt + self.turn_format.format(
|
72 |
instruction=instruction, input=input
|
@@ -86,12 +90,20 @@ class OpenOrcaSystemDataPrompter(SystemDataPrompter):
|
|
86 |
"""
|
87 |
|
88 |
def match_prompt_style(self):
|
|
|
89 |
if self.prompt_style == PromptStyle.INSTRUCT.value:
|
90 |
self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
|
91 |
self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
|
92 |
if self.prompt_style == PromptStyle.CHAT.value:
|
93 |
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
94 |
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
|
97 |
class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
|
@@ -137,3 +149,12 @@ def load_open_orca(tokenizer, cfg):
|
|
137 |
cfg.train_on_inputs,
|
138 |
cfg.sequence_len,
|
139 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
) -> Generator[str, None, None]:
|
67 |
# returns the full prompt from instruction and optional input
|
68 |
# if a label (=response, =output) is provided, it's also appended.
|
69 |
+
formatted_sys_prompt = (
|
70 |
+
self.system_format.format(system=system)
|
71 |
+
if system and self.system_format
|
72 |
+
else ""
|
73 |
+
)
|
74 |
if input:
|
75 |
res = formatted_sys_prompt + self.turn_format.format(
|
76 |
instruction=instruction, input=input
|
|
|
90 |
"""
|
91 |
|
92 |
def match_prompt_style(self):
|
93 |
+
# pylint: disable=duplicate-code
|
94 |
if self.prompt_style == PromptStyle.INSTRUCT.value:
|
95 |
self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
|
96 |
self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
|
97 |
if self.prompt_style == PromptStyle.CHAT.value:
|
98 |
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
99 |
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
100 |
+
self.system_format = "SYSTEM: {system}\n"
|
101 |
+
if self.prompt_style == PromptStyle.CHATML.value:
|
102 |
+
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
|
103 |
+
self.turn_no_input_format = (
|
104 |
+
"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
|
105 |
+
)
|
106 |
+
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
|
107 |
|
108 |
|
109 |
class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
|
|
|
149 |
cfg.train_on_inputs,
|
150 |
cfg.sequence_len,
|
151 |
)
|
152 |
+
|
153 |
+
|
154 |
+
def load_open_orca_chatml(tokenizer, cfg):
|
155 |
+
return OpenOrcaPromptTokenizingStrategy(
|
156 |
+
OpenOrcaSystemDataPrompter(PromptStyle.CHATML.value),
|
157 |
+
tokenizer,
|
158 |
+
cfg.train_on_inputs,
|
159 |
+
cfg.sequence_len,
|
160 |
+
)
|
@@ -16,6 +16,7 @@ class PromptStyle(Enum):
|
|
16 |
|
17 |
INSTRUCT = "instruct"
|
18 |
CHAT = "chat"
|
|
|
19 |
|
20 |
|
21 |
class AlpacaPrompter:
|
@@ -25,6 +26,7 @@ class AlpacaPrompter:
|
|
25 |
|
26 |
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
27 |
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
|
|
28 |
turn_format: str
|
29 |
turn_no_input_format: str
|
30 |
prompt_style: Optional[PromptStyle] = None
|
@@ -34,14 +36,23 @@ class AlpacaPrompter:
|
|
34 |
self.match_prompt_style()
|
35 |
|
36 |
def match_prompt_style(self):
|
|
|
37 |
if self.prompt_style == PromptStyle.INSTRUCT.value:
|
38 |
self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
39 |
self.turn_no_input_format = (
|
40 |
"### Instruction:\n{instruction}\n\n### Response:\n"
|
41 |
)
|
|
|
42 |
if self.prompt_style == PromptStyle.CHAT.value:
|
43 |
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
44 |
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
def build_prompt(
|
47 |
self,
|
|
|
16 |
|
17 |
INSTRUCT = "instruct"
|
18 |
CHAT = "chat"
|
19 |
+
CHATML = "chatml"
|
20 |
|
21 |
|
22 |
class AlpacaPrompter:
|
|
|
26 |
|
27 |
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
28 |
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
29 |
+
system_format: str
|
30 |
turn_format: str
|
31 |
turn_no_input_format: str
|
32 |
prompt_style: Optional[PromptStyle] = None
|
|
|
36 |
self.match_prompt_style()
|
37 |
|
38 |
def match_prompt_style(self):
|
39 |
+
# pylint: disable=duplicate-code
|
40 |
if self.prompt_style == PromptStyle.INSTRUCT.value:
|
41 |
self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
42 |
self.turn_no_input_format = (
|
43 |
"### Instruction:\n{instruction}\n\n### Response:\n"
|
44 |
)
|
45 |
+
self.system_format = "### System:\n{system}\n\n"
|
46 |
if self.prompt_style == PromptStyle.CHAT.value:
|
47 |
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
48 |
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
49 |
+
self.system_format = "SYSTEM: {system}\n"
|
50 |
+
if self.prompt_style == PromptStyle.CHATML.value:
|
51 |
+
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
|
52 |
+
self.turn_no_input_format = (
|
53 |
+
"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
|
54 |
+
)
|
55 |
+
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
|
56 |
|
57 |
def build_prompt(
|
58 |
self,
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
3 |
+
"""
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from typing import Any, Optional, Union
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from transformers import PreTrainedTokenizerBase
|
9 |
+
from transformers.utils import PaddingStrategy
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class DataCollatorForSeq2Seq:
|
14 |
+
"""
|
15 |
+
Data collator that will dynamically pad the inputs received, as well as the labels and position_ids
|
16 |
+
|
17 |
+
Args:
|
18 |
+
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
|
19 |
+
The tokenizer used for encoding the data.
|
20 |
+
model ([`PreTrainedModel`]):
|
21 |
+
The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to
|
22 |
+
prepare the *decoder_input_ids*
|
23 |
+
|
24 |
+
This is useful when using *label_smoothing* to avoid calculating loss twice.
|
25 |
+
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
26 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
27 |
+
among:
|
28 |
+
|
29 |
+
- `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
|
30 |
+
sequence is provided).
|
31 |
+
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
32 |
+
acceptable input length for the model if that argument is not provided.
|
33 |
+
- `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
|
34 |
+
max_length (`int`, *optional*):
|
35 |
+
Maximum length of the returned list and optionally padding length (see above).
|
36 |
+
pad_to_multiple_of (`int`, *optional*):
|
37 |
+
If set will pad the sequence to a multiple of the provided value.
|
38 |
+
|
39 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
40 |
+
7.5 (Volta).
|
41 |
+
label_pad_token_id (`int`, *optional*, defaults to -100):
|
42 |
+
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
43 |
+
return_tensors (`str`):
|
44 |
+
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
45 |
+
"""
|
46 |
+
|
47 |
+
tokenizer: PreTrainedTokenizerBase
|
48 |
+
model: Optional[Any] = None
|
49 |
+
padding: Union[bool, str, PaddingStrategy] = True
|
50 |
+
max_length: Optional[int] = None
|
51 |
+
pad_to_multiple_of: Optional[int] = None
|
52 |
+
label_pad_token_id: int = -100
|
53 |
+
position_pad_token_id: int = 0
|
54 |
+
return_tensors: str = "pt"
|
55 |
+
|
56 |
+
def __call__(self, features, return_tensors=None):
|
57 |
+
labels = None
|
58 |
+
if return_tensors is None:
|
59 |
+
return_tensors = self.return_tensors
|
60 |
+
|
61 |
+
for feature_name, pad_token_id in [
|
62 |
+
("labels", self.label_pad_token_id),
|
63 |
+
("position_ids", self.position_pad_token_id),
|
64 |
+
]:
|
65 |
+
feat = (
|
66 |
+
[feature[feature_name] for feature in features]
|
67 |
+
if feature_name in features[0].keys()
|
68 |
+
else None
|
69 |
+
)
|
70 |
+
labels = feat if feat and feature_name == "labels" else labels
|
71 |
+
# We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
|
72 |
+
# same length to return tensors.
|
73 |
+
if feat is not None:
|
74 |
+
max_feature_length = max(len(l) for l in feat) # noqa: E741
|
75 |
+
if self.pad_to_multiple_of is not None:
|
76 |
+
max_feature_length = (
|
77 |
+
(max_feature_length + self.pad_to_multiple_of - 1)
|
78 |
+
// self.pad_to_multiple_of
|
79 |
+
* self.pad_to_multiple_of
|
80 |
+
)
|
81 |
+
|
82 |
+
padding_side = self.tokenizer.padding_side
|
83 |
+
for feature in features:
|
84 |
+
remainder = [pad_token_id] * (
|
85 |
+
max_feature_length - len(feature[feature_name])
|
86 |
+
)
|
87 |
+
if isinstance(feature[feature_name], list):
|
88 |
+
feature[feature_name] = (
|
89 |
+
feature[feature_name] + remainder
|
90 |
+
if padding_side == "right"
|
91 |
+
else remainder + feature[feature_name]
|
92 |
+
)
|
93 |
+
elif padding_side == "right":
|
94 |
+
feature[feature_name] = np.concatenate(
|
95 |
+
[feature[feature_name], remainder]
|
96 |
+
).astype(np.int64)
|
97 |
+
else:
|
98 |
+
feature[feature_name] = np.concatenate(
|
99 |
+
[remainder, feature[feature_name]]
|
100 |
+
).astype(np.int64)
|
101 |
+
|
102 |
+
features = self.tokenizer.pad(
|
103 |
+
features,
|
104 |
+
padding=self.padding,
|
105 |
+
max_length=self.max_length,
|
106 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
107 |
+
return_tensors=return_tensors,
|
108 |
+
)
|
109 |
+
|
110 |
+
# prepare decoder_input_ids
|
111 |
+
if (
|
112 |
+
labels is not None
|
113 |
+
and self.model is not None
|
114 |
+
and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
|
115 |
+
):
|
116 |
+
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(
|
117 |
+
labels=features["labels"]
|
118 |
+
)
|
119 |
+
features["decoder_input_ids"] = decoder_input_ids
|
120 |
+
|
121 |
+
return features
|
@@ -1,13 +1,19 @@
|
|
1 |
"""Module containing data utilities"""
|
2 |
import functools
|
3 |
-
import
|
4 |
import logging
|
5 |
from hashlib import md5
|
6 |
from pathlib import Path
|
7 |
-
from typing import
|
8 |
|
9 |
import torch
|
10 |
-
from datasets import
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
from huggingface_hub import hf_hub_download
|
12 |
from transformers import PreTrainedTokenizerBase
|
13 |
|
@@ -35,6 +41,7 @@ from axolotl.prompters import (
|
|
35 |
ShareGPTPrompter,
|
36 |
SummarizeTLDRPrompter,
|
37 |
)
|
|
|
38 |
|
39 |
LOG = logging.getLogger("axolotl")
|
40 |
|
@@ -109,6 +116,7 @@ def load_tokenized_prepared_datasets(
|
|
109 |
local_path = Path(d.path)
|
110 |
if local_path.exists():
|
111 |
if local_path.is_dir():
|
|
|
112 |
ds = load_dataset(
|
113 |
d.path,
|
114 |
name=d.name,
|
@@ -262,20 +270,12 @@ def load_tokenized_prepared_datasets(
|
|
262 |
raise ValueError(
|
263 |
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
|
264 |
)
|
265 |
-
LOG.info("
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
while True:
|
272 |
-
chunk = list(itertools.islice(d_iter, chunk_size))
|
273 |
-
if not chunk:
|
274 |
-
break
|
275 |
-
samples.extend(chunk)
|
276 |
-
|
277 |
-
LOG.info("shuffle")
|
278 |
-
dataset = Dataset.from_list(samples).shuffle(seed=seed)
|
279 |
if cfg.local_rank == 0:
|
280 |
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
281 |
dataset.save_to_disk(prepared_ds_path)
|
@@ -374,6 +374,7 @@ def load_prepare_datasets(
|
|
374 |
dataset = Dataset.from_list(list(constant_len_dataset))
|
375 |
|
376 |
# filter out bad data
|
|
|
377 |
dataset = Dataset.from_list(
|
378 |
[
|
379 |
d
|
@@ -413,7 +414,51 @@ def load_prepare_datasets(
|
|
413 |
)
|
414 |
|
415 |
if cfg.val_set_size:
|
416 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
417 |
train_dataset = dataset["train"]
|
418 |
eval_dataset = dataset["test"]
|
419 |
else:
|
|
|
1 |
"""Module containing data utilities"""
|
2 |
import functools
|
3 |
+
import hashlib
|
4 |
import logging
|
5 |
from hashlib import md5
|
6 |
from pathlib import Path
|
7 |
+
from typing import Tuple, Union
|
8 |
|
9 |
import torch
|
10 |
+
from datasets import (
|
11 |
+
Dataset,
|
12 |
+
DatasetDict,
|
13 |
+
concatenate_datasets,
|
14 |
+
load_dataset,
|
15 |
+
load_from_disk,
|
16 |
+
)
|
17 |
from huggingface_hub import hf_hub_download
|
18 |
from transformers import PreTrainedTokenizerBase
|
19 |
|
|
|
41 |
ShareGPTPrompter,
|
42 |
SummarizeTLDRPrompter,
|
43 |
)
|
44 |
+
from axolotl.utils.distributed import barrier, is_main_process
|
45 |
|
46 |
LOG = logging.getLogger("axolotl")
|
47 |
|
|
|
116 |
local_path = Path(d.path)
|
117 |
if local_path.exists():
|
118 |
if local_path.is_dir():
|
119 |
+
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
|
120 |
ds = load_dataset(
|
121 |
d.path,
|
122 |
name=d.name,
|
|
|
270 |
raise ValueError(
|
271 |
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
|
272 |
)
|
273 |
+
LOG.info("merging datasets")
|
274 |
+
dataset = concatenate_datasets(datasets)
|
275 |
+
|
276 |
+
if len(datasets) > 1:
|
277 |
+
LOG.info("shuffle merged datasets")
|
278 |
+
dataset = dataset.shuffle(seed=seed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
if cfg.local_rank == 0:
|
280 |
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
281 |
dataset.save_to_disk(prepared_ds_path)
|
|
|
374 |
dataset = Dataset.from_list(list(constant_len_dataset))
|
375 |
|
376 |
# filter out bad data
|
377 |
+
# TODO convert to dataset.filter(...)
|
378 |
dataset = Dataset.from_list(
|
379 |
[
|
380 |
d
|
|
|
414 |
)
|
415 |
|
416 |
if cfg.val_set_size:
|
417 |
+
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
|
418 |
+
to_hash_train = (
|
419 |
+
dataset._fingerprint # pylint: disable=protected-access
|
420 |
+
+ "|"
|
421 |
+
+ str(cfg.val_set_size)
|
422 |
+
+ "|"
|
423 |
+
+ "train"
|
424 |
+
+ "|"
|
425 |
+
+ str(cfg.seed or 42)
|
426 |
+
)
|
427 |
+
to_hash_test = (
|
428 |
+
dataset._fingerprint # pylint: disable=protected-access
|
429 |
+
+ "|"
|
430 |
+
+ str(cfg.val_set_size)
|
431 |
+
+ "|"
|
432 |
+
+ "test"
|
433 |
+
+ "|"
|
434 |
+
+ str(cfg.seed or 42)
|
435 |
+
)
|
436 |
+
train_fingerprint = hashlib.md5(
|
437 |
+
to_hash_train.encode(), usedforsecurity=False
|
438 |
+
).hexdigest()
|
439 |
+
test_fingerprint = hashlib.md5(
|
440 |
+
to_hash_test.encode(), usedforsecurity=False
|
441 |
+
).hexdigest()
|
442 |
+
|
443 |
+
if is_main_process():
|
444 |
+
dataset = dataset.train_test_split(
|
445 |
+
test_size=cfg.val_set_size,
|
446 |
+
shuffle=False,
|
447 |
+
seed=cfg.seed or 42,
|
448 |
+
train_new_fingerprint=train_fingerprint,
|
449 |
+
test_new_fingerprint=test_fingerprint,
|
450 |
+
)
|
451 |
+
barrier()
|
452 |
+
if not is_main_process():
|
453 |
+
dataset = dataset.train_test_split(
|
454 |
+
test_size=cfg.val_set_size,
|
455 |
+
shuffle=False,
|
456 |
+
seed=cfg.seed or 42,
|
457 |
+
train_new_fingerprint=train_fingerprint,
|
458 |
+
test_new_fingerprint=test_fingerprint,
|
459 |
+
)
|
460 |
+
barrier()
|
461 |
+
|
462 |
train_dataset = dataset["train"]
|
463 |
eval_dataset = dataset["test"]
|
464 |
else:
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pylint: skip-file
|
2 |
+
import hashlib
|
3 |
+
import itertools
|
4 |
+
import logging
|
5 |
+
import math
|
6 |
+
from typing import Any, Callable, List, Union
|
7 |
+
|
8 |
+
import numba
|
9 |
+
import numpy as np
|
10 |
+
from torch.utils.data import DistributedSampler, Sampler
|
11 |
+
|
12 |
+
LOG = logging.getLogger("axolotl.utils.dataloader")
|
13 |
+
|
14 |
+
|
15 |
+
@numba.njit
|
16 |
+
def ffd_check(a: np.ndarray, c: int, n: int):
|
17 |
+
# First-fit-decreasing bin packing
|
18 |
+
# Check if a[] could fit in n bins with capacity c
|
19 |
+
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
|
20 |
+
|
21 |
+
a = np.sort(a)[::-1]
|
22 |
+
bins = np.full((n,), c, dtype=a.dtype)
|
23 |
+
for size in a:
|
24 |
+
not_found = True
|
25 |
+
for idx in range(n):
|
26 |
+
if bins[idx] >= size:
|
27 |
+
bins[idx] -= size
|
28 |
+
not_found = False
|
29 |
+
break
|
30 |
+
|
31 |
+
if not_found:
|
32 |
+
return False
|
33 |
+
|
34 |
+
return True
|
35 |
+
|
36 |
+
|
37 |
+
@numba.njit
|
38 |
+
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
|
39 |
+
# First-fit-decreasing bin packing (with result return)
|
40 |
+
|
41 |
+
indices = np.argsort(a)[::-1]
|
42 |
+
a = a[indices]
|
43 |
+
|
44 |
+
bins: List[Any] = []
|
45 |
+
bins_result: List[Any] = []
|
46 |
+
for a_id, size in enumerate(a):
|
47 |
+
add_new = True
|
48 |
+
for idx in range(len(bins)):
|
49 |
+
if bins[idx] >= size:
|
50 |
+
bins[idx] -= size
|
51 |
+
bins_result[idx].append(indices[a_id] + start_index)
|
52 |
+
add_new = False
|
53 |
+
break
|
54 |
+
|
55 |
+
if add_new:
|
56 |
+
bins.append(c - size)
|
57 |
+
bins_result.append([indices[a_id] + start_index])
|
58 |
+
|
59 |
+
return bins_result, len(a)
|
60 |
+
|
61 |
+
|
62 |
+
@numba.njit
|
63 |
+
def allocate(
|
64 |
+
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
|
65 |
+
):
|
66 |
+
"""
|
67 |
+
:param lengths: array of lengths of each sample
|
68 |
+
:param lengths_cumsum: cumulative sum of consecutive lengths
|
69 |
+
:param rank: rank for this process
|
70 |
+
:param c: length of tokens per batch
|
71 |
+
:param n: number of ranks
|
72 |
+
:return:
|
73 |
+
"""
|
74 |
+
# Dynamic batch allocator, similar to Multifit
|
75 |
+
# https://en.wikipedia.org/wiki/Multifit_algorithm
|
76 |
+
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
|
77 |
+
|
78 |
+
s = 0
|
79 |
+
start_index = 0
|
80 |
+
result = []
|
81 |
+
result_totseqs = []
|
82 |
+
|
83 |
+
while True:
|
84 |
+
# binary search [left, right)
|
85 |
+
left = 1
|
86 |
+
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
|
87 |
+
|
88 |
+
while right - left > 1:
|
89 |
+
mid = (left + right) // 2
|
90 |
+
if ffd_check(lengths[start_index : start_index + mid], c, n):
|
91 |
+
left = mid
|
92 |
+
else:
|
93 |
+
right = mid
|
94 |
+
|
95 |
+
# use length left
|
96 |
+
batch, tot_seqs = ffd_with_result(
|
97 |
+
lengths[start_index : start_index + left], c, start_index
|
98 |
+
)
|
99 |
+
if len(batch) < n:
|
100 |
+
break
|
101 |
+
|
102 |
+
start_index += left
|
103 |
+
s = lengths_cumsum[start_index - 1]
|
104 |
+
|
105 |
+
# add local rank
|
106 |
+
result.append(batch[rank])
|
107 |
+
# add total seqs for all ranks
|
108 |
+
result_totseqs.append(tot_seqs)
|
109 |
+
# yield batch[rank], tot_seqs, s, len(result) * c * n
|
110 |
+
return result, result_totseqs, s, len(result) * c * n
|
111 |
+
|
112 |
+
|
113 |
+
def chunk(iterable, n):
|
114 |
+
"""
|
115 |
+
Chunk data into tuples of length n
|
116 |
+
"""
|
117 |
+
# batched('ABCDEFG', 3) --> ABC DEF G
|
118 |
+
if n < 1:
|
119 |
+
raise ValueError("n must be at least one")
|
120 |
+
it = iter(iterable)
|
121 |
+
while batch := tuple(itertools.islice(it, n)):
|
122 |
+
yield batch
|
123 |
+
|
124 |
+
|
125 |
+
def hash_indices(lst: List[int]) -> str:
|
126 |
+
# Convert the list of integers to a string representation
|
127 |
+
concatenated = ",".join(map(str, lst))
|
128 |
+
|
129 |
+
# Generate the hash
|
130 |
+
sha256 = hashlib.sha256()
|
131 |
+
sha256.update(concatenated.encode())
|
132 |
+
|
133 |
+
return sha256.hexdigest()
|
134 |
+
|
135 |
+
|
136 |
+
class MultipackDistributedDataloader:
|
137 |
+
"""Unpadded data loading using Multipack.
|
138 |
+
Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py
|
139 |
+
Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.
|
140 |
+
"""
|
141 |
+
|
142 |
+
def __init__(
|
143 |
+
self,
|
144 |
+
dataset: Any,
|
145 |
+
collate_fn: Callable,
|
146 |
+
seq_max_length: int = 2048,
|
147 |
+
batch_size: int = 1,
|
148 |
+
sampler: Union[Sampler, DistributedSampler] = None,
|
149 |
+
packing_efficiency_estimate: float = 1.0,
|
150 |
+
sample_packing_seq_len_multiplier: int = 1,
|
151 |
+
device_count: int = 1,
|
152 |
+
):
|
153 |
+
# Dataset
|
154 |
+
self.dataset = dataset
|
155 |
+
self.lengths = (
|
156 |
+
dataset.data.column("position_ids")
|
157 |
+
.to_pandas()
|
158 |
+
.apply(lambda x: x[-1] + 1)
|
159 |
+
.values
|
160 |
+
)
|
161 |
+
assert isinstance(self.lengths, np.ndarray)
|
162 |
+
assert batch_size % sample_packing_seq_len_multiplier == 0
|
163 |
+
assert batch_size >= sample_packing_seq_len_multiplier
|
164 |
+
self.sampler = sampler
|
165 |
+
self.batch_size = batch_size
|
166 |
+
self.sample_packing_seq_len_multiplier = sample_packing_seq_len_multiplier
|
167 |
+
self.seq_max_length = seq_max_length
|
168 |
+
self.batch_max_length = batch_size * seq_max_length
|
169 |
+
self.collate_fn = collate_fn
|
170 |
+
|
171 |
+
self.num_replicas = 1
|
172 |
+
self.rank = 0
|
173 |
+
|
174 |
+
# statistics
|
175 |
+
self.eff_total_used = 0
|
176 |
+
self.eff_total_slots = 0
|
177 |
+
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
178 |
+
self.device_count = device_count
|
179 |
+
|
180 |
+
def generate_batches(self, set_stats=False):
|
181 |
+
LOG.info("generating packed batches")
|
182 |
+
if self.sampler:
|
183 |
+
indices = [idx for idx in self.sampler]
|
184 |
+
else:
|
185 |
+
indices = range(0, len(self.dataset))
|
186 |
+
|
187 |
+
LOG.info(hash_indices(indices))
|
188 |
+
lengths = self.lengths[indices]
|
189 |
+
lengths_cumsum = np.cumsum(lengths)
|
190 |
+
|
191 |
+
batches, totseqs, total_used, total_slots = allocate(
|
192 |
+
lengths=lengths,
|
193 |
+
lengths_cumsum=lengths_cumsum,
|
194 |
+
rank=self.rank,
|
195 |
+
# c=self.batch_max_length,
|
196 |
+
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
|
197 |
+
n=self.num_replicas,
|
198 |
+
)
|
199 |
+
|
200 |
+
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
|
201 |
+
|
202 |
+
# statistics
|
203 |
+
if set_stats:
|
204 |
+
self.eff_total_used += total_used
|
205 |
+
self.eff_total_slots += total_slots
|
206 |
+
|
207 |
+
return batches, totseqs
|
208 |
+
|
209 |
+
def __iter__(self):
|
210 |
+
if hasattr(self.sampler, "set_epoch"):
|
211 |
+
new_epoch = self.sampler.epoch + 1
|
212 |
+
self.sampler.set_epoch(new_epoch)
|
213 |
+
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
214 |
+
all_batches, _ = self.generate_batches(set_stats=True)
|
215 |
+
features = self.dataset.features.keys()
|
216 |
+
len_remaining = self._len_est()
|
217 |
+
for batches in chunk(
|
218 |
+
all_batches, self.batch_size // self.sample_packing_seq_len_multiplier
|
219 |
+
):
|
220 |
+
chunked_data = []
|
221 |
+
attn_mask_cum_idx = 0
|
222 |
+
for batch in batches:
|
223 |
+
concatenated = {}
|
224 |
+
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
|
225 |
+
for feature in features:
|
226 |
+
if feature == "attention_mask":
|
227 |
+
arrays = [
|
228 |
+
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
|
229 |
+
for idx, item in enumerate(batched_data)
|
230 |
+
if feature in item
|
231 |
+
]
|
232 |
+
attn_mask_cum_idx += len(batched_data)
|
233 |
+
concatenated[feature] = np.concatenate(arrays)
|
234 |
+
else:
|
235 |
+
arrays = [
|
236 |
+
np.array(item[feature])
|
237 |
+
for item in batched_data
|
238 |
+
if feature in item
|
239 |
+
]
|
240 |
+
concatenated[feature] = np.concatenate(arrays)
|
241 |
+
chunked_data.append(concatenated)
|
242 |
+
yield self.collate_fn(chunked_data)
|
243 |
+
len_remaining -= 1
|
244 |
+
if not len_remaining:
|
245 |
+
return
|
246 |
+
|
247 |
+
def _len_est(self):
|
248 |
+
lengths_sum = np.sum(self.lengths)
|
249 |
+
lengths_sum_per_device = lengths_sum // self.device_count
|
250 |
+
LOG.info(
|
251 |
+
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
252 |
+
f"total_num_tokens per device: {lengths_sum_per_device}"
|
253 |
+
)
|
254 |
+
|
255 |
+
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
256 |
+
return (
|
257 |
+
math.floor(
|
258 |
+
0.99
|
259 |
+
* lengths_sum_per_device
|
260 |
+
/ self.packing_efficiency_estimate
|
261 |
+
// self.seq_max_length
|
262 |
+
// self.batch_size
|
263 |
+
)
|
264 |
+
- 1
|
265 |
+
)
|
266 |
+
|
267 |
+
def __len__(self):
|
268 |
+
# this doesn't return the actual length b/c with distributed samplers, not all dataloaders get
|
269 |
+
# the same share of total tokens
|
270 |
+
# if not self.eff_total_used:
|
271 |
+
# batches, _ = self.generate_batches(set_stats=True)
|
272 |
+
# LOG.info(
|
273 |
+
# f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
274 |
+
# f"actual packing efficiency: {self.efficiency()}"
|
275 |
+
# )
|
276 |
+
return max(1, self._len_est())
|
277 |
+
|
278 |
+
def len_w_stats(self):
|
279 |
+
if not self.eff_total_used:
|
280 |
+
batches, _ = self.generate_batches(set_stats=True)
|
281 |
+
LOG.info(
|
282 |
+
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
283 |
+
f"actual packing efficiency: {self.efficiency()}"
|
284 |
+
)
|
285 |
+
return max(1, self._len_est())
|
286 |
+
|
287 |
+
def efficiency(self):
|
288 |
+
return self.eff_total_used / self.eff_total_slots
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
utility helpers for distributed checks
|
3 |
+
"""
|
4 |
+
import torch.distributed as dist
|
5 |
+
from accelerate import Accelerator
|
6 |
+
|
7 |
+
accelerate = None # pylint: disable=invalid-name
|
8 |
+
|
9 |
+
|
10 |
+
def load_accelerate():
|
11 |
+
global accelerate # pylint: disable=global-statement
|
12 |
+
accelerate = Accelerator()
|
13 |
+
|
14 |
+
|
15 |
+
def is_distributed():
|
16 |
+
"""
|
17 |
+
Check if distributed training is initialized.
|
18 |
+
"""
|
19 |
+
global accelerate # pylint: disable=global-statement
|
20 |
+
if not accelerate:
|
21 |
+
accelerate = Accelerator()
|
22 |
+
return dist.is_available() and dist.is_initialized()
|
23 |
+
|
24 |
+
|
25 |
+
def barrier():
|
26 |
+
"""
|
27 |
+
Acts as a barrier to wait for all processes. This ensures that all processes
|
28 |
+
reach the barrier before proceeding further.
|
29 |
+
"""
|
30 |
+
if is_distributed():
|
31 |
+
dist.barrier()
|
32 |
+
|
33 |
+
|
34 |
+
def is_main_process():
|
35 |
+
"""
|
36 |
+
Check if the current process is the main process.
|
37 |
+
If not in distributed mode, always return True.
|
38 |
+
"""
|
39 |
+
if not is_distributed():
|
40 |
+
return True
|
41 |
+
return dist.get_rank() == 0
|
@@ -37,20 +37,26 @@ def load_tokenizer(
|
|
37 |
tokenizer_type,
|
38 |
cfg,
|
39 |
):
|
|
|
40 |
use_fast = True # this is the default
|
41 |
if cfg.tokenizer_use_fast is not None:
|
42 |
use_fast = cfg.tokenizer_use_fast
|
|
|
|
|
|
|
43 |
if tokenizer_type:
|
44 |
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
45 |
tokenizer_config,
|
46 |
trust_remote_code=cfg.trust_remote_code or False,
|
47 |
use_fast=use_fast,
|
|
|
48 |
)
|
49 |
else:
|
50 |
tokenizer = AutoTokenizer.from_pretrained(
|
51 |
tokenizer_config,
|
52 |
trust_remote_code=cfg.trust_remote_code or False,
|
53 |
use_fast=use_fast,
|
|
|
54 |
)
|
55 |
|
56 |
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
@@ -90,8 +96,10 @@ def load_model(
|
|
90 |
|
91 |
# TODO refactor as a kwarg
|
92 |
load_in_8bit = cfg.load_in_8bit
|
93 |
-
cfg.is_llama_derived_model =
|
94 |
-
|
|
|
|
|
95 |
)
|
96 |
|
97 |
if cfg.is_llama_derived_model and cfg.flash_attention:
|
@@ -136,6 +144,14 @@ def load_model(
|
|
136 |
LOG.info("patching with xpos rope")
|
137 |
replace_llama_rope_with_xpos_rope()
|
138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
if cfg.bf16 or cfg.bfloat16:
|
140 |
torch_dtype = torch.bfloat16
|
141 |
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
@@ -228,7 +244,6 @@ def load_model(
|
|
228 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
229 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
230 |
torch_dtype=torch_dtype,
|
231 |
-
device_map="auto" if cfg.world_size == 1 else cfg.device_map,
|
232 |
**model_kwargs,
|
233 |
)
|
234 |
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
@@ -263,7 +278,6 @@ def load_model(
|
|
263 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
264 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
265 |
torch_dtype=torch_dtype,
|
266 |
-
device_map=cfg.device_map,
|
267 |
trust_remote_code=cfg.trust_remote_code or False,
|
268 |
**model_kwargs,
|
269 |
)
|
@@ -294,7 +308,6 @@ def load_model(
|
|
294 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
295 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
296 |
torch_dtype=torch_dtype,
|
297 |
-
device_map=cfg.device_map,
|
298 |
trust_remote_code=cfg.trust_remote_code or False,
|
299 |
**model_kwargs,
|
300 |
)
|
@@ -308,7 +321,6 @@ def load_model(
|
|
308 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
309 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
310 |
torch_dtype=torch_dtype,
|
311 |
-
device_map=cfg.device_map,
|
312 |
trust_remote_code=cfg.trust_remote_code or False,
|
313 |
**model_kwargs,
|
314 |
)
|
|
|
37 |
tokenizer_type,
|
38 |
cfg,
|
39 |
):
|
40 |
+
tokenizer_kwargs = {}
|
41 |
use_fast = True # this is the default
|
42 |
if cfg.tokenizer_use_fast is not None:
|
43 |
use_fast = cfg.tokenizer_use_fast
|
44 |
+
if cfg.tokenizer_legacy is not None:
|
45 |
+
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
|
46 |
+
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
|
47 |
if tokenizer_type:
|
48 |
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
49 |
tokenizer_config,
|
50 |
trust_remote_code=cfg.trust_remote_code or False,
|
51 |
use_fast=use_fast,
|
52 |
+
**tokenizer_kwargs,
|
53 |
)
|
54 |
else:
|
55 |
tokenizer = AutoTokenizer.from_pretrained(
|
56 |
tokenizer_config,
|
57 |
trust_remote_code=cfg.trust_remote_code or False,
|
58 |
use_fast=use_fast,
|
59 |
+
**tokenizer_kwargs,
|
60 |
)
|
61 |
|
62 |
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
|
|
96 |
|
97 |
# TODO refactor as a kwarg
|
98 |
load_in_8bit = cfg.load_in_8bit
|
99 |
+
cfg.is_llama_derived_model = (
|
100 |
+
"llama" in base_model
|
101 |
+
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
102 |
+
or cfg.is_llama_derived_model
|
103 |
)
|
104 |
|
105 |
if cfg.is_llama_derived_model and cfg.flash_attention:
|
|
|
144 |
LOG.info("patching with xpos rope")
|
145 |
replace_llama_rope_with_xpos_rope()
|
146 |
|
147 |
+
if cfg.is_llama_derived_model and (
|
148 |
+
cfg.max_packed_sequence_len or cfg.sample_packing
|
149 |
+
):
|
150 |
+
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
151 |
+
|
152 |
+
LOG.info("patching _expand_mask")
|
153 |
+
hijack_expand_mask()
|
154 |
+
|
155 |
if cfg.bf16 or cfg.bfloat16:
|
156 |
torch_dtype = torch.bfloat16
|
157 |
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
|
|
244 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
245 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
246 |
torch_dtype=torch_dtype,
|
|
|
247 |
**model_kwargs,
|
248 |
)
|
249 |
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
|
|
278 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
279 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
280 |
torch_dtype=torch_dtype,
|
|
|
281 |
trust_remote_code=cfg.trust_remote_code or False,
|
282 |
**model_kwargs,
|
283 |
)
|
|
|
308 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
309 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
310 |
torch_dtype=torch_dtype,
|
|
|
311 |
trust_remote_code=cfg.trust_remote_code or False,
|
312 |
**model_kwargs,
|
313 |
)
|
|
|
321 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
322 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
323 |
torch_dtype=torch_dtype,
|
|
|
324 |
trust_remote_code=cfg.trust_remote_code or False,
|
325 |
**model_kwargs,
|
326 |
)
|
@@ -1,19 +1,23 @@
|
|
1 |
"""Module containing the Trainer class and related functions"""
|
2 |
-
|
3 |
import importlib
|
4 |
import logging
|
5 |
import math
|
6 |
import os
|
7 |
import sys
|
|
|
8 |
from dataclasses import dataclass, field
|
|
|
9 |
from pathlib import Path
|
10 |
-
from typing import Optional
|
11 |
|
12 |
import bitsandbytes as bnb
|
|
|
13 |
import torch.cuda
|
14 |
import transformers
|
|
|
15 |
from torch import nn
|
16 |
from torch.optim.lr_scheduler import OneCycleLR
|
|
|
17 |
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
18 |
from transformers.trainer_pt_utils import get_parameter_names
|
19 |
|
@@ -22,6 +26,8 @@ from axolotl.utils.callbacks import (
|
|
22 |
SaveBetterTransformerModelCallback,
|
23 |
SavePeftModelCallback,
|
24 |
)
|
|
|
|
|
25 |
from axolotl.utils.schedulers import (
|
26 |
InterpolatingLogScheduler,
|
27 |
get_cosine_schedule_with_quadratic_warmup,
|
@@ -30,6 +36,68 @@ from axolotl.utils.schedulers import (
|
|
30 |
LOG = logging.getLogger("axolotl")
|
31 |
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
@dataclass
|
34 |
class AxolotlTrainingArguments(TrainingArguments):
|
35 |
"""
|
@@ -40,6 +108,22 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|
40 |
default=False,
|
41 |
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
42 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
|
45 |
class AxolotlTrainer(Trainer):
|
@@ -77,6 +161,64 @@ class AxolotlTrainer(Trainer):
|
|
77 |
return super().create_scheduler(num_training_steps, optimizer)
|
78 |
return self.lr_scheduler
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
82 |
"""
|
@@ -107,10 +249,121 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
|
107 |
return self.lr_scheduler
|
108 |
|
109 |
|
110 |
-
def
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
warmup_steps = (
|
115 |
cfg.warmup_steps
|
116 |
if cfg.warmup_steps is not None
|
@@ -190,7 +443,14 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
190 |
if cfg.save_safetensors:
|
191 |
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
|
192 |
|
|
|
|
|
|
|
|
|
|
|
193 |
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
|
|
|
|
194 |
per_device_train_batch_size=cfg.micro_batch_size,
|
195 |
per_device_eval_batch_size=cfg.eval_batch_size
|
196 |
if cfg.eval_batch_size is not None
|
@@ -204,7 +464,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
204 |
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
|
205 |
save_steps=cfg.save_steps,
|
206 |
output_dir=cfg.output_dir,
|
207 |
-
save_total_limit=
|
208 |
load_best_model_at_end=(
|
209 |
cfg.load_best_model_at_end is not False
|
210 |
and cfg.val_set_size > 0
|
@@ -222,6 +482,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
222 |
if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep")
|
223 |
else "cosine",
|
224 |
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
|
|
|
|
225 |
**training_arguments_kwargs,
|
226 |
)
|
227 |
|
@@ -316,11 +578,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
316 |
if cfg.collator_pad_to_longest:
|
317 |
data_collator_kwargs["padding"] = "longest"
|
318 |
else:
|
319 |
-
|
|
|
|
|
320 |
|
321 |
if cfg.is_llama_derived_model and cfg.landmark_attention:
|
322 |
-
from functools import partial
|
323 |
-
|
324 |
from axolotl.monkeypatch.llama_landmark_attn import (
|
325 |
add_mem_tokens,
|
326 |
get_mem_id,
|
@@ -348,7 +610,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
348 |
train_dataset=train_dataset,
|
349 |
eval_dataset=eval_dataset,
|
350 |
args=training_args,
|
351 |
-
data_collator=
|
352 |
tokenizer,
|
353 |
return_tensors="pt",
|
354 |
**data_collator_kwargs,
|
|
|
1 |
"""Module containing the Trainer class and related functions"""
|
|
|
2 |
import importlib
|
3 |
import logging
|
4 |
import math
|
5 |
import os
|
6 |
import sys
|
7 |
+
from contextlib import contextmanager
|
8 |
from dataclasses import dataclass, field
|
9 |
+
from functools import partial
|
10 |
from pathlib import Path
|
11 |
+
from typing import Optional, Union
|
12 |
|
13 |
import bitsandbytes as bnb
|
14 |
+
import numpy as np
|
15 |
import torch.cuda
|
16 |
import transformers
|
17 |
+
from datasets import Dataset, set_caching_enabled
|
18 |
from torch import nn
|
19 |
from torch.optim.lr_scheduler import OneCycleLR
|
20 |
+
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
|
21 |
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
22 |
from transformers.trainer_pt_utils import get_parameter_names
|
23 |
|
|
|
26 |
SaveBetterTransformerModelCallback,
|
27 |
SavePeftModelCallback,
|
28 |
)
|
29 |
+
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
30 |
+
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
31 |
from axolotl.utils.schedulers import (
|
32 |
InterpolatingLogScheduler,
|
33 |
get_cosine_schedule_with_quadratic_warmup,
|
|
|
36 |
LOG = logging.getLogger("axolotl")
|
37 |
|
38 |
|
39 |
+
@torch.jit.script
|
40 |
+
def weighted_cross_entropy(
|
41 |
+
logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor
|
42 |
+
):
|
43 |
+
# Flatten the logits, labels, and weights tensors
|
44 |
+
logits = logits.view(
|
45 |
+
-1, logits.size(-1)
|
46 |
+
) # logits becomes of shape [batch_size*sequence_length, vocab_size]
|
47 |
+
labels = labels.view(-1) # labels becomes of shape [batch_size*sequence_length]
|
48 |
+
weights = weights.view(-1) # weights becomes of shape [batch_size*sequence_length]
|
49 |
+
|
50 |
+
# Compute the unweighted cross entropy loss
|
51 |
+
losses = torch.nn.functional.cross_entropy(logits, labels, reduction="none")
|
52 |
+
|
53 |
+
# Apply the weights to the losses and compute their sum
|
54 |
+
return (weights * losses).sum()
|
55 |
+
|
56 |
+
|
57 |
+
@torch.jit.script
|
58 |
+
def create_weighted_mask(labels: torch.Tensor):
|
59 |
+
# Check if the tensor is 2D. If not, unsqueeze it to make it 2D
|
60 |
+
if len(labels.shape) == 1:
|
61 |
+
labels = labels.unsqueeze(0)
|
62 |
+
|
63 |
+
weights = torch.zeros_like(labels).float()
|
64 |
+
for i in range(labels.shape[0]):
|
65 |
+
mask = labels[i] != -100
|
66 |
+
|
67 |
+
# Create a tensor to track group ids
|
68 |
+
group_ids = torch.zeros_like(labels[i]).int()
|
69 |
+
curr_group_id = 0
|
70 |
+
|
71 |
+
for j in range(1, len(labels[i])):
|
72 |
+
if mask[j] and not mask[j - 1]: # switch from masked to unmasked label
|
73 |
+
curr_group_id += 1 # start new group
|
74 |
+
group_ids[j] = (
|
75 |
+
curr_group_id if mask[j] else 0
|
76 |
+
) # assign group id if unmasked label
|
77 |
+
|
78 |
+
# Count only unmasked labels in each group
|
79 |
+
group_counts = torch.bincount(group_ids[mask])
|
80 |
+
|
81 |
+
mask_weights = torch.zeros_like(labels[i]).float()
|
82 |
+
mask_weights[mask] = 1.0 / group_counts[group_ids[mask]]
|
83 |
+
|
84 |
+
weights[i] = mask_weights
|
85 |
+
|
86 |
+
return weights.squeeze() # squeeze the output to match the input dimension
|
87 |
+
|
88 |
+
|
89 |
+
def trainer_weighted_loss(model_output, labels, shift_labels=True):
|
90 |
+
logits = (
|
91 |
+
model_output["logits"] if isinstance(model_output, dict) else model_output[0]
|
92 |
+
)
|
93 |
+
if shift_labels:
|
94 |
+
logits = logits[..., :-1, :].contiguous()
|
95 |
+
labels = labels[..., 1:].contiguous()
|
96 |
+
|
97 |
+
weights = create_weighted_mask(labels)
|
98 |
+
return weighted_cross_entropy(logits, labels, weights)
|
99 |
+
|
100 |
+
|
101 |
@dataclass
|
102 |
class AxolotlTrainingArguments(TrainingArguments):
|
103 |
"""
|
|
|
108 |
default=False,
|
109 |
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
110 |
)
|
111 |
+
sample_packing: bool = field(
|
112 |
+
default=False,
|
113 |
+
metadata={"help": "Use sample packing for efficient training."},
|
114 |
+
)
|
115 |
+
sample_packing_efficiency: float = field(
|
116 |
+
default=1.0,
|
117 |
+
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
118 |
+
)
|
119 |
+
max_seq_length: int = field(
|
120 |
+
default=2048,
|
121 |
+
metadata={"help": "The maximum sequence length the model can handle"},
|
122 |
+
)
|
123 |
+
sample_packing_seq_len_multiplier: int = field(
|
124 |
+
default=1,
|
125 |
+
metadata={"help": "the multiplier for the max len for packed sequences"},
|
126 |
+
)
|
127 |
|
128 |
|
129 |
class AxolotlTrainer(Trainer):
|
|
|
161 |
return super().create_scheduler(num_training_steps, optimizer)
|
162 |
return self.lr_scheduler
|
163 |
|
164 |
+
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
165 |
+
if self.args.world_size > 1 and self.args.sample_packing:
|
166 |
+
return DistributedSampler(
|
167 |
+
self.train_dataset,
|
168 |
+
num_replicas=self.args.world_size,
|
169 |
+
rank=self.args.process_index,
|
170 |
+
seed=self.args.seed,
|
171 |
+
)
|
172 |
+
return super()._get_train_sampler()
|
173 |
+
|
174 |
+
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
175 |
+
if self.args.sample_packing:
|
176 |
+
train_sampler = self._get_train_sampler()
|
177 |
+
return self.accelerator.prepare(
|
178 |
+
MultipackDistributedDataloader(
|
179 |
+
self.train_dataset,
|
180 |
+
batch_size=self._train_batch_size,
|
181 |
+
seq_max_length=self.args.max_seq_length,
|
182 |
+
collate_fn=self.data_collator,
|
183 |
+
sampler=train_sampler,
|
184 |
+
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
185 |
+
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
186 |
+
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
187 |
+
)
|
188 |
+
)
|
189 |
+
return super().get_train_dataloader()
|
190 |
+
|
191 |
+
def get_eval_dataloader(
|
192 |
+
self, eval_dataset: Optional[Dataset] = None
|
193 |
+
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
194 |
+
if self.args.sample_packing:
|
195 |
+
eval_dataset = (
|
196 |
+
eval_dataset if eval_dataset is not None else self.eval_dataset
|
197 |
+
)
|
198 |
+
eval_sampler = self._get_eval_sampler(eval_dataset)
|
199 |
+
return self.accelerator.prepare(
|
200 |
+
MultipackDistributedDataloader(
|
201 |
+
eval_dataset,
|
202 |
+
batch_size=self.args.eval_batch_size,
|
203 |
+
seq_max_length=self.args.max_seq_length,
|
204 |
+
collate_fn=self.data_collator,
|
205 |
+
sampler=eval_sampler,
|
206 |
+
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
207 |
+
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
208 |
+
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
209 |
+
)
|
210 |
+
)
|
211 |
+
return super().get_eval_dataloader(eval_dataset)
|
212 |
+
|
213 |
+
def compute_loss(self, model, inputs, return_outputs=False):
|
214 |
+
# use one's weighted cross entropy loss calc
|
215 |
+
# if self.args.sample_packing:
|
216 |
+
# labels = inputs.pop("labels")
|
217 |
+
# outputs = model(**inputs)
|
218 |
+
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
219 |
+
# return (loss, outputs) if return_outputs else loss
|
220 |
+
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
221 |
+
|
222 |
|
223 |
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
224 |
"""
|
|
|
249 |
return self.lr_scheduler
|
250 |
|
251 |
|
252 |
+
def add_position_ids(sample):
|
253 |
+
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
|
254 |
+
return sample
|
255 |
+
|
256 |
+
|
257 |
+
def drop_long_seq(sample, sequence_len=2048):
|
258 |
+
return len(sample["input_ids"]) <= sequence_len
|
259 |
+
|
260 |
+
|
261 |
+
@contextmanager
|
262 |
+
def disable_datasets_caching():
|
263 |
+
try:
|
264 |
+
set_caching_enabled(False)
|
265 |
+
yield
|
266 |
+
finally:
|
267 |
+
set_caching_enabled(True)
|
268 |
+
|
269 |
+
|
270 |
+
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
271 |
+
if cfg.sample_packing:
|
272 |
+
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
273 |
+
train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count()).map(
|
274 |
+
add_position_ids, num_proc=os.cpu_count()
|
275 |
+
)
|
276 |
+
if eval_dataset:
|
277 |
+
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count()).map(
|
278 |
+
add_position_ids, num_proc=os.cpu_count()
|
279 |
+
)
|
280 |
+
return train_dataset, eval_dataset
|
281 |
+
|
282 |
+
|
283 |
+
def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
284 |
+
if cfg.sample_packing:
|
285 |
+
# we have to drop anything longer then sequence len otherwise
|
286 |
+
# flash attention with position ids fails
|
287 |
+
if not cfg.total_num_tokens:
|
288 |
+
LOG.info("calculating total_num_tokens")
|
289 |
+
total_num_tokens = np.sum(
|
290 |
+
train_dataset.data.column("input_ids")
|
291 |
+
.to_pandas()
|
292 |
+
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
|
293 |
+
.values
|
294 |
+
)
|
295 |
+
LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
|
296 |
+
cfg.total_num_tokens = total_num_tokens
|
297 |
+
|
298 |
+
if cfg.sample_packing_eff_est:
|
299 |
+
total_num_steps = (
|
300 |
+
# match count to len est in dataloader
|
301 |
+
(
|
302 |
+
math.floor(
|
303 |
+
0.99
|
304 |
+
* cfg.total_num_tokens
|
305 |
+
/ cfg.sample_packing_eff_est
|
306 |
+
/ cfg.sequence_len
|
307 |
+
// cfg.batch_size
|
308 |
+
// int(os.environ.get("WORLD_SIZE", 1))
|
309 |
+
)
|
310 |
+
- 1
|
311 |
+
)
|
312 |
+
* cfg.num_epochs
|
313 |
+
)
|
314 |
+
LOG.info(
|
315 |
+
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
|
316 |
+
)
|
317 |
+
else:
|
318 |
+
sampler = RandomSampler(train_dataset)
|
319 |
+
data_loader = MultipackDistributedDataloader(
|
320 |
+
train_dataset,
|
321 |
+
batch_size=cfg.micro_batch_size,
|
322 |
+
seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len,
|
323 |
+
collate_fn=DataCollatorForSeq2Seq(
|
324 |
+
tokenizer,
|
325 |
+
return_tensors="pt",
|
326 |
+
padding="longest",
|
327 |
+
),
|
328 |
+
sampler=sampler,
|
329 |
+
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
330 |
+
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
331 |
+
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
332 |
+
)
|
333 |
+
data_loader_len = data_loader.len_w_stats()
|
334 |
+
actual_eff = data_loader.efficiency()
|
335 |
+
LOG.info(f"data_loader_len: {data_loader_len}")
|
336 |
+
total_num_steps = int(
|
337 |
+
math.floor(
|
338 |
+
data_loader_len
|
339 |
+
* cfg.micro_batch_size
|
340 |
+
* cfg.num_epochs
|
341 |
+
// cfg.batch_size
|
342 |
+
)
|
343 |
+
)
|
344 |
+
LOG.info(
|
345 |
+
f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`"
|
346 |
+
)
|
347 |
+
cfg.sample_packing_eff_est = math.ceil(actual_eff * 100.0) / 100.0
|
348 |
+
else:
|
349 |
+
total_num_steps = int(
|
350 |
+
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
351 |
+
)
|
352 |
+
LOG.info(f"total_num_steps: {total_num_steps}")
|
353 |
+
return total_num_steps
|
354 |
+
|
355 |
+
|
356 |
+
def setup_fsdp_envs(cfg):
|
357 |
+
os.environ["ACCELERATE_USE_FSDP"] = "true"
|
358 |
+
if cfg.fsdp_config.fsdp_sync_module_states:
|
359 |
+
os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
|
360 |
+
if cfg.fsdp_config.fsdp_state_dict_type:
|
361 |
+
os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type
|
362 |
+
|
363 |
+
|
364 |
+
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
365 |
+
if cfg.fsdp:
|
366 |
+
setup_fsdp_envs(cfg)
|
367 |
warmup_steps = (
|
368 |
cfg.warmup_steps
|
369 |
if cfg.warmup_steps is not None
|
|
|
443 |
if cfg.save_safetensors:
|
444 |
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
|
445 |
|
446 |
+
if cfg.sample_packing_eff_est:
|
447 |
+
training_arguments_kwargs[
|
448 |
+
"sample_packing_efficiency"
|
449 |
+
] = cfg.sample_packing_eff_est
|
450 |
+
|
451 |
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
452 |
+
# max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
|
453 |
+
max_seq_length=cfg.sequence_len,
|
454 |
per_device_train_batch_size=cfg.micro_batch_size,
|
455 |
per_device_eval_batch_size=cfg.eval_batch_size
|
456 |
if cfg.eval_batch_size is not None
|
|
|
464 |
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
|
465 |
save_steps=cfg.save_steps,
|
466 |
output_dir=cfg.output_dir,
|
467 |
+
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
|
468 |
load_best_model_at_end=(
|
469 |
cfg.load_best_model_at_end is not False
|
470 |
and cfg.val_set_size > 0
|
|
|
482 |
if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep")
|
483 |
else "cosine",
|
484 |
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
485 |
+
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
|
486 |
+
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
487 |
**training_arguments_kwargs,
|
488 |
)
|
489 |
|
|
|
578 |
if cfg.collator_pad_to_longest:
|
579 |
data_collator_kwargs["padding"] = "longest"
|
580 |
else:
|
581 |
+
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
582 |
+
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
583 |
+
data_collator_kwargs["pad_to_multiple_of"] = 64
|
584 |
|
585 |
if cfg.is_llama_derived_model and cfg.landmark_attention:
|
|
|
|
|
586 |
from axolotl.monkeypatch.llama_landmark_attn import (
|
587 |
add_mem_tokens,
|
588 |
get_mem_id,
|
|
|
610 |
train_dataset=train_dataset,
|
611 |
eval_dataset=eval_dataset,
|
612 |
args=training_args,
|
613 |
+
data_collator=DataCollatorForSeq2Seq(
|
614 |
tokenizer,
|
615 |
return_tensors="pt",
|
616 |
**data_collator_kwargs,
|
@@ -8,6 +8,19 @@ LOG = logging.getLogger("axolotl")
|
|
8 |
|
9 |
|
10 |
def validate_config(cfg):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
12 |
raise ValueError(
|
13 |
"please set only one of gradient_accumulation_steps or batch_size"
|
@@ -104,6 +117,17 @@ def validate_config(cfg):
|
|
104 |
+ "point to its path, and remove model_revision from the config."
|
105 |
)
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
# TODO
|
108 |
# MPT 7b
|
109 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
|
|
8 |
|
9 |
|
10 |
def validate_config(cfg):
|
11 |
+
if cfg.max_packed_sequence_len and cfg.sample_packing:
|
12 |
+
raise ValueError(
|
13 |
+
"please set only one of max_packed_sequence_len (deprecated soon) or sample_packing"
|
14 |
+
)
|
15 |
+
if cfg.max_packed_sequence_len:
|
16 |
+
LOG.warning(
|
17 |
+
str(
|
18 |
+
PendingDeprecationWarning(
|
19 |
+
"max_packed_sequence_len will be deprecated in favor of sample_packing"
|
20 |
+
)
|
21 |
+
)
|
22 |
+
)
|
23 |
+
|
24 |
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
25 |
raise ValueError(
|
26 |
"please set only one of gradient_accumulation_steps or batch_size"
|
|
|
117 |
+ "point to its path, and remove model_revision from the config."
|
118 |
)
|
119 |
|
120 |
+
if cfg.sample_packing and cfg.sdp_attention:
|
121 |
+
# incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
|
122 |
+
raise ValueError(
|
123 |
+
"sample_packing not compatible with sdp_attention. Use flash_attention"
|
124 |
+
)
|
125 |
+
|
126 |
+
if cfg.sample_packing and cfg.xformers_attention:
|
127 |
+
raise ValueError(
|
128 |
+
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
129 |
+
)
|
130 |
+
|
131 |
# TODO
|
132 |
# MPT 7b
|
133 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Unit tests for the monkeypatch utils
|
3 |
+
"""
|
4 |
+
import unittest
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from axolotl.monkeypatch.utils import get_cu_seqlens, get_cu_seqlens_from_pos_ids
|
9 |
+
|
10 |
+
|
11 |
+
class TestMonkeyPatchUtils(unittest.TestCase):
|
12 |
+
"""
|
13 |
+
Unit test class for monkeypatch utils
|
14 |
+
"""
|
15 |
+
|
16 |
+
def test_get_cu_seqlens_1d(self):
|
17 |
+
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
|
18 |
+
target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32)
|
19 |
+
self.assertTrue(torch.allclose(get_cu_seqlens(attn_mask)[0], target_res))
|
20 |
+
|
21 |
+
def test_get_cu_seqlens_from_pos_ids_1d(self):
|
22 |
+
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0]])
|
23 |
+
target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32)
|
24 |
+
self.assertTrue(
|
25 |
+
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
if __name__ == "__main__":
|
30 |
+
unittest.main()
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Unit tests for the monkey patch for expand mask to handle packed sequences
|
3 |
+
"""
|
4 |
+
import unittest
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from axolotl.monkeypatch.llama_expand_mask import _expand_mask
|
9 |
+
|
10 |
+
|
11 |
+
class TestExpandMask(unittest.TestCase):
|
12 |
+
"""
|
13 |
+
Test class for attention mask expansion for packed sequences
|
14 |
+
"""
|
15 |
+
|
16 |
+
def test_output(self):
|
17 |
+
mask = torch.tensor([[1, 1, 1, 2], [2, 3, 3, 0]])
|
18 |
+
dtype = torch.float32
|
19 |
+
expected_output = torch.tensor(
|
20 |
+
[
|
21 |
+
[
|
22 |
+
[
|
23 |
+
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
24 |
+
[0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
|
25 |
+
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
|
26 |
+
[-3.4028e38, -3.4028e38, -3.4028e38, 0.0000e00],
|
27 |
+
]
|
28 |
+
],
|
29 |
+
[
|
30 |
+
[
|
31 |
+
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
32 |
+
[-3.4028e38, 0.0000e00, -3.4028e38, -3.4028e38],
|
33 |
+
[-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38],
|
34 |
+
[-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38],
|
35 |
+
]
|
36 |
+
],
|
37 |
+
]
|
38 |
+
)
|
39 |
+
# Check that the output matches the expected output
|
40 |
+
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
|
41 |
+
|
42 |
+
|
43 |
+
if __name__ == "__main__":
|
44 |
+
unittest.main()
|
@@ -27,7 +27,7 @@ class TestPacking(unittest.TestCase):
|
|
27 |
}
|
28 |
)
|
29 |
|
30 |
-
def
|
31 |
prompter = AlpacaPrompter("chat")
|
32 |
strat = AlpacaPromptTokenizingStrategy(
|
33 |
prompter,
|
@@ -55,10 +55,14 @@ class TestPacking(unittest.TestCase):
|
|
55 |
# first example doesn't have mask reset
|
56 |
assert example["input_ids"][0] == self.tokenizer.bos_token_id
|
57 |
assert example["attention_mask"][0] == 1
|
|
|
|
|
58 |
|
59 |
# but subsequent one does
|
60 |
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
|
61 |
-
assert example["attention_mask"][next_bos_index] ==
|
|
|
|
|
62 |
|
63 |
|
64 |
if __name__ == "__main__":
|
|
|
27 |
}
|
28 |
)
|
29 |
|
30 |
+
def test_increments_attention(self):
|
31 |
prompter = AlpacaPrompter("chat")
|
32 |
strat = AlpacaPromptTokenizingStrategy(
|
33 |
prompter,
|
|
|
55 |
# first example doesn't have mask reset
|
56 |
assert example["input_ids"][0] == self.tokenizer.bos_token_id
|
57 |
assert example["attention_mask"][0] == 1
|
58 |
+
assert example["position_ids"][0] == 0
|
59 |
+
assert example["position_ids"][1] == 1
|
60 |
|
61 |
# but subsequent one does
|
62 |
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
|
63 |
+
assert example["attention_mask"][next_bos_index] == 2
|
64 |
+
assert example["position_ids"][next_bos_index] == 0
|
65 |
+
assert example["position_ids"][next_bos_index + 1] == 1
|
66 |
|
67 |
|
68 |
if __name__ == "__main__":
|
@@ -134,9 +134,15 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
|
|
134 |
"output": "Hi! How can I help?",
|
135 |
}
|
136 |
example = strat.tokenize_prompt(sample)
|
137 |
-
assert example["input_ids"][0:
|
138 |
-
|
139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
|
142 |
class Llama2ChatTokenizationTest(unittest.TestCase):
|
|
|
134 |
"output": "Hi! How can I help?",
|
135 |
}
|
136 |
example = strat.tokenize_prompt(sample)
|
137 |
+
assert example["input_ids"][0:5] == [
|
138 |
+
1,
|
139 |
+
28962,
|
140 |
+
1254,
|
141 |
+
12665,
|
142 |
+
29901,
|
143 |
+
] # "<s>SYSTEM:"
|
144 |
+
assert example["input_ids"][5:7] == [671, 20118] # " use cot"
|
145 |
+
assert example["input_ids"][8] == 11889 # USER
|
146 |
|
147 |
|
148 |
class Llama2ChatTokenizationTest(unittest.TestCase):
|
@@ -70,7 +70,7 @@ class AlpacaPrompterTest(unittest.TestCase):
|
|
70 |
)
|
71 |
)
|
72 |
assert "use cot" in res
|
73 |
-
assert res.startswith("
|
74 |
assert "### Instruction:" not in res
|
75 |
assert "### Input:" not in res
|
76 |
assert "alpacas" in res
|
|
|
70 |
)
|
71 |
)
|
72 |
assert "use cot" in res
|
73 |
+
assert res.startswith("SYSTEM:")
|
74 |
assert "### Instruction:" not in res
|
75 |
assert "### Input:" not in res
|
76 |
assert "alpacas" in res
|
@@ -313,3 +313,27 @@ class ValidationTest(unittest.TestCase):
|
|
313 |
)
|
314 |
|
315 |
validate_config(cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
)
|
314 |
|
315 |
validate_config(cfg)
|
316 |
+
|
317 |
+
def test_packing(self):
|
318 |
+
cfg = DictDefault(
|
319 |
+
{
|
320 |
+
"max_packed_sequence_len": 2048,
|
321 |
+
}
|
322 |
+
)
|
323 |
+
with self._caplog.at_level(logging.WARNING):
|
324 |
+
validate_config(cfg)
|
325 |
+
assert any(
|
326 |
+
"max_packed_sequence_len will be deprecated in favor of sample_packing"
|
327 |
+
in record.message
|
328 |
+
for record in self._caplog.records
|
329 |
+
)
|
330 |
+
|
331 |
+
cfg = DictDefault(
|
332 |
+
{
|
333 |
+
"max_packed_sequence_len": 2048,
|
334 |
+
"sample_packing": True,
|
335 |
+
}
|
336 |
+
)
|
337 |
+
regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
|
338 |
+
with pytest.raises(ValueError, match=regex_exp):
|
339 |
+
validate_config(cfg)
|