tweaks to data loading, 8 bit adam, accelerate and deepspeed
Browse files- configs/llama_13B_alpaca.yml +39 -0
- src/axolotl/utils/data.py +19 -10
- src/axolotl/utils/models.py +13 -6
- src/axolotl/utils/trainer.py +16 -3
configs/llama_13B_alpaca.yml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_model: huggyllama/llama-13b
|
2 |
+
model_type: LlamaForCausalLM
|
3 |
+
tokenizer_type: LlamaTokenizer
|
4 |
+
load_in_8bit: true
|
5 |
+
datasets:
|
6 |
+
- path: anon8231489123/ShareGPT_Vicuna_unfiltered
|
7 |
+
data_files: ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json
|
8 |
+
type: sharegpt
|
9 |
+
dataset_prepared_path: last_run_prepared
|
10 |
+
val_set_size: 0.002
|
11 |
+
adapter:
|
12 |
+
lora_model_dir:
|
13 |
+
sequence_len: 2048
|
14 |
+
lora_r: 8
|
15 |
+
lora_alpha: 16
|
16 |
+
lora_dropout: 0.05
|
17 |
+
lora_target_modules:
|
18 |
+
- q_proj
|
19 |
+
- v_proj
|
20 |
+
lora_fan_in_fan_out: false
|
21 |
+
wandb_project:
|
22 |
+
wandb_watch:
|
23 |
+
wandb_run_id:
|
24 |
+
wandb_log_model: checkpoint
|
25 |
+
output_dir: ./llama-13b-sharegpt
|
26 |
+
batch_size: 64
|
27 |
+
micro_batch_size: 2
|
28 |
+
warmup_steps: 1000
|
29 |
+
save_steps:
|
30 |
+
eval_steps:
|
31 |
+
num_epochs: 5
|
32 |
+
learning_rate: 0.00003
|
33 |
+
train_on_inputs: false
|
34 |
+
group_by_length: false
|
35 |
+
bf16: true
|
36 |
+
tf32: true
|
37 |
+
early_stopping_patience: 5
|
38 |
+
resume_from_checkpoint:
|
39 |
+
local_rank:
|
src/axolotl/utils/data.py
CHANGED
@@ -2,7 +2,7 @@ import logging
|
|
2 |
from hashlib import md5
|
3 |
from pathlib import Path
|
4 |
|
5 |
-
from datasets import load_from_disk, load_dataset, IterableDataset, Dataset
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
|
8 |
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
|
@@ -44,10 +44,11 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
44 |
)
|
45 |
|
46 |
if any(prepared_ds_path.glob("*")):
|
47 |
-
logging.info("Loading prepared dataset from disk...")
|
48 |
dataset = load_from_disk(str(prepared_ds_path))
|
49 |
logging.info("Prepared dataset loaded from disk...")
|
50 |
else:
|
|
|
51 |
logging.info("Loading raw datasets...")
|
52 |
datasets = []
|
53 |
for d in cfg.datasets:
|
@@ -113,18 +114,26 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
113 |
datasets.append(ds_wrapper)
|
114 |
else:
|
115 |
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
|
116 |
-
|
117 |
-
tokenizer,
|
118 |
-
datasets,
|
119 |
-
seq_length=max_packed_sequence_len,
|
120 |
-
)
|
121 |
-
logging.info("merging, packing, shuffling, and splitting master dataset")
|
122 |
-
dataset = Dataset.from_list([_ for _ in constant_len_dataset]).shuffle(seed=42)
|
123 |
|
|
|
124 |
if cfg.local_rank == 0:
|
125 |
-
logging.info(f"Saving prepared dataset to disk... {prepared_ds_path}")
|
126 |
dataset.save_to_disk(prepared_ds_path)
|
127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
dataset = dataset.train_test_split(
|
129 |
test_size=cfg.val_set_size, shuffle=False
|
130 |
)
|
|
|
2 |
from hashlib import md5
|
3 |
from pathlib import Path
|
4 |
|
5 |
+
from datasets import load_from_disk, load_dataset, IterableDataset, Dataset, concatenate_datasets
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
|
8 |
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
|
|
|
44 |
)
|
45 |
|
46 |
if any(prepared_ds_path.glob("*")):
|
47 |
+
logging.info(f"Loading prepared dataset from disk ay {prepared_ds_path}...")
|
48 |
dataset = load_from_disk(str(prepared_ds_path))
|
49 |
logging.info("Prepared dataset loaded from disk...")
|
50 |
else:
|
51 |
+
logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
|
52 |
logging.info("Loading raw datasets...")
|
53 |
datasets = []
|
54 |
for d in cfg.datasets:
|
|
|
114 |
datasets.append(ds_wrapper)
|
115 |
else:
|
116 |
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
|
117 |
+
logging.info("merging and shuffling master dataset")
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
+
dataset = concatenate_datasets(datasets).shuffle(seed=42)
|
120 |
if cfg.local_rank == 0:
|
121 |
+
logging.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
122 |
dataset.save_to_disk(prepared_ds_path)
|
123 |
|
124 |
+
if cfg.max_packed_sequence_len is not None:
|
125 |
+
constant_len_dataset = ConstantLengthDataset(
|
126 |
+
tokenizer,
|
127 |
+
[dataset],
|
128 |
+
seq_length=max_packed_sequence_len,
|
129 |
+
)
|
130 |
+
logging.info("packing master dataset")
|
131 |
+
dataset = Dataset.from_list([_ for _ in constant_len_dataset])
|
132 |
+
|
133 |
+
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
134 |
+
logging.info(f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards")
|
135 |
+
dataset = dataset.shard(num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx)
|
136 |
+
|
137 |
dataset = dataset.train_test_split(
|
138 |
test_size=cfg.val_set_size, shuffle=False
|
139 |
)
|
src/axolotl/utils/models.py
CHANGED
@@ -101,12 +101,19 @@ def load_model(
|
|
101 |
)
|
102 |
load_in_8bit = False
|
103 |
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
elif model_type:
|
111 |
model = getattr(transformers, model_type).from_pretrained(
|
112 |
base_model,
|
|
|
101 |
)
|
102 |
load_in_8bit = False
|
103 |
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
|
104 |
+
if not cfg.load_in_8bit:
|
105 |
+
model = LlamaForCausalLM.from_pretrained(
|
106 |
+
base_model,
|
107 |
+
device_map=cfg.device_map,
|
108 |
+
)
|
109 |
+
else:
|
110 |
+
model = LlamaForCausalLM.from_pretrained(
|
111 |
+
base_model,
|
112 |
+
load_in_8bit=cfg.load_in_8bit,
|
113 |
+
torch_dtype=torch_dtype,
|
114 |
+
device_map=cfg.device_map,
|
115 |
+
)
|
116 |
+
|
117 |
elif model_type:
|
118 |
model = getattr(transformers, model_type).from_pretrained(
|
119 |
base_model,
|
src/axolotl/utils/trainer.py
CHANGED
@@ -1,5 +1,9 @@
|
|
1 |
import math
|
|
|
|
|
|
|
2 |
import bitsandbytes as bnb
|
|
|
3 |
import transformers
|
4 |
from torch import nn
|
5 |
from torch.optim.lr_scheduler import OneCycleLR
|
@@ -12,7 +16,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
12 |
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
13 |
)
|
14 |
warmup_steps = cfg.warmup_steps if cfg.warmup_steps else min(int(0.03 * total_num_steps), 100)
|
15 |
-
logging_steps = max(min(int(0.005 * total_num_steps), 10), 1)
|
16 |
save_steps = eval_steps = cfg.save_steps if cfg.save_steps else min(int(0.05 * total_num_steps), 200)
|
17 |
|
18 |
training_arguments_kwargs = {}
|
@@ -26,6 +30,15 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
26 |
if cfg.gradient_checkpointing is not None:
|
27 |
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
training_args = transformers.TrainingArguments(
|
30 |
per_device_train_batch_size=cfg.micro_batch_size,
|
31 |
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
@@ -37,7 +50,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
37 |
save_steps=save_steps,
|
38 |
output_dir=cfg.output_dir,
|
39 |
save_total_limit=3,
|
40 |
-
load_best_model_at_end=True if cfg.val_set_size > 0 else False,
|
41 |
ddp_find_unused_parameters=False if cfg.ddp else None,
|
42 |
group_by_length=cfg.group_by_length,
|
43 |
report_to="wandb" if cfg.use_wandb else None,
|
@@ -47,7 +60,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
47 |
|
48 |
trainer_kwargs = {}
|
49 |
|
50 |
-
if cfg.
|
51 |
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
52 |
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
53 |
optimizer_grouped_parameters = [
|
|
|
1 |
import math
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
import bitsandbytes as bnb
|
6 |
+
import torch.cuda
|
7 |
import transformers
|
8 |
from torch import nn
|
9 |
from torch.optim.lr_scheduler import OneCycleLR
|
|
|
16 |
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
17 |
)
|
18 |
warmup_steps = cfg.warmup_steps if cfg.warmup_steps else min(int(0.03 * total_num_steps), 100)
|
19 |
+
logging_steps = cfg.logging_steps if cfg.logging_steps else max(min(int(0.005 * total_num_steps), 10), 1)
|
20 |
save_steps = eval_steps = cfg.save_steps if cfg.save_steps else min(int(0.05 * total_num_steps), 200)
|
21 |
|
22 |
training_arguments_kwargs = {}
|
|
|
30 |
if cfg.gradient_checkpointing is not None:
|
31 |
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
|
32 |
|
33 |
+
# deepspeed
|
34 |
+
if os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true" and torch.cuda.device_count() > 1:
|
35 |
+
if cfg.deepspeed:
|
36 |
+
training_arguments_kwargs["deepspeed"] = cfg.deepspeed
|
37 |
+
else:
|
38 |
+
# make a guess here
|
39 |
+
# TODO search Path("./") for one
|
40 |
+
training_arguments_kwargs["deepspeed"] = "./ds_config.json"
|
41 |
+
|
42 |
training_args = transformers.TrainingArguments(
|
43 |
per_device_train_batch_size=cfg.micro_batch_size,
|
44 |
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
|
|
50 |
save_steps=save_steps,
|
51 |
output_dir=cfg.output_dir,
|
52 |
save_total_limit=3,
|
53 |
+
load_best_model_at_end=True if cfg.val_set_size > 0 and save_steps % eval_steps == 0 else False,
|
54 |
ddp_find_unused_parameters=False if cfg.ddp else None,
|
55 |
group_by_length=cfg.group_by_length,
|
56 |
report_to="wandb" if cfg.use_wandb else None,
|
|
|
60 |
|
61 |
trainer_kwargs = {}
|
62 |
|
63 |
+
if cfg.optimizer == "adam8bit" and not cfg.load_4bit and not "deepspeed" in training_arguments_kwargs:
|
64 |
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
65 |
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
66 |
optimizer_grouped_parameters = [
|