Merge pull request #9 from winglian/dev
Browse files- FAQS.md +4 -0
- configs/galactica_1_3B.yml +41 -0
- configs/llama_13B_alpaca.yml +39 -0
- configs/llama_65B_alpaca.yml +4 -1
- configs/stability_3b.yml +33 -0
- ds_config.json +5 -5
- scripts/finetune.py +1 -1
- src/axolotl/datasets.py +9 -5
- src/axolotl/prompters.py +4 -0
- src/axolotl/utils/data.py +33 -12
- src/axolotl/utils/models.py +31 -8
- src/axolotl/utils/trainer.py +26 -4
FAQS.md
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# FAQs
|
2 |
+
|
3 |
+
- Can you train StableLM with this? Yes, but only with a single GPU atm. Multi GPU support is coming soon! Just waiting on this [PR](https://github.com/huggingface/transformers/pull/22874)
|
4 |
+
- Will this work with Deepspeed? That's still a WIP, but setting `export ACCELERATE_USE_DEEPSPEED=true` should work in some cases
|
configs/galactica_1_3B.yml
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_model: facebook/galactica-1.3b
|
2 |
+
model_type: AutoModelForCausalLM
|
3 |
+
tokenizer_type: AutoTokenizer
|
4 |
+
load_in_8bit: false
|
5 |
+
datasets:
|
6 |
+
- path: tatsu-lab/alpaca
|
7 |
+
type: alpaca
|
8 |
+
dataset_prepared_path: last_run_prepared
|
9 |
+
val_set_size: 0.1
|
10 |
+
adapter:
|
11 |
+
lora_model_dir:
|
12 |
+
sequence_len: 1024
|
13 |
+
max_packed_sequence_len: 1024
|
14 |
+
lora_r: 8
|
15 |
+
lora_alpha: 16
|
16 |
+
lora_dropout: 0.05
|
17 |
+
lora_target_modules:
|
18 |
+
- q_proj
|
19 |
+
- v_proj
|
20 |
+
lora_fan_in_fan_out: false
|
21 |
+
wandb_project:
|
22 |
+
wandb_watch:
|
23 |
+
wandb_run_id:
|
24 |
+
wandb_log_model: checkpoint
|
25 |
+
output_dir: ./lora-llama-alpaca
|
26 |
+
batch_size: 32
|
27 |
+
micro_batch_size: 16
|
28 |
+
num_epochs: 3
|
29 |
+
learning_rate: 0.00003
|
30 |
+
train_on_inputs: false
|
31 |
+
group_by_length: false
|
32 |
+
bf16: false
|
33 |
+
tf32: false
|
34 |
+
early_stopping_patience:
|
35 |
+
resume_from_checkpoint:
|
36 |
+
local_rank:
|
37 |
+
special_tokens:
|
38 |
+
pad_token: "[PAD]"
|
39 |
+
bos_token: "<s>"
|
40 |
+
eos_token: "</s>"
|
41 |
+
unk_token: "<unk>"
|
configs/llama_13B_alpaca.yml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_model: huggyllama/llama-13b
|
2 |
+
model_type: LlamaForCausalLM
|
3 |
+
tokenizer_type: LlamaTokenizer
|
4 |
+
load_in_8bit: true
|
5 |
+
datasets:
|
6 |
+
- path: anon8231489123/ShareGPT_Vicuna_unfiltered
|
7 |
+
data_files: ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json
|
8 |
+
type: sharegpt
|
9 |
+
dataset_prepared_path: last_run_prepared
|
10 |
+
val_set_size: 0.002
|
11 |
+
adapter:
|
12 |
+
lora_model_dir:
|
13 |
+
sequence_len: 2048
|
14 |
+
lora_r: 8
|
15 |
+
lora_alpha: 16
|
16 |
+
lora_dropout: 0.05
|
17 |
+
lora_target_modules:
|
18 |
+
- q_proj
|
19 |
+
- v_proj
|
20 |
+
lora_fan_in_fan_out: false
|
21 |
+
wandb_project:
|
22 |
+
wandb_watch:
|
23 |
+
wandb_run_id:
|
24 |
+
wandb_log_model: checkpoint
|
25 |
+
output_dir: ./llama-13b-sharegpt
|
26 |
+
batch_size: 64
|
27 |
+
micro_batch_size: 2
|
28 |
+
warmup_steps: 1000
|
29 |
+
save_steps:
|
30 |
+
eval_steps:
|
31 |
+
num_epochs: 5
|
32 |
+
learning_rate: 0.00003
|
33 |
+
train_on_inputs: false
|
34 |
+
group_by_length: false
|
35 |
+
bf16: true
|
36 |
+
tf32: true
|
37 |
+
early_stopping_patience: 5
|
38 |
+
resume_from_checkpoint:
|
39 |
+
local_rank:
|
configs/llama_65B_alpaca.yml
CHANGED
@@ -5,7 +5,8 @@ load_in_8bit: true
|
|
5 |
datasets:
|
6 |
- path: data/alpaca_data_gpt4.jsonl
|
7 |
type: alpaca
|
8 |
-
- path:
|
|
|
9 |
type: sharegpt
|
10 |
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
11 |
type: gpteacher
|
@@ -30,6 +31,8 @@ wandb_log_model: checkpoint
|
|
30 |
output_dir: ./lora-llama-alpaca
|
31 |
batch_size: 128
|
32 |
micro_batch_size: 16
|
|
|
|
|
33 |
num_epochs: 5
|
34 |
learning_rate: 0.00003
|
35 |
train_on_inputs: false
|
|
|
5 |
datasets:
|
6 |
- path: data/alpaca_data_gpt4.jsonl
|
7 |
type: alpaca
|
8 |
+
- path: anon8231489123/ShareGPT_Vicuna_unfiltered
|
9 |
+
data_files: ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json
|
10 |
type: sharegpt
|
11 |
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
12 |
type: gpteacher
|
|
|
31 |
output_dir: ./lora-llama-alpaca
|
32 |
batch_size: 128
|
33 |
micro_batch_size: 16
|
34 |
+
warmup_steps: 1000
|
35 |
+
save_steps:
|
36 |
num_epochs: 5
|
37 |
learning_rate: 0.00003
|
38 |
train_on_inputs: false
|
configs/stability_3b.yml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_model: stabilityai/stablelm-base-alpha-3b
|
2 |
+
load_in_8bit: true
|
3 |
+
datasets:
|
4 |
+
- path: vicgalle/alpaca-gpt4
|
5 |
+
type: alpaca
|
6 |
+
dataset_prepared_path: last_run_prepared
|
7 |
+
val_set_size: 0.04
|
8 |
+
adapter:
|
9 |
+
lora_model_dir:
|
10 |
+
sequence_len: 4096
|
11 |
+
lora_r: 8
|
12 |
+
lora_alpha: 16
|
13 |
+
lora_dropout: 0.05
|
14 |
+
lora_target_modules:
|
15 |
+
- q_proj
|
16 |
+
- v_proj
|
17 |
+
lora_fan_in_fan_out: false
|
18 |
+
wandb_project: stable-llama-3b
|
19 |
+
wandb_watch:
|
20 |
+
wandb_run_id:
|
21 |
+
wandb_log_model: checkpoint
|
22 |
+
output_dir: ./stable-llama-3b
|
23 |
+
batch_size: 128
|
24 |
+
micro_batch_size: 16
|
25 |
+
num_epochs: 1
|
26 |
+
learning_rate: 0.00003
|
27 |
+
train_on_inputs: false
|
28 |
+
group_by_length: false
|
29 |
+
bf16: true
|
30 |
+
tf32: true
|
31 |
+
early_stopping_patience: 3
|
32 |
+
resume_from_checkpoint:
|
33 |
+
local_rank:
|
ds_config.json
CHANGED
@@ -11,11 +11,10 @@
|
|
11 |
"min_loss_scale": 1
|
12 |
},
|
13 |
"scheduler": {
|
14 |
-
"type": "
|
15 |
"params": {
|
16 |
-
"
|
17 |
-
"
|
18 |
-
"warmup_num_steps": "auto"
|
19 |
}
|
20 |
},
|
21 |
"zero_optimization": {
|
@@ -25,7 +24,8 @@
|
|
25 |
"allgather_bucket_size": 5e8,
|
26 |
"contiguous_gradients": true,
|
27 |
"reduce_bucket_size": "auto",
|
28 |
-
"reduce_scatter": true
|
|
|
29 |
},
|
30 |
"gradient_accumulation_steps": "auto",
|
31 |
"gradient_clipping": "auto",
|
|
|
11 |
"min_loss_scale": 1
|
12 |
},
|
13 |
"scheduler": {
|
14 |
+
"type": "OneCycle",
|
15 |
"params": {
|
16 |
+
"cycle_min_lr": 1e-7,
|
17 |
+
"cycle_max_lr": 1e-4
|
|
|
18 |
}
|
19 |
},
|
20 |
"zero_optimization": {
|
|
|
24 |
"allgather_bucket_size": 5e8,
|
25 |
"contiguous_gradients": true,
|
26 |
"reduce_bucket_size": "auto",
|
27 |
+
"reduce_scatter": true,
|
28 |
+
"stage3_gather_16bit_weights_on_model_save": true
|
29 |
},
|
30 |
"gradient_accumulation_steps": "auto",
|
31 |
"gradient_clipping": "auto",
|
scripts/finetune.py
CHANGED
@@ -159,7 +159,7 @@ def train(
|
|
159 |
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
160 |
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
161 |
choose_device(cfg)
|
162 |
-
cfg.ddp = cfg.world_size != 1
|
163 |
if cfg.ddp:
|
164 |
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
165 |
cfg.gradient_accumulation_steps = (
|
|
|
159 |
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
160 |
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
161 |
choose_device(cfg)
|
162 |
+
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
163 |
if cfg.ddp:
|
164 |
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
165 |
cfg.gradient_accumulation_steps = (
|
src/axolotl/datasets.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from typing import List
|
2 |
|
3 |
import torch
|
@@ -92,11 +93,14 @@ class ConstantLengthDataset(IterableDataset):
|
|
92 |
: self.seq_length
|
93 |
]
|
94 |
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
100 |
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
|
101 |
buffer_len = 0
|
102 |
|
|
|
1 |
+
import logging
|
2 |
from typing import List
|
3 |
|
4 |
import torch
|
|
|
93 |
: self.seq_length
|
94 |
]
|
95 |
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
96 |
+
if labels.size() == input_ids.size() and attention_mask.size() == input_ids.size():
|
97 |
+
yield {
|
98 |
+
"input_ids": input_ids,
|
99 |
+
"labels": labels,
|
100 |
+
"attention_mask": attention_mask,
|
101 |
+
}
|
102 |
+
else:
|
103 |
+
logging.warning("dropping batch due to tensor size mismatch")
|
104 |
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
|
105 |
buffer_len = 0
|
106 |
|
src/axolotl/prompters.py
CHANGED
@@ -128,6 +128,10 @@ conv_vicuna_v1_1 = Conversation(
|
|
128 |
|
129 |
class ShareGPTPrompter:
|
130 |
def build_prompt(self, source, tokenizer):
|
|
|
|
|
|
|
|
|
131 |
if len(source) < 2:
|
132 |
# If there isn't a back and forth conversation, ignore it
|
133 |
# also happens on the data splitting leaving empty conversations
|
|
|
128 |
|
129 |
class ShareGPTPrompter:
|
130 |
def build_prompt(self, source, tokenizer):
|
131 |
+
# ignore the system prompt if provided
|
132 |
+
if source[0]["from"] == "system":
|
133 |
+
source.pop(0)
|
134 |
+
|
135 |
if len(source) < 2:
|
136 |
# If there isn't a back and forth conversation, ignore it
|
137 |
# also happens on the data splitting leaving empty conversations
|
src/axolotl/utils/data.py
CHANGED
@@ -2,7 +2,8 @@ import logging
|
|
2 |
from hashlib import md5
|
3 |
from pathlib import Path
|
4 |
|
5 |
-
from datasets import load_from_disk, load_dataset, IterableDataset, Dataset
|
|
|
6 |
|
7 |
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
|
8 |
from axolotl.prompt_tokenizers import (
|
@@ -30,7 +31,7 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
30 |
ds_hash = str(
|
31 |
md5(
|
32 |
(
|
33 |
-
str(
|
34 |
+ "@"
|
35 |
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
36 |
).encode("utf-8")
|
@@ -43,13 +44,15 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
43 |
)
|
44 |
|
45 |
if any(prepared_ds_path.glob("*")):
|
46 |
-
logging.info("Loading prepared dataset from disk...")
|
47 |
dataset = load_from_disk(str(prepared_ds_path))
|
48 |
logging.info("Prepared dataset loaded from disk...")
|
49 |
else:
|
|
|
50 |
logging.info("Loading raw datasets...")
|
51 |
datasets = []
|
52 |
for d in cfg.datasets:
|
|
|
53 |
ds_from_hub = False
|
54 |
try:
|
55 |
load_dataset(d.path, streaming=True)
|
@@ -63,8 +66,14 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
63 |
"json", data_files=d.path, streaming=True, split=None
|
64 |
)
|
65 |
elif ds_from_hub:
|
66 |
-
|
|
|
|
|
|
|
67 |
else:
|
|
|
|
|
|
|
68 |
raise Exception("unhandled dataset load")
|
69 |
|
70 |
if d.type == "alpaca":
|
@@ -105,20 +114,32 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
105 |
datasets.append(ds_wrapper)
|
106 |
else:
|
107 |
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
constant_len_dataset = ConstantLengthDataset(
|
109 |
tokenizer,
|
110 |
-
|
111 |
seq_length=max_packed_sequence_len,
|
112 |
)
|
113 |
-
logging.info("
|
114 |
-
dataset = Dataset.from_list([_ for _ in constant_len_dataset])
|
115 |
-
test_size=cfg.val_set_size, shuffle=True, seed=42
|
116 |
-
)
|
117 |
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
|
|
|
|
|
|
|
122 |
train_dataset = dataset["train"]
|
123 |
eval_dataset = dataset["test"]
|
124 |
|
|
|
2 |
from hashlib import md5
|
3 |
from pathlib import Path
|
4 |
|
5 |
+
from datasets import load_from_disk, load_dataset, IterableDataset, Dataset, concatenate_datasets
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
|
8 |
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
|
9 |
from axolotl.prompt_tokenizers import (
|
|
|
31 |
ds_hash = str(
|
32 |
md5(
|
33 |
(
|
34 |
+
str(cfg.sequence_len)
|
35 |
+ "@"
|
36 |
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
37 |
).encode("utf-8")
|
|
|
44 |
)
|
45 |
|
46 |
if any(prepared_ds_path.glob("*")):
|
47 |
+
logging.info(f"Loading prepared dataset from disk ay {prepared_ds_path}...")
|
48 |
dataset = load_from_disk(str(prepared_ds_path))
|
49 |
logging.info("Prepared dataset loaded from disk...")
|
50 |
else:
|
51 |
+
logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
|
52 |
logging.info("Loading raw datasets...")
|
53 |
datasets = []
|
54 |
for d in cfg.datasets:
|
55 |
+
ds = None
|
56 |
ds_from_hub = False
|
57 |
try:
|
58 |
load_dataset(d.path, streaming=True)
|
|
|
66 |
"json", data_files=d.path, streaming=True, split=None
|
67 |
)
|
68 |
elif ds_from_hub:
|
69 |
+
if d.data_files:
|
70 |
+
ds = load_dataset(d.path, streaming=True, data_files=d.data_files)
|
71 |
+
else:
|
72 |
+
ds = load_dataset(d.path, streaming=True)
|
73 |
else:
|
74 |
+
fp = hf_hub_download(repo_id=d.path, repo_type="dataset", filename=d.data_files)
|
75 |
+
ds = load_dataset("json", data_files=fp, streaming=True, split=None)
|
76 |
+
if not ds:
|
77 |
raise Exception("unhandled dataset load")
|
78 |
|
79 |
if d.type == "alpaca":
|
|
|
114 |
datasets.append(ds_wrapper)
|
115 |
else:
|
116 |
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
|
117 |
+
logging.info("tokenizing, merging, and shuffling master dataset")
|
118 |
+
|
119 |
+
samples = []
|
120 |
+
for d in datasets:
|
121 |
+
samples = samples + [i for i in d]
|
122 |
+
dataset = Dataset.from_list(samples).shuffle(seed=42)
|
123 |
+
if cfg.local_rank == 0:
|
124 |
+
logging.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
125 |
+
dataset.save_to_disk(prepared_ds_path)
|
126 |
+
|
127 |
+
if cfg.max_packed_sequence_len is not None:
|
128 |
constant_len_dataset = ConstantLengthDataset(
|
129 |
tokenizer,
|
130 |
+
[dataset],
|
131 |
seq_length=max_packed_sequence_len,
|
132 |
)
|
133 |
+
logging.info(f"packing master dataset to len: {cfg.max_packed_sequence_len}")
|
134 |
+
dataset = Dataset.from_list([_ for _ in constant_len_dataset])
|
|
|
|
|
135 |
|
136 |
+
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
137 |
+
logging.info(f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards")
|
138 |
+
dataset = dataset.shard(num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx)
|
139 |
|
140 |
+
dataset = dataset.train_test_split(
|
141 |
+
test_size=cfg.val_set_size, shuffle=False
|
142 |
+
)
|
143 |
train_dataset = dataset["train"]
|
144 |
eval_dataset = dataset["test"]
|
145 |
|
src/axolotl/utils/models.py
CHANGED
@@ -7,11 +7,16 @@ import torch
|
|
7 |
import transformers
|
8 |
from transformers import (
|
9 |
AutoModelForCausalLM,
|
10 |
-
LlamaForCausalLM,
|
11 |
-
LlamaTokenizer,
|
12 |
AutoTokenizer,
|
13 |
PreTrainedModel,
|
14 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
17 |
|
@@ -70,7 +75,7 @@ def load_model(
|
|
70 |
snapshot_download_kwargs = {}
|
71 |
if cfg.base_model_ignore_patterns:
|
72 |
snapshot_download_kwargs["ignore_patterns"] = cfg.base_model_ignore_patterns
|
73 |
-
cache_model_path = Path(snapshot_download(base_model, **
|
74 |
files = (
|
75 |
list(cache_model_path.glob("*.pt"))
|
76 |
+ list(cache_model_path.glob("*.safetensors"))
|
@@ -95,15 +100,29 @@ def load_model(
|
|
95 |
else True,
|
96 |
)
|
97 |
load_in_8bit = False
|
98 |
-
elif is_llama_derived_model:
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
base_model,
|
101 |
load_in_8bit=cfg.load_in_8bit,
|
102 |
torch_dtype=torch_dtype,
|
103 |
device_map=cfg.device_map,
|
104 |
)
|
105 |
else:
|
106 |
-
model =
|
107 |
base_model,
|
108 |
load_in_8bit=cfg.load_in_8bit,
|
109 |
torch_dtype=torch_dtype,
|
@@ -123,7 +142,7 @@ def load_model(
|
|
123 |
|
124 |
if not tokenizer:
|
125 |
try:
|
126 |
-
if is_llama_derived_model:
|
127 |
tokenizer = LlamaTokenizer.from_pretrained(model)
|
128 |
else:
|
129 |
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
|
@@ -142,13 +161,17 @@ def load_model(
|
|
142 |
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
143 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
144 |
|
|
|
|
|
|
|
|
|
145 |
if load_in_8bit and not cfg.load_4bit:
|
146 |
logging.info("converting model w/ prepare_model_for_int8_training")
|
147 |
model = prepare_model_for_int8_training(model)
|
148 |
|
149 |
model, lora_config = load_adapter(model, cfg, adapter)
|
150 |
|
151 |
-
if cfg.ddp:
|
152 |
model.to(f"cuda:{cfg.local_rank}")
|
153 |
|
154 |
if cfg.load_4bit:
|
|
|
7 |
import transformers
|
8 |
from transformers import (
|
9 |
AutoModelForCausalLM,
|
|
|
|
|
10 |
AutoTokenizer,
|
11 |
PreTrainedModel,
|
12 |
)
|
13 |
+
try:
|
14 |
+
from transformers import (
|
15 |
+
LlamaForCausalLM,
|
16 |
+
LlamaTokenizer,
|
17 |
+
)
|
18 |
+
except:
|
19 |
+
logging.warning("This version of transformers does not support Llama. Consider upgrading.")
|
20 |
|
21 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
22 |
|
|
|
75 |
snapshot_download_kwargs = {}
|
76 |
if cfg.base_model_ignore_patterns:
|
77 |
snapshot_download_kwargs["ignore_patterns"] = cfg.base_model_ignore_patterns
|
78 |
+
cache_model_path = Path(snapshot_download(base_model, **snapshot_download_kwargs))
|
79 |
files = (
|
80 |
list(cache_model_path.glob("*.pt"))
|
81 |
+ list(cache_model_path.glob("*.safetensors"))
|
|
|
100 |
else True,
|
101 |
)
|
102 |
load_in_8bit = False
|
103 |
+
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
|
104 |
+
if not cfg.load_in_8bit:
|
105 |
+
model = LlamaForCausalLM.from_pretrained(
|
106 |
+
base_model,
|
107 |
+
device_map=cfg.device_map,
|
108 |
+
)
|
109 |
+
else:
|
110 |
+
model = LlamaForCausalLM.from_pretrained(
|
111 |
+
base_model,
|
112 |
+
load_in_8bit=cfg.load_in_8bit,
|
113 |
+
torch_dtype=torch_dtype,
|
114 |
+
device_map=cfg.device_map,
|
115 |
+
)
|
116 |
+
|
117 |
+
elif model_type:
|
118 |
+
model = getattr(transformers, model_type).from_pretrained(
|
119 |
base_model,
|
120 |
load_in_8bit=cfg.load_in_8bit,
|
121 |
torch_dtype=torch_dtype,
|
122 |
device_map=cfg.device_map,
|
123 |
)
|
124 |
else:
|
125 |
+
model = AutoModelForCausalLM.from_pretrained(
|
126 |
base_model,
|
127 |
load_in_8bit=cfg.load_in_8bit,
|
128 |
torch_dtype=torch_dtype,
|
|
|
142 |
|
143 |
if not tokenizer:
|
144 |
try:
|
145 |
+
if is_llama_derived_model and "LlamaTokenizer" in globals():
|
146 |
tokenizer = LlamaTokenizer.from_pretrained(model)
|
147 |
else:
|
148 |
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
|
|
|
161 |
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
162 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
163 |
|
164 |
+
if cfg.special_tokens:
|
165 |
+
for k, v in cfg.special_tokens.items():
|
166 |
+
setattr(tokenizer, k, v)
|
167 |
+
|
168 |
if load_in_8bit and not cfg.load_4bit:
|
169 |
logging.info("converting model w/ prepare_model_for_int8_training")
|
170 |
model = prepare_model_for_int8_training(model)
|
171 |
|
172 |
model, lora_config = load_adapter(model, cfg, adapter)
|
173 |
|
174 |
+
if cfg.ddp and not load_in_8bit:
|
175 |
model.to(f"cuda:{cfg.local_rank}")
|
176 |
|
177 |
if cfg.load_4bit:
|
src/axolotl/utils/trainer.py
CHANGED
@@ -1,5 +1,9 @@
|
|
1 |
import math
|
|
|
|
|
|
|
2 |
import bitsandbytes as bnb
|
|
|
3 |
import transformers
|
4 |
from torch import nn
|
5 |
from torch.optim.lr_scheduler import OneCycleLR
|
@@ -12,7 +16,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
12 |
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
13 |
)
|
14 |
warmup_steps = cfg.warmup_steps if cfg.warmup_steps else min(int(0.03 * total_num_steps), 100)
|
15 |
-
logging_steps = max(min(int(0.005 * total_num_steps), 10), 1)
|
16 |
save_steps = eval_steps = cfg.save_steps if cfg.save_steps else min(int(0.05 * total_num_steps), 200)
|
17 |
|
18 |
training_arguments_kwargs = {}
|
@@ -26,6 +30,15 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
26 |
if cfg.gradient_checkpointing is not None:
|
27 |
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
training_args = transformers.TrainingArguments(
|
30 |
per_device_train_batch_size=cfg.micro_batch_size,
|
31 |
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
@@ -37,7 +50,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
37 |
save_steps=save_steps,
|
38 |
output_dir=cfg.output_dir,
|
39 |
save_total_limit=3,
|
40 |
-
load_best_model_at_end=True if cfg.val_set_size > 0 else False,
|
41 |
ddp_find_unused_parameters=False if cfg.ddp else None,
|
42 |
group_by_length=cfg.group_by_length,
|
43 |
report_to="wandb" if cfg.use_wandb else None,
|
@@ -47,7 +60,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
47 |
|
48 |
trainer_kwargs = {}
|
49 |
|
50 |
-
if cfg.
|
51 |
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
52 |
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
53 |
optimizer_grouped_parameters = [
|
@@ -94,13 +107,22 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
94 |
)
|
95 |
trainer_kwargs["callbacks"] = [early_stop_cb]
|
96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
trainer = transformers.Trainer(
|
98 |
model=model,
|
99 |
train_dataset=train_dataset,
|
100 |
eval_dataset=eval_dataset,
|
101 |
args=training_args,
|
102 |
data_collator=transformers.DataCollatorForSeq2Seq(
|
103 |
-
tokenizer,
|
|
|
|
|
104 |
),
|
105 |
**trainer_kwargs,
|
106 |
)
|
|
|
1 |
import math
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
import bitsandbytes as bnb
|
6 |
+
import torch.cuda
|
7 |
import transformers
|
8 |
from torch import nn
|
9 |
from torch.optim.lr_scheduler import OneCycleLR
|
|
|
16 |
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
17 |
)
|
18 |
warmup_steps = cfg.warmup_steps if cfg.warmup_steps else min(int(0.03 * total_num_steps), 100)
|
19 |
+
logging_steps = cfg.logging_steps if cfg.logging_steps else max(min(int(0.005 * total_num_steps), 10), 1)
|
20 |
save_steps = eval_steps = cfg.save_steps if cfg.save_steps else min(int(0.05 * total_num_steps), 200)
|
21 |
|
22 |
training_arguments_kwargs = {}
|
|
|
30 |
if cfg.gradient_checkpointing is not None:
|
31 |
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
|
32 |
|
33 |
+
# deepspeed
|
34 |
+
if os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true" and torch.cuda.device_count() > 1:
|
35 |
+
if cfg.deepspeed:
|
36 |
+
training_arguments_kwargs["deepspeed"] = cfg.deepspeed
|
37 |
+
else:
|
38 |
+
# make a guess here
|
39 |
+
# TODO search Path("./") for one
|
40 |
+
training_arguments_kwargs["deepspeed"] = "./ds_config.json"
|
41 |
+
|
42 |
training_args = transformers.TrainingArguments(
|
43 |
per_device_train_batch_size=cfg.micro_batch_size,
|
44 |
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
|
|
50 |
save_steps=save_steps,
|
51 |
output_dir=cfg.output_dir,
|
52 |
save_total_limit=3,
|
53 |
+
load_best_model_at_end=True if cfg.val_set_size > 0 and save_steps % eval_steps == 0 else False,
|
54 |
ddp_find_unused_parameters=False if cfg.ddp else None,
|
55 |
group_by_length=cfg.group_by_length,
|
56 |
report_to="wandb" if cfg.use_wandb else None,
|
|
|
60 |
|
61 |
trainer_kwargs = {}
|
62 |
|
63 |
+
if cfg.optimizer == "adam8bit" and not cfg.load_4bit and not "deepspeed" in training_arguments_kwargs:
|
64 |
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
65 |
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
66 |
optimizer_grouped_parameters = [
|
|
|
107 |
)
|
108 |
trainer_kwargs["callbacks"] = [early_stop_cb]
|
109 |
|
110 |
+
data_collator_kwargs = {
|
111 |
+
"padding": True,
|
112 |
+
}
|
113 |
+
if cfg.collator_pad_to_longest:
|
114 |
+
data_collator_kwargs["padding"] = "longest"
|
115 |
+
else:
|
116 |
+
data_collator_kwargs["pad_to_multiple_of"] = 8
|
117 |
trainer = transformers.Trainer(
|
118 |
model=model,
|
119 |
train_dataset=train_dataset,
|
120 |
eval_dataset=eval_dataset,
|
121 |
args=training_args,
|
122 |
data_collator=transformers.DataCollatorForSeq2Seq(
|
123 |
+
tokenizer,
|
124 |
+
return_tensors="pt",
|
125 |
+
**data_collator_kwargs,
|
126 |
),
|
127 |
**trainer_kwargs,
|
128 |
)
|