Merge pull request #92 from OpenAccess-AI-Collective/flash-optimum
Browse files- README.md +2 -0
- examples/pythia-12b/README.md +9 -0
- examples/pythia-12b/config.yml +49 -0
- requirements.txt +1 -0
- scripts/finetune.py +35 -12
- src/axolotl/utils/callbacks.py +38 -1
- src/axolotl/utils/data.py +124 -4
- src/axolotl/utils/models.py +16 -3
- src/axolotl/utils/trainer.py +7 -1
- src/axolotl/utils/validation.py +33 -1
- tests/test_validation.py +51 -0
README.md
CHANGED
@@ -421,6 +421,8 @@ optimizer:
|
|
421 |
# specify weight decay
|
422 |
weight_decay:
|
423 |
|
|
|
|
|
424 |
# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
425 |
xformers_attention:
|
426 |
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
|
|
|
421 |
# specify weight decay
|
422 |
weight_decay:
|
423 |
|
424 |
+
# whether to bettertransformers
|
425 |
+
flash_optimum:
|
426 |
# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
427 |
xformers_attention:
|
428 |
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
|
examples/pythia-12b/README.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Pythia 12B
|
2 |
+
|
3 |
+
- Single-GPU A100 only (?)
|
4 |
+
|
5 |
+
```shell
|
6 |
+
python scripts/finetune.py examples/pythia-12b/config.yml
|
7 |
+
```
|
8 |
+
|
9 |
+
⚠️ Multiple-GPU A100 - Doesn't seem to work with multi-gpu without causing OOM! ⚠️
|
examples/pythia-12b/config.yml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_model: EleutherAI/pythia-12b-deduped
|
2 |
+
base_model_config: EleutherAI/pythia-12b-deduped
|
3 |
+
base_model_ignore_patterns: pytorch* # prefer safetensors
|
4 |
+
model_type: GPTNeoXForCausalLM
|
5 |
+
tokenizer_type: AutoTokenizer
|
6 |
+
load_in_8bit: false
|
7 |
+
load_in_4bit: false
|
8 |
+
gptq: false
|
9 |
+
device_map: auto
|
10 |
+
datasets:
|
11 |
+
- path: vicgalle/alpaca-gpt4
|
12 |
+
type: alpaca
|
13 |
+
dataset_prepared_path: last_run_prepared
|
14 |
+
val_set_size: 0.05
|
15 |
+
adapter:
|
16 |
+
lora_model_dir:
|
17 |
+
sequence_len: 2048
|
18 |
+
max_packed_sequence_len: 2048
|
19 |
+
lora_r: 64
|
20 |
+
lora_alpha: 32
|
21 |
+
lora_dropout: 0.0
|
22 |
+
lora_target_modules:
|
23 |
+
lora_target_linear: true
|
24 |
+
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
25 |
+
wandb_project:
|
26 |
+
wandb_watch:
|
27 |
+
wandb_run_id:
|
28 |
+
wandb_log_model:
|
29 |
+
output_dir: ./pythia-12b
|
30 |
+
gradient_accumulation_steps: 1
|
31 |
+
micro_batch_size: 1
|
32 |
+
num_epochs: 5
|
33 |
+
learning_rate: 0.00003
|
34 |
+
optimizer: adamw_bnb_8bit
|
35 |
+
lr_scheduler: cosine
|
36 |
+
train_on_inputs: false
|
37 |
+
group_by_length: false
|
38 |
+
bf16: false
|
39 |
+
fp16: false
|
40 |
+
float16: true
|
41 |
+
tf32: true
|
42 |
+
flash_optimum: true
|
43 |
+
early_stopping_patience:
|
44 |
+
resume_from_checkpoint:
|
45 |
+
local_rank:
|
46 |
+
gradient_checkpointing: true
|
47 |
+
fsdp:
|
48 |
+
fsdp_config:
|
49 |
+
collator_pad_to_longest: true
|
requirements.txt
CHANGED
@@ -11,6 +11,7 @@ sentencepiece
|
|
11 |
wandb
|
12 |
einops
|
13 |
xformers
|
|
|
14 |
# qlora things
|
15 |
bert-score==0.3.13
|
16 |
evaluate==0.4.0
|
|
|
11 |
wandb
|
12 |
einops
|
13 |
xformers
|
14 |
+
optimum
|
15 |
# qlora things
|
16 |
bert-score==0.3.13
|
17 |
evaluate==0.4.0
|
scripts/finetune.py
CHANGED
@@ -12,13 +12,14 @@ from typing import Any, Dict, List, Optional, Union
|
|
12 |
import fire
|
13 |
import torch
|
14 |
import yaml
|
|
|
|
|
|
|
15 |
from transformers import GenerationConfig, TextStreamer
|
16 |
|
17 |
-
from axolotl.utils.data import load_prepare_datasets
|
18 |
from axolotl.utils.dict import DictDefault
|
19 |
from axolotl.utils.models import load_model, load_tokenizer
|
20 |
-
|
21 |
-
# add src to the pythonpath so we don't need to pip install this
|
22 |
from axolotl.utils.tokenization import check_dataset_labels
|
23 |
from axolotl.utils.trainer import setup_trainer
|
24 |
from axolotl.utils.validation import validate_config
|
@@ -217,9 +218,20 @@ def train(
|
|
217 |
if (
|
218 |
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
|
219 |
): # don't need to load dataset for these
|
220 |
-
|
221 |
-
|
222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
|
224 |
if cfg.debug or "debug" in kwargs:
|
225 |
logging.info("check_dataset_labels...")
|
@@ -285,12 +297,15 @@ def train(
|
|
285 |
|
286 |
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
287 |
if cfg.local_rank == 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
signal.signal(
|
289 |
-
signal.SIGINT,
|
290 |
-
lambda signal, frame: (
|
291 |
-
model.save_pretrained(cfg.output_dir),
|
292 |
-
sys.exit(0),
|
293 |
-
),
|
294 |
)
|
295 |
|
296 |
logging.info("Starting trainer...")
|
@@ -313,13 +328,21 @@ def train(
|
|
313 |
|
314 |
if not Path(cfg.output_dir).is_dir():
|
315 |
os.makedirs(cfg.output_dir, exist_ok=True)
|
316 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
|
318 |
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
319 |
|
320 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
321 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
322 |
if cfg.local_rank == 0:
|
|
|
|
|
323 |
model.save_pretrained(cfg.output_dir)
|
324 |
|
325 |
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
|
|
12 |
import fire
|
13 |
import torch
|
14 |
import yaml
|
15 |
+
|
16 |
+
# add src to the pythonpath so we don't need to pip install this
|
17 |
+
from optimum.bettertransformer import BetterTransformer
|
18 |
from transformers import GenerationConfig, TextStreamer
|
19 |
|
20 |
+
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
21 |
from axolotl.utils.dict import DictDefault
|
22 |
from axolotl.utils.models import load_model, load_tokenizer
|
|
|
|
|
23 |
from axolotl.utils.tokenization import check_dataset_labels
|
24 |
from axolotl.utils.trainer import setup_trainer
|
25 |
from axolotl.utils.validation import validate_config
|
|
|
218 |
if (
|
219 |
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
|
220 |
): # don't need to load dataset for these
|
221 |
+
if not cfg.pretraining_dataset:
|
222 |
+
train_dataset, eval_dataset = load_prepare_datasets(
|
223 |
+
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
224 |
+
)
|
225 |
+
else:
|
226 |
+
train_dataset = load_pretraining_dataset(
|
227 |
+
cfg.pretraining_dataset,
|
228 |
+
tokenizer,
|
229 |
+
max_tokens=cfg.sequence_len,
|
230 |
+
seed=cfg.seed,
|
231 |
+
)
|
232 |
+
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
233 |
+
train_dataset = train_dataset.with_format("torch")
|
234 |
+
eval_dataset = None
|
235 |
|
236 |
if cfg.debug or "debug" in kwargs:
|
237 |
logging.info("check_dataset_labels...")
|
|
|
297 |
|
298 |
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
299 |
if cfg.local_rank == 0:
|
300 |
+
|
301 |
+
def terminate_handler(_, __, model):
|
302 |
+
if cfg.flash_optimum:
|
303 |
+
model = BetterTransformer.reverse(model)
|
304 |
+
model.save_pretrained(cfg.output_dir)
|
305 |
+
sys.exit(0)
|
306 |
+
|
307 |
signal.signal(
|
308 |
+
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
|
|
|
|
|
|
|
|
|
309 |
)
|
310 |
|
311 |
logging.info("Starting trainer...")
|
|
|
328 |
|
329 |
if not Path(cfg.output_dir).is_dir():
|
330 |
os.makedirs(cfg.output_dir, exist_ok=True)
|
331 |
+
if cfg.flash_optimum:
|
332 |
+
with torch.backends.cuda.sdp_kernel(
|
333 |
+
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
334 |
+
):
|
335 |
+
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
336 |
+
else:
|
337 |
+
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
338 |
|
339 |
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
340 |
|
341 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
342 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
343 |
if cfg.local_rank == 0:
|
344 |
+
if cfg.flash_optimum:
|
345 |
+
model = BetterTransformer.reverse(model)
|
346 |
model.save_pretrained(cfg.output_dir)
|
347 |
|
348 |
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
src/axolotl/utils/callbacks.py
CHANGED
@@ -2,13 +2,14 @@
|
|
2 |
|
3 |
import os
|
4 |
|
|
|
5 |
from transformers import (
|
6 |
TrainerCallback,
|
7 |
TrainerControl,
|
8 |
TrainerState,
|
9 |
TrainingArguments,
|
10 |
)
|
11 |
-
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
12 |
|
13 |
|
14 |
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
@@ -30,3 +31,39 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
|
|
30 |
kwargs["model"].save_pretrained(peft_model_path)
|
31 |
|
32 |
return control
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
import os
|
4 |
|
5 |
+
from optimum.bettertransformer import BetterTransformer
|
6 |
from transformers import (
|
7 |
TrainerCallback,
|
8 |
TrainerControl,
|
9 |
TrainerState,
|
10 |
TrainingArguments,
|
11 |
)
|
12 |
+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
13 |
|
14 |
|
15 |
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
|
|
31 |
kwargs["model"].save_pretrained(peft_model_path)
|
32 |
|
33 |
return control
|
34 |
+
|
35 |
+
|
36 |
+
class SaveBetterTransformerModelCallback(
|
37 |
+
TrainerCallback
|
38 |
+
): # pylint: disable=too-few-public-methods
|
39 |
+
"""Callback to save the BetterTransformer wrapped model"""
|
40 |
+
|
41 |
+
def on_step_end(
|
42 |
+
self,
|
43 |
+
args: TrainingArguments,
|
44 |
+
state: TrainerState,
|
45 |
+
control: TrainerControl,
|
46 |
+
**kwargs,
|
47 |
+
):
|
48 |
+
# Save
|
49 |
+
if (
|
50 |
+
args.save_strategy == IntervalStrategy.STEPS
|
51 |
+
and args.save_steps > 0
|
52 |
+
and state.global_step % args.save_steps == 0
|
53 |
+
):
|
54 |
+
control.should_save = True
|
55 |
+
|
56 |
+
if control.should_save:
|
57 |
+
checkpoint_folder = os.path.join(
|
58 |
+
args.output_dir,
|
59 |
+
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
60 |
+
)
|
61 |
+
|
62 |
+
model = BetterTransformer.reverse(kwargs["model"])
|
63 |
+
model.save_pretrained(checkpoint_folder)
|
64 |
+
# FIXME - need to cleanup old checkpoints
|
65 |
+
|
66 |
+
# since we're saving here, we don't need the trainer loop to attempt to save too b/c
|
67 |
+
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
|
68 |
+
control.should_save = False
|
69 |
+
return control
|
src/axolotl/utils/data.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
"""Module containing data utilities"""
|
2 |
-
|
3 |
import logging
|
4 |
from hashlib import md5
|
5 |
from pathlib import Path
|
6 |
from typing import List, Tuple, Union
|
7 |
|
|
|
8 |
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
from transformers import PreTrainedTokenizerBase
|
@@ -394,8 +395,127 @@ def load_prepare_datasets(
|
|
394 |
index=cfg.dataset_shard_idx,
|
395 |
)
|
396 |
|
397 |
-
|
398 |
-
|
399 |
-
|
|
|
|
|
|
|
|
|
400 |
|
401 |
return train_dataset, eval_dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
"""Module containing data utilities"""
|
2 |
+
import functools
|
3 |
import logging
|
4 |
from hashlib import md5
|
5 |
from pathlib import Path
|
6 |
from typing import List, Tuple, Union
|
7 |
|
8 |
+
import torch
|
9 |
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
from transformers import PreTrainedTokenizerBase
|
|
|
395 |
index=cfg.dataset_shard_idx,
|
396 |
)
|
397 |
|
398 |
+
if cfg.val_set_size:
|
399 |
+
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
|
400 |
+
train_dataset = dataset["train"]
|
401 |
+
eval_dataset = dataset["test"]
|
402 |
+
else:
|
403 |
+
train_dataset = dataset
|
404 |
+
eval_dataset = None
|
405 |
|
406 |
return train_dataset, eval_dataset
|
407 |
+
|
408 |
+
|
409 |
+
def encode_pretraining(tokenizer, max_tokens, examples):
|
410 |
+
res = tokenizer(
|
411 |
+
examples["text"],
|
412 |
+
truncation=True,
|
413 |
+
max_length=max_tokens - 2,
|
414 |
+
add_special_tokens=True,
|
415 |
+
)
|
416 |
+
# Convert to PyTorch tensors
|
417 |
+
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
|
418 |
+
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
|
419 |
+
new_input_ids = []
|
420 |
+
new_attention_mask = []
|
421 |
+
# Append EOS and PAD tokens to input_ids, and correct attention_mask
|
422 |
+
for i, _ in enumerate(input_ids):
|
423 |
+
input_ids[i] = torch.cat(
|
424 |
+
(
|
425 |
+
input_ids[i],
|
426 |
+
torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]),
|
427 |
+
),
|
428 |
+
dim=0,
|
429 |
+
)
|
430 |
+
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
|
431 |
+
|
432 |
+
# Concatenate tokens so that their lengths are less than max_tokens
|
433 |
+
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
434 |
+
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
435 |
+
|
436 |
+
for ids, mask in zip(input_ids, attention_mask):
|
437 |
+
if buffer_input_ids.numel() == max_tokens:
|
438 |
+
new_input_ids.append(buffer_input_ids)
|
439 |
+
new_attention_mask.append(buffer_attention_mask)
|
440 |
+
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
441 |
+
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
442 |
+
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
443 |
+
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
444 |
+
elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
|
445 |
+
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
446 |
+
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
447 |
+
else:
|
448 |
+
buffer_input_ids = torch.cat(
|
449 |
+
(
|
450 |
+
buffer_input_ids,
|
451 |
+
torch.full(
|
452 |
+
(max_tokens - buffer_input_ids.numel(),),
|
453 |
+
tokenizer.pad_token_id,
|
454 |
+
dtype=torch.long,
|
455 |
+
),
|
456 |
+
),
|
457 |
+
dim=0,
|
458 |
+
)
|
459 |
+
buffer_attention_mask = torch.cat(
|
460 |
+
(
|
461 |
+
buffer_attention_mask,
|
462 |
+
torch.full(
|
463 |
+
(max_tokens - buffer_attention_mask.numel(),),
|
464 |
+
0,
|
465 |
+
dtype=torch.long,
|
466 |
+
),
|
467 |
+
),
|
468 |
+
dim=0,
|
469 |
+
)
|
470 |
+
new_input_ids.append(buffer_input_ids)
|
471 |
+
new_attention_mask.append(buffer_attention_mask)
|
472 |
+
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
473 |
+
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
474 |
+
|
475 |
+
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
476 |
+
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
477 |
+
|
478 |
+
if buffer_input_ids.numel() > 0: # for any leftover tokens
|
479 |
+
while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size
|
480 |
+
buffer_input_ids = torch.cat(
|
481 |
+
(
|
482 |
+
buffer_input_ids,
|
483 |
+
torch.full(
|
484 |
+
(max_tokens - buffer_input_ids.numel(),),
|
485 |
+
tokenizer.pad_token_id,
|
486 |
+
dtype=torch.long,
|
487 |
+
),
|
488 |
+
),
|
489 |
+
dim=0,
|
490 |
+
)
|
491 |
+
buffer_attention_mask = torch.cat(
|
492 |
+
(
|
493 |
+
buffer_attention_mask,
|
494 |
+
torch.full(
|
495 |
+
(max_tokens - buffer_attention_mask.numel(),),
|
496 |
+
0,
|
497 |
+
dtype=torch.long,
|
498 |
+
),
|
499 |
+
),
|
500 |
+
dim=0,
|
501 |
+
)
|
502 |
+
new_input_ids.append(buffer_input_ids)
|
503 |
+
new_attention_mask.append(buffer_attention_mask)
|
504 |
+
|
505 |
+
ret = {
|
506 |
+
"input_ids": [seq.tolist() for seq in new_input_ids],
|
507 |
+
"labels": [seq.tolist() for seq in new_input_ids],
|
508 |
+
"attention_mask": [seq.tolist() for seq in new_attention_mask],
|
509 |
+
}
|
510 |
+
|
511 |
+
logging.debug(len(ret["input_ids"]))
|
512 |
+
return ret
|
513 |
+
|
514 |
+
|
515 |
+
def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
|
516 |
+
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
517 |
+
dataset = load_dataset(path, streaming=True, split="train")
|
518 |
+
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
|
519 |
+
# TODO dynamically figure out which columns/features to remove
|
520 |
+
dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
|
521 |
+
return dataset
|
src/axolotl/utils/models.py
CHANGED
@@ -10,8 +10,9 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
|
|
10 |
import bitsandbytes as bnb
|
11 |
import torch
|
12 |
import transformers
|
|
|
13 |
from transformers import PreTrainedModel # noqa: F401
|
14 |
-
from transformers import (
|
15 |
AutoConfig,
|
16 |
AutoModelForCausalLM,
|
17 |
AutoTokenizer,
|
@@ -121,9 +122,9 @@ def load_model(
|
|
121 |
logging.info("patching with xpos rope")
|
122 |
replace_llama_rope_with_xpos_rope()
|
123 |
|
124 |
-
if cfg.bf16:
|
125 |
torch_dtype = torch.bfloat16
|
126 |
-
elif cfg.load_in_8bit or cfg.fp16:
|
127 |
torch_dtype = torch.float16
|
128 |
else:
|
129 |
torch_dtype = torch.float32
|
@@ -287,6 +288,15 @@ def load_model(
|
|
287 |
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
288 |
model.resize_token_embeddings(embeddings_len)
|
289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
if not cfg.gptq and (
|
291 |
(cfg.adapter == "lora" and load_in_8bit)
|
292 |
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
@@ -332,6 +342,9 @@ def load_model(
|
|
332 |
logging.warning("there are no parameters that require gradient updates")
|
333 |
model.config.use_cache = False
|
334 |
|
|
|
|
|
|
|
335 |
# TODO resume_from_checkpoint handling
|
336 |
return model, lora_config
|
337 |
|
|
|
10 |
import bitsandbytes as bnb
|
11 |
import torch
|
12 |
import transformers
|
13 |
+
from optimum.bettertransformer import BetterTransformer
|
14 |
from transformers import PreTrainedModel # noqa: F401
|
15 |
+
from transformers import (
|
16 |
AutoConfig,
|
17 |
AutoModelForCausalLM,
|
18 |
AutoTokenizer,
|
|
|
122 |
logging.info("patching with xpos rope")
|
123 |
replace_llama_rope_with_xpos_rope()
|
124 |
|
125 |
+
if cfg.bf16 or cfg.bfloat16:
|
126 |
torch_dtype = torch.bfloat16
|
127 |
+
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
128 |
torch_dtype = torch.float16
|
129 |
else:
|
130 |
torch_dtype = torch.float32
|
|
|
288 |
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
289 |
model.resize_token_embeddings(embeddings_len)
|
290 |
|
291 |
+
if (
|
292 |
+
hasattr(model.config, "max_position_embeddings")
|
293 |
+
and cfg.sequence_len >= model.config.max_position_embeddings
|
294 |
+
):
|
295 |
+
logging.warning(
|
296 |
+
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
|
297 |
+
)
|
298 |
+
model.config.max_position_embeddings = cfg.sequence_len
|
299 |
+
|
300 |
if not cfg.gptq and (
|
301 |
(cfg.adapter == "lora" and load_in_8bit)
|
302 |
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
|
|
342 |
logging.warning("there are no parameters that require gradient updates")
|
343 |
model.config.use_cache = False
|
344 |
|
345 |
+
if cfg.flash_optimum:
|
346 |
+
model = BetterTransformer.transform(model)
|
347 |
+
|
348 |
# TODO resume_from_checkpoint handling
|
349 |
return model, lora_config
|
350 |
|
src/axolotl/utils/trainer.py
CHANGED
@@ -16,7 +16,10 @@ from torch.optim.lr_scheduler import OneCycleLR
|
|
16 |
from transformers import EarlyStoppingCallback, Trainer
|
17 |
from transformers.trainer_pt_utils import get_parameter_names
|
18 |
|
19 |
-
from axolotl.utils.callbacks import
|
|
|
|
|
|
|
20 |
from axolotl.utils.schedulers import InterpolatingLogScheduler
|
21 |
|
22 |
|
@@ -228,6 +231,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
228 |
]: # only save in rank 0
|
229 |
callbacks.append(SavePeftModelCallback)
|
230 |
|
|
|
|
|
|
|
231 |
data_collator_kwargs = {
|
232 |
"padding": True,
|
233 |
}
|
|
|
16 |
from transformers import EarlyStoppingCallback, Trainer
|
17 |
from transformers.trainer_pt_utils import get_parameter_names
|
18 |
|
19 |
+
from axolotl.utils.callbacks import (
|
20 |
+
SaveBetterTransformerModelCallback,
|
21 |
+
SavePeftModelCallback,
|
22 |
+
)
|
23 |
from axolotl.utils.schedulers import InterpolatingLogScheduler
|
24 |
|
25 |
|
|
|
231 |
]: # only save in rank 0
|
232 |
callbacks.append(SavePeftModelCallback)
|
233 |
|
234 |
+
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
|
235 |
+
callbacks.append(SaveBetterTransformerModelCallback)
|
236 |
+
|
237 |
data_collator_kwargs = {
|
238 |
"padding": True,
|
239 |
}
|
src/axolotl/utils/validation.py
CHANGED
@@ -2,6 +2,8 @@
|
|
2 |
|
3 |
import logging
|
4 |
|
|
|
|
|
5 |
|
6 |
def validate_config(cfg):
|
7 |
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
@@ -62,7 +64,37 @@ def validate_config(cfg):
|
|
62 |
) and cfg.gradient_checkpointing:
|
63 |
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
# TODO
|
66 |
# MPT 7b
|
67 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
68 |
-
# no 8bit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
import logging
|
4 |
|
5 |
+
import torch
|
6 |
+
|
7 |
|
8 |
def validate_config(cfg):
|
9 |
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
|
|
64 |
) and cfg.gradient_checkpointing:
|
65 |
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
66 |
|
67 |
+
if cfg.flash_optimum is True:
|
68 |
+
if cfg.adapter:
|
69 |
+
logging.warning(
|
70 |
+
"BetterTransformers probably doesn't work with PEFT adapters"
|
71 |
+
)
|
72 |
+
if cfg.fp16 or cfg.bf16:
|
73 |
+
raise ValueError("AMP is not supported with BetterTransformer")
|
74 |
+
if cfg.float16 is not True and cfg.bloat16 is not True:
|
75 |
+
logging.warning(
|
76 |
+
"You should probably set bfloat16 or float16 to true to "
|
77 |
+
"load the model in float16 for BetterTransformers"
|
78 |
+
)
|
79 |
+
if int(torch.__version__.split(".")[0]) < 2:
|
80 |
+
logging.warning("torch>=2.0.0 required")
|
81 |
+
raise ValueError(
|
82 |
+
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
|
83 |
+
)
|
84 |
+
|
85 |
+
if cfg.pretraining_dataset and cfg.group_by_length:
|
86 |
+
logging.warning(
|
87 |
+
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
88 |
+
)
|
89 |
+
|
90 |
# TODO
|
91 |
# MPT 7b
|
92 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
93 |
+
# no 8bit adaAmw w bf16
|
94 |
+
|
95 |
+
# GPT-NeoX
|
96 |
+
# evals broken when extending context len
|
97 |
+
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
98 |
+
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
|
99 |
+
# attention_mask = causal_mask + attention_mask
|
100 |
+
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3
|
tests/test_validation.py
CHANGED
@@ -212,3 +212,54 @@ class ValidationTest(unittest.TestCase):
|
|
212 |
|
213 |
with pytest.raises(ValueError, match=regex_exp):
|
214 |
validate_config(cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
with pytest.raises(ValueError, match=regex_exp):
|
214 |
validate_config(cfg)
|
215 |
+
|
216 |
+
def test_flash_optimum(self):
|
217 |
+
cfg = DictDefault(
|
218 |
+
{
|
219 |
+
"flash_optimum": True,
|
220 |
+
"adapter": "lora",
|
221 |
+
}
|
222 |
+
)
|
223 |
+
|
224 |
+
with self._caplog.at_level(logging.WARNING):
|
225 |
+
validate_config(cfg)
|
226 |
+
assert any(
|
227 |
+
"BetterTransformers probably doesn't work with PEFT adapters"
|
228 |
+
in record.message
|
229 |
+
for record in self._caplog.records
|
230 |
+
)
|
231 |
+
|
232 |
+
cfg = DictDefault(
|
233 |
+
{
|
234 |
+
"flash_optimum": True,
|
235 |
+
}
|
236 |
+
)
|
237 |
+
|
238 |
+
with self._caplog.at_level(logging.WARNING):
|
239 |
+
validate_config(cfg)
|
240 |
+
assert any(
|
241 |
+
"probably set bfloat16 or float16" in record.message
|
242 |
+
for record in self._caplog.records
|
243 |
+
)
|
244 |
+
|
245 |
+
cfg = DictDefault(
|
246 |
+
{
|
247 |
+
"flash_optimum": True,
|
248 |
+
"fp16": True,
|
249 |
+
}
|
250 |
+
)
|
251 |
+
regex_exp = r".*AMP is not supported.*"
|
252 |
+
|
253 |
+
with pytest.raises(ValueError, match=regex_exp):
|
254 |
+
validate_config(cfg)
|
255 |
+
|
256 |
+
cfg = DictDefault(
|
257 |
+
{
|
258 |
+
"flash_optimum": True,
|
259 |
+
"bf16": True,
|
260 |
+
}
|
261 |
+
)
|
262 |
+
regex_exp = r".*AMP is not supported.*"
|
263 |
+
|
264 |
+
with pytest.raises(ValueError, match=regex_exp):
|
265 |
+
validate_config(cfg)
|