Apply isort then black
Browse files- scripts/alpaca_json_to_jsonl.py +2 -5
- scripts/finetune.py +12 -8
- setup.py +1 -1
- src/axolotl/datasets.py +6 -2
- src/axolotl/flash_attn.py +31 -11
- src/axolotl/prompt_strategies/alpaca_chat.py +1 -0
- src/axolotl/prompt_strategies/creative_acr.py +26 -8
- src/axolotl/prompt_tokenizers.py +10 -3
- src/axolotl/prompters.py +5 -3
- src/axolotl/utils/callbacks.py +4 -3
- src/axolotl/utils/data.py +34 -23
- src/axolotl/utils/models.py +13 -23
- src/axolotl/utils/tokenization.py +1 -0
- src/axolotl/utils/trainer.py +8 -3
- tests/test_validation.py +1 -1
scripts/alpaca_json_to_jsonl.py
CHANGED
@@ -2,23 +2,20 @@
|
|
2 |
|
3 |
import os
|
4 |
import sys
|
5 |
-
|
6 |
-
from typing import Optional, Union
|
7 |
from pathlib import Path
|
|
|
8 |
|
9 |
import fire
|
10 |
|
11 |
-
|
12 |
from axolotl.convert import (
|
13 |
FileReader,
|
14 |
-
StdoutWriter,
|
15 |
FileWriter,
|
16 |
JsonlSerializer,
|
17 |
JsonParser,
|
18 |
JsonToJsonlConverter,
|
|
|
19 |
)
|
20 |
|
21 |
-
|
22 |
# add src to the pythonpath so we don't need to pip install this
|
23 |
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
24 |
src_dir = os.path.join(project_root, "src")
|
|
|
2 |
|
3 |
import os
|
4 |
import sys
|
|
|
|
|
5 |
from pathlib import Path
|
6 |
+
from typing import Optional, Union
|
7 |
|
8 |
import fire
|
9 |
|
|
|
10 |
from axolotl.convert import (
|
11 |
FileReader,
|
|
|
12 |
FileWriter,
|
13 |
JsonlSerializer,
|
14 |
JsonParser,
|
15 |
JsonToJsonlConverter,
|
16 |
+
StdoutWriter,
|
17 |
)
|
18 |
|
|
|
19 |
# add src to the pythonpath so we don't need to pip install this
|
20 |
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
21 |
src_dir = os.path.join(project_root, "src")
|
scripts/finetune.py
CHANGED
@@ -7,20 +7,20 @@ import random
|
|
7 |
import signal
|
8 |
import sys
|
9 |
from pathlib import Path
|
10 |
-
from typing import
|
11 |
|
12 |
import fire
|
13 |
import torch
|
14 |
import yaml
|
15 |
|
16 |
-
# add src to the pythonpath so we don't need to pip install this
|
17 |
-
from axolotl.utils.tokenization import check_dataset_labels
|
18 |
-
from axolotl.utils.validation import validate_config
|
19 |
-
from axolotl.utils.dict import DictDefault
|
20 |
-
|
21 |
from axolotl.utils.data import load_prepare_datasets
|
|
|
22 |
from axolotl.utils.models import load_model, load_tokenizer
|
|
|
|
|
|
|
23 |
from axolotl.utils.trainer import setup_trainer
|
|
|
24 |
from axolotl.utils.wandb import setup_wandb_env_vars
|
25 |
|
26 |
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
@@ -242,7 +242,10 @@ def train(
|
|
242 |
if cfg.local_rank == 0:
|
243 |
signal.signal(
|
244 |
signal.SIGINT,
|
245 |
-
lambda signal, frame: (
|
|
|
|
|
|
|
246 |
)
|
247 |
|
248 |
logging.info("Starting trainer...")
|
@@ -255,7 +258,8 @@ def train(
|
|
255 |
]
|
256 |
if len(possible_checkpoints) > 0:
|
257 |
sorted_paths = sorted(
|
258 |
-
possible_checkpoints,
|
|
|
259 |
)
|
260 |
resume_from_checkpoint = sorted_paths[-1]
|
261 |
logging.info(
|
|
|
7 |
import signal
|
8 |
import sys
|
9 |
from pathlib import Path
|
10 |
+
from typing import Any, Dict, List, Optional, Union
|
11 |
|
12 |
import fire
|
13 |
import torch
|
14 |
import yaml
|
15 |
|
|
|
|
|
|
|
|
|
|
|
16 |
from axolotl.utils.data import load_prepare_datasets
|
17 |
+
from axolotl.utils.dict import DictDefault
|
18 |
from axolotl.utils.models import load_model, load_tokenizer
|
19 |
+
|
20 |
+
# add src to the pythonpath so we don't need to pip install this
|
21 |
+
from axolotl.utils.tokenization import check_dataset_labels
|
22 |
from axolotl.utils.trainer import setup_trainer
|
23 |
+
from axolotl.utils.validation import validate_config
|
24 |
from axolotl.utils.wandb import setup_wandb_env_vars
|
25 |
|
26 |
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
|
242 |
if cfg.local_rank == 0:
|
243 |
signal.signal(
|
244 |
signal.SIGINT,
|
245 |
+
lambda signal, frame: (
|
246 |
+
model.save_pretrained(cfg.output_dir),
|
247 |
+
sys.exit(0),
|
248 |
+
),
|
249 |
)
|
250 |
|
251 |
logging.info("Starting trainer...")
|
|
|
258 |
]
|
259 |
if len(possible_checkpoints) > 0:
|
260 |
sorted_paths = sorted(
|
261 |
+
possible_checkpoints,
|
262 |
+
key=lambda path: int(path.split("-")[-1]),
|
263 |
)
|
264 |
resume_from_checkpoint = sorted_paths[-1]
|
265 |
logging.info(
|
setup.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
"""setup.py for axolotl"""
|
2 |
|
3 |
-
from setuptools import
|
4 |
|
5 |
install_requires = []
|
6 |
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
|
|
1 |
"""setup.py for axolotl"""
|
2 |
|
3 |
+
from setuptools import find_packages, setup
|
4 |
|
5 |
install_requires = []
|
6 |
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
src/axolotl/datasets.py
CHANGED
@@ -5,8 +5,8 @@ from typing import List
|
|
5 |
|
6 |
import torch
|
7 |
from datasets import IterableDataset
|
8 |
-
from .prompt_tokenizers import PromptTokenizingStrategy, InvalidDataException
|
9 |
|
|
|
10 |
|
11 |
# We want this to be a wrapper for an existing dataset that we have loaded
|
12 |
# lets use the concept of middlewares to wrap each dataset, for example
|
@@ -114,7 +114,11 @@ class ConstantLengthDataset(IterableDataset):
|
|
114 |
logging.warning(
|
115 |
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
|
116 |
)
|
117 |
-
buffer = {
|
|
|
|
|
|
|
|
|
118 |
buffer_len = 0
|
119 |
|
120 |
if example:
|
|
|
5 |
|
6 |
import torch
|
7 |
from datasets import IterableDataset
|
|
|
8 |
|
9 |
+
from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy
|
10 |
|
11 |
# We want this to be a wrapper for an existing dataset that we have loaded
|
12 |
# lets use the concept of middlewares to wrap each dataset, for example
|
|
|
114 |
logging.warning(
|
115 |
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
|
116 |
)
|
117 |
+
buffer = {
|
118 |
+
"input_ids": [],
|
119 |
+
"attention_mask": [],
|
120 |
+
"labels": [],
|
121 |
+
}
|
122 |
buffer_len = 0
|
123 |
|
124 |
if example:
|
src/axolotl/flash_attn.py
CHANGED
@@ -5,14 +5,11 @@
|
|
5 |
from typing import Optional, Tuple
|
6 |
|
7 |
import torch
|
8 |
-
|
9 |
import transformers
|
10 |
-
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
11 |
-
|
12 |
from einops import rearrange
|
13 |
-
|
14 |
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
15 |
-
from
|
16 |
|
17 |
|
18 |
def forward(
|
@@ -75,7 +72,11 @@ def forward(
|
|
75 |
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
76 |
max_s = q_len
|
77 |
cu_q_lens = torch.arange(
|
78 |
-
0,
|
|
|
|
|
|
|
|
|
79 |
)
|
80 |
output = flash_attn_unpadded_qkvpacked_func(
|
81 |
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
@@ -88,25 +89,44 @@ def forward(
|
|
88 |
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
89 |
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
90 |
x_unpad = rearrange(
|
91 |
-
x_unpad,
|
|
|
|
|
|
|
92 |
)
|
93 |
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
94 |
-
x_unpad,
|
|
|
|
|
|
|
|
|
|
|
95 |
)
|
96 |
output = rearrange(
|
97 |
pad_input(
|
98 |
-
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
|
|
|
|
|
|
|
99 |
),
|
100 |
"b s (h d) -> b s h d",
|
101 |
h=nheads,
|
102 |
)
|
103 |
-
return
|
|
|
|
|
|
|
|
|
104 |
|
105 |
|
106 |
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
107 |
# requires the attention mask to be the same as the key_padding_mask
|
108 |
def _prepare_decoder_attention_mask(
|
109 |
-
self,
|
|
|
|
|
|
|
|
|
110 |
): # pylint: disable=unused-argument
|
111 |
# [bsz, seq_len]
|
112 |
return attention_mask
|
|
|
5 |
from typing import Optional, Tuple
|
6 |
|
7 |
import torch
|
|
|
8 |
import transformers
|
|
|
|
|
9 |
from einops import rearrange
|
10 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
11 |
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
12 |
+
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
13 |
|
14 |
|
15 |
def forward(
|
|
|
72 |
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
73 |
max_s = q_len
|
74 |
cu_q_lens = torch.arange(
|
75 |
+
0,
|
76 |
+
(bsz + 1) * q_len,
|
77 |
+
step=q_len,
|
78 |
+
dtype=torch.int32,
|
79 |
+
device=qkv.device,
|
80 |
)
|
81 |
output = flash_attn_unpadded_qkvpacked_func(
|
82 |
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
|
|
89 |
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
90 |
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
91 |
x_unpad = rearrange(
|
92 |
+
x_unpad,
|
93 |
+
"nnz (three h d) -> nnz three h d",
|
94 |
+
three=3,
|
95 |
+
h=nheads,
|
96 |
)
|
97 |
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
98 |
+
x_unpad,
|
99 |
+
cu_q_lens,
|
100 |
+
max_s,
|
101 |
+
0.0,
|
102 |
+
softmax_scale=None,
|
103 |
+
causal=True,
|
104 |
)
|
105 |
output = rearrange(
|
106 |
pad_input(
|
107 |
+
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
|
108 |
+
indices,
|
109 |
+
bsz,
|
110 |
+
q_len,
|
111 |
),
|
112 |
"b s (h d) -> b s h d",
|
113 |
h=nheads,
|
114 |
)
|
115 |
+
return (
|
116 |
+
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
|
117 |
+
None,
|
118 |
+
None,
|
119 |
+
)
|
120 |
|
121 |
|
122 |
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
123 |
# requires the attention mask to be the same as the key_padding_mask
|
124 |
def _prepare_decoder_attention_mask(
|
125 |
+
self,
|
126 |
+
attention_mask,
|
127 |
+
input_shape,
|
128 |
+
inputs_embeds,
|
129 |
+
past_key_values_length,
|
130 |
): # pylint: disable=unused-argument
|
131 |
# [bsz, seq_len]
|
132 |
return attention_mask
|
src/axolotl/prompt_strategies/alpaca_chat.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
"""Module containing the AlpacaQAPromptTokenizingStrategy class"""
|
2 |
|
3 |
from typing import Tuple
|
|
|
4 |
from axolotl.prompt_tokenizers import (
|
5 |
AlpacaPromptTokenizingStrategy,
|
6 |
InstructionPromptTokenizingStrategy,
|
|
|
1 |
"""Module containing the AlpacaQAPromptTokenizingStrategy class"""
|
2 |
|
3 |
from typing import Tuple
|
4 |
+
|
5 |
from axolotl.prompt_tokenizers import (
|
6 |
AlpacaPromptTokenizingStrategy,
|
7 |
InstructionPromptTokenizingStrategy,
|
src/axolotl/prompt_strategies/creative_acr.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
"""Module loading the CreativePromptTokenizingStrategy and similar classes"""
|
2 |
|
3 |
-
from typing import Tuple, Union
|
4 |
|
5 |
import yaml
|
|
|
6 |
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
7 |
|
8 |
|
@@ -61,10 +62,14 @@ Answer: {answer}
|
|
61 |
|
62 |
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
63 |
scores = yaml.dump(
|
64 |
-
prompt["scores"],
|
|
|
|
|
65 |
)
|
66 |
critiques = yaml.dump(
|
67 |
-
prompt["critiques"],
|
|
|
|
|
68 |
)
|
69 |
evaluation = scores + critiques
|
70 |
question = prompt["instruction"]
|
@@ -97,10 +102,14 @@ Evaluation:
|
|
97 |
|
98 |
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
99 |
scores = yaml.dump(
|
100 |
-
prompt["scores"],
|
|
|
|
|
101 |
)
|
102 |
critiques = yaml.dump(
|
103 |
-
prompt["critiques"],
|
|
|
|
|
104 |
)
|
105 |
evaluation = scores + critiques
|
106 |
question = prompt["instruction"]
|
@@ -165,17 +174,26 @@ class CreativeRevisePrompter(CreativePrompterBase):
|
|
165 |
|
166 |
def load_answer(tokenizer, cfg):
|
167 |
return CreativeAnsweringPromptTokenizingStrategy(
|
168 |
-
CreativeAnswerPrompter(),
|
|
|
|
|
|
|
169 |
)
|
170 |
|
171 |
|
172 |
def load_critique(tokenizer, cfg):
|
173 |
return CreativeCritiquePromptTokenizingStrategy(
|
174 |
-
CreativeCritiquePrompter(),
|
|
|
|
|
|
|
175 |
)
|
176 |
|
177 |
|
178 |
def load_revise(tokenizer, cfg):
|
179 |
return CreativeRevisePromptTokenizingStrategy(
|
180 |
-
CreativeRevisePrompter(),
|
|
|
|
|
|
|
181 |
)
|
|
|
1 |
"""Module loading the CreativePromptTokenizingStrategy and similar classes"""
|
2 |
|
3 |
+
from typing import Generator, Tuple, Union
|
4 |
|
5 |
import yaml
|
6 |
+
|
7 |
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
8 |
|
9 |
|
|
|
62 |
|
63 |
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
64 |
scores = yaml.dump(
|
65 |
+
prompt["scores"],
|
66 |
+
default_flow_style=False,
|
67 |
+
Dumper=yaml.Dumper,
|
68 |
)
|
69 |
critiques = yaml.dump(
|
70 |
+
prompt["critiques"],
|
71 |
+
default_flow_style=False,
|
72 |
+
Dumper=yaml.Dumper,
|
73 |
)
|
74 |
evaluation = scores + critiques
|
75 |
question = prompt["instruction"]
|
|
|
102 |
|
103 |
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
104 |
scores = yaml.dump(
|
105 |
+
prompt["scores"],
|
106 |
+
default_flow_style=False,
|
107 |
+
Dumper=yaml.Dumper,
|
108 |
)
|
109 |
critiques = yaml.dump(
|
110 |
+
prompt["critiques"],
|
111 |
+
default_flow_style=False,
|
112 |
+
Dumper=yaml.Dumper,
|
113 |
)
|
114 |
evaluation = scores + critiques
|
115 |
question = prompt["instruction"]
|
|
|
174 |
|
175 |
def load_answer(tokenizer, cfg):
|
176 |
return CreativeAnsweringPromptTokenizingStrategy(
|
177 |
+
CreativeAnswerPrompter(),
|
178 |
+
tokenizer,
|
179 |
+
cfg.train_on_inputs,
|
180 |
+
cfg.sequence_len,
|
181 |
)
|
182 |
|
183 |
|
184 |
def load_critique(tokenizer, cfg):
|
185 |
return CreativeCritiquePromptTokenizingStrategy(
|
186 |
+
CreativeCritiquePrompter(),
|
187 |
+
tokenizer,
|
188 |
+
cfg.train_on_inputs,
|
189 |
+
cfg.sequence_len,
|
190 |
)
|
191 |
|
192 |
|
193 |
def load_revise(tokenizer, cfg):
|
194 |
return CreativeRevisePromptTokenizingStrategy(
|
195 |
+
CreativeRevisePrompter(),
|
196 |
+
tokenizer,
|
197 |
+
cfg.train_on_inputs,
|
198 |
+
cfg.sequence_len,
|
199 |
)
|
src/axolotl/prompt_tokenizers.py
CHANGED
@@ -347,7 +347,9 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
347 |
part = part[0] + part[1] if not user_token else part[1]
|
348 |
# this is still the user query, we should
|
349 |
res = self._tokenize(
|
350 |
-
part.strip(),
|
|
|
|
|
351 |
)
|
352 |
if user_token:
|
353 |
res["input_ids"] = [user_token, *res["input_ids"]]
|
@@ -358,10 +360,15 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
358 |
part = part[0] + part[1] if not assistant_token else part[1]
|
359 |
# this should be the assistent response, should end with an eos token
|
360 |
res = self._tokenize(
|
361 |
-
part.strip(),
|
|
|
|
|
362 |
)
|
363 |
if assistant_token:
|
364 |
-
res["input_ids"] = [
|
|
|
|
|
|
|
365 |
# not masked out from labels
|
366 |
labels = copy.deepcopy(res["input_ids"])
|
367 |
else:
|
|
|
347 |
part = part[0] + part[1] if not user_token else part[1]
|
348 |
# this is still the user query, we should
|
349 |
res = self._tokenize(
|
350 |
+
part.strip(),
|
351 |
+
add_eos_token=False,
|
352 |
+
strip_bos_token=True,
|
353 |
)
|
354 |
if user_token:
|
355 |
res["input_ids"] = [user_token, *res["input_ids"]]
|
|
|
360 |
part = part[0] + part[1] if not assistant_token else part[1]
|
361 |
# this should be the assistent response, should end with an eos token
|
362 |
res = self._tokenize(
|
363 |
+
part.strip(),
|
364 |
+
add_eos_token=True,
|
365 |
+
strip_bos_token=True,
|
366 |
)
|
367 |
if assistant_token:
|
368 |
+
res["input_ids"] = [
|
369 |
+
assistant_token,
|
370 |
+
*res["input_ids"],
|
371 |
+
]
|
372 |
# not masked out from labels
|
373 |
labels = copy.deepcopy(res["input_ids"])
|
374 |
else:
|
src/axolotl/prompters.py
CHANGED
@@ -2,8 +2,8 @@
|
|
2 |
|
3 |
import dataclasses
|
4 |
import logging
|
5 |
-
from enum import
|
6 |
-
from typing import List, Optional, Union
|
7 |
|
8 |
IGNORE_TOKEN_ID = -100
|
9 |
|
@@ -203,7 +203,9 @@ class ReflectAlpacaPrompter:
|
|
203 |
res = self.prompt_no_input.format(instruction=instruction)
|
204 |
if output and reflection and corrected:
|
205 |
label = self.agent_label.format(
|
206 |
-
output=output,
|
|
|
|
|
207 |
)
|
208 |
res = f"{res}{label}"
|
209 |
yield res
|
|
|
2 |
|
3 |
import dataclasses
|
4 |
import logging
|
5 |
+
from enum import Enum, auto
|
6 |
+
from typing import Generator, List, Optional, Union
|
7 |
|
8 |
IGNORE_TOKEN_ID = -100
|
9 |
|
|
|
203 |
res = self.prompt_no_input.format(instruction=instruction)
|
204 |
if output and reflection and corrected:
|
205 |
label = self.agent_label.format(
|
206 |
+
output=output,
|
207 |
+
reflection=reflection,
|
208 |
+
corrected=corrected,
|
209 |
)
|
210 |
res = f"{res}{label}"
|
211 |
yield res
|
src/axolotl/utils/callbacks.py
CHANGED
@@ -4,9 +4,9 @@ import os
|
|
4 |
|
5 |
from transformers import (
|
6 |
TrainerCallback,
|
7 |
-
TrainingArguments,
|
8 |
-
TrainerState,
|
9 |
TrainerControl,
|
|
|
|
|
10 |
)
|
11 |
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
12 |
|
@@ -22,7 +22,8 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
|
|
22 |
**kwargs,
|
23 |
):
|
24 |
checkpoint_folder = os.path.join(
|
25 |
-
args.output_dir,
|
|
|
26 |
)
|
27 |
|
28 |
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
|
|
4 |
|
5 |
from transformers import (
|
6 |
TrainerCallback,
|
|
|
|
|
7 |
TrainerControl,
|
8 |
+
TrainerState,
|
9 |
+
TrainingArguments,
|
10 |
)
|
11 |
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
12 |
|
|
|
22 |
**kwargs,
|
23 |
):
|
24 |
checkpoint_folder = os.path.join(
|
25 |
+
args.output_dir,
|
26 |
+
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
27 |
)
|
28 |
|
29 |
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
src/axolotl/utils/data.py
CHANGED
@@ -5,38 +5,33 @@ from hashlib import md5
|
|
5 |
from pathlib import Path
|
6 |
from typing import List, Tuple, Union
|
7 |
|
8 |
-
from datasets import
|
9 |
-
load_from_disk,
|
10 |
-
load_dataset,
|
11 |
-
Dataset,
|
12 |
-
DatasetDict,
|
13 |
-
)
|
14 |
from huggingface_hub import hf_hub_download
|
15 |
from transformers import PreTrainedTokenizerBase
|
16 |
|
17 |
-
from axolotl.datasets import
|
18 |
from axolotl.prompt_strategies import load
|
19 |
from axolotl.prompt_tokenizers import (
|
|
|
20 |
AlpacaPromptTokenizingStrategy,
|
|
|
|
|
21 |
GPTeacherPromptTokenizingStrategy,
|
|
|
22 |
OpenAssistantPromptTokenizingStrategy,
|
23 |
-
AlpacaReflectionPTStrategy,
|
24 |
ShareGPTPromptTokenizingStrategy,
|
25 |
-
JeopardyPromptTokenizingStrategy,
|
26 |
-
CompletionPromptTokenizingStrategy,
|
27 |
-
AlpacaMultipleChoicePromptTokenizingStrategy,
|
28 |
SummarizeTLDRPromptTokenizingStrategy,
|
29 |
)
|
30 |
from axolotl.prompters import (
|
31 |
AlpacaPrompter,
|
|
|
32 |
GPTeacherPrompter,
|
33 |
-
ReflectAlpacaPrompter,
|
34 |
-
ShareGPTPrompter,
|
35 |
JeopardyPrompter,
|
36 |
-
|
37 |
MultipleChoiceExplainPrompter,
|
|
|
|
|
38 |
SummarizeTLDRPrompter,
|
39 |
-
MultipleChoiceConcisePrompter,
|
40 |
)
|
41 |
|
42 |
|
@@ -67,7 +62,8 @@ def load_tokenized_prepared_datasets(
|
|
67 |
try:
|
68 |
if cfg.push_dataset_to_hub:
|
69 |
dataset = load_dataset(
|
70 |
-
f"{cfg.push_dataset_to_hub}/{ds_hash}",
|
|
|
71 |
)
|
72 |
dataset = dataset["train"]
|
73 |
except Exception: # pylint: disable=broad-except
|
@@ -88,7 +84,11 @@ def load_tokenized_prepared_datasets(
|
|
88 |
ds: Union[Dataset, DatasetDict] = None
|
89 |
ds_from_hub = False
|
90 |
try:
|
91 |
-
load_dataset(
|
|
|
|
|
|
|
|
|
92 |
ds_from_hub = True
|
93 |
except FileNotFoundError:
|
94 |
pass
|
@@ -96,7 +96,10 @@ def load_tokenized_prepared_datasets(
|
|
96 |
# prefer local dataset, even if hub exists
|
97 |
if Path(d.path).exists():
|
98 |
ds = load_dataset(
|
99 |
-
"json",
|
|
|
|
|
|
|
100 |
)
|
101 |
elif ds_from_hub:
|
102 |
if d.data_files:
|
@@ -108,11 +111,15 @@ def load_tokenized_prepared_datasets(
|
|
108 |
)
|
109 |
else:
|
110 |
ds = load_dataset(
|
111 |
-
d.path,
|
|
|
|
|
112 |
)
|
113 |
else:
|
114 |
fp = hf_hub_download(
|
115 |
-
repo_id=d.path,
|
|
|
|
|
116 |
)
|
117 |
ds = load_dataset("json", data_files=fp, streaming=False, split=None)
|
118 |
if not ds:
|
@@ -249,7 +256,9 @@ def load_tokenized_prepared_datasets(
|
|
249 |
|
250 |
|
251 |
def load_prepare_datasets(
|
252 |
-
tokenizer: PreTrainedTokenizerBase,
|
|
|
|
|
253 |
) -> Tuple[Dataset, Dataset]:
|
254 |
max_packed_sequence_len = (
|
255 |
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
@@ -353,7 +362,8 @@ def load_prepare_datasets(
|
|
353 |
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
354 |
)
|
355 |
dataset.push_to_hub(
|
356 |
-
f"{cfg.push_dataset_to_hub}/{ds_hash}",
|
|
|
357 |
)
|
358 |
else:
|
359 |
dataset = load_tokenized_prepared_datasets(
|
@@ -365,7 +375,8 @@ def load_prepare_datasets(
|
|
365 |
f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
|
366 |
)
|
367 |
dataset = dataset.shard(
|
368 |
-
num_shards=cfg.dataset_shard_num,
|
|
|
369 |
)
|
370 |
|
371 |
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
|
|
|
5 |
from pathlib import Path
|
6 |
from typing import List, Tuple, Union
|
7 |
|
8 |
+
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
|
|
|
|
|
|
|
|
|
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
from transformers import PreTrainedTokenizerBase
|
11 |
|
12 |
+
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
|
13 |
from axolotl.prompt_strategies import load
|
14 |
from axolotl.prompt_tokenizers import (
|
15 |
+
AlpacaMultipleChoicePromptTokenizingStrategy,
|
16 |
AlpacaPromptTokenizingStrategy,
|
17 |
+
AlpacaReflectionPTStrategy,
|
18 |
+
CompletionPromptTokenizingStrategy,
|
19 |
GPTeacherPromptTokenizingStrategy,
|
20 |
+
JeopardyPromptTokenizingStrategy,
|
21 |
OpenAssistantPromptTokenizingStrategy,
|
|
|
22 |
ShareGPTPromptTokenizingStrategy,
|
|
|
|
|
|
|
23 |
SummarizeTLDRPromptTokenizingStrategy,
|
24 |
)
|
25 |
from axolotl.prompters import (
|
26 |
AlpacaPrompter,
|
27 |
+
CompletionPrompter,
|
28 |
GPTeacherPrompter,
|
|
|
|
|
29 |
JeopardyPrompter,
|
30 |
+
MultipleChoiceConcisePrompter,
|
31 |
MultipleChoiceExplainPrompter,
|
32 |
+
ReflectAlpacaPrompter,
|
33 |
+
ShareGPTPrompter,
|
34 |
SummarizeTLDRPrompter,
|
|
|
35 |
)
|
36 |
|
37 |
|
|
|
62 |
try:
|
63 |
if cfg.push_dataset_to_hub:
|
64 |
dataset = load_dataset(
|
65 |
+
f"{cfg.push_dataset_to_hub}/{ds_hash}",
|
66 |
+
use_auth_token=use_auth_token,
|
67 |
)
|
68 |
dataset = dataset["train"]
|
69 |
except Exception: # pylint: disable=broad-except
|
|
|
84 |
ds: Union[Dataset, DatasetDict] = None
|
85 |
ds_from_hub = False
|
86 |
try:
|
87 |
+
load_dataset(
|
88 |
+
d.path,
|
89 |
+
streaming=True,
|
90 |
+
use_auth_token=use_auth_token,
|
91 |
+
)
|
92 |
ds_from_hub = True
|
93 |
except FileNotFoundError:
|
94 |
pass
|
|
|
96 |
# prefer local dataset, even if hub exists
|
97 |
if Path(d.path).exists():
|
98 |
ds = load_dataset(
|
99 |
+
"json",
|
100 |
+
data_files=d.path,
|
101 |
+
streaming=False,
|
102 |
+
split=None,
|
103 |
)
|
104 |
elif ds_from_hub:
|
105 |
if d.data_files:
|
|
|
111 |
)
|
112 |
else:
|
113 |
ds = load_dataset(
|
114 |
+
d.path,
|
115 |
+
streaming=False,
|
116 |
+
use_auth_token=use_auth_token,
|
117 |
)
|
118 |
else:
|
119 |
fp = hf_hub_download(
|
120 |
+
repo_id=d.path,
|
121 |
+
repo_type="dataset",
|
122 |
+
filename=d.data_files,
|
123 |
)
|
124 |
ds = load_dataset("json", data_files=fp, streaming=False, split=None)
|
125 |
if not ds:
|
|
|
256 |
|
257 |
|
258 |
def load_prepare_datasets(
|
259 |
+
tokenizer: PreTrainedTokenizerBase,
|
260 |
+
cfg,
|
261 |
+
default_dataset_prepared_path,
|
262 |
) -> Tuple[Dataset, Dataset]:
|
263 |
max_packed_sequence_len = (
|
264 |
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
|
|
362 |
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
363 |
)
|
364 |
dataset.push_to_hub(
|
365 |
+
f"{cfg.push_dataset_to_hub}/{ds_hash}",
|
366 |
+
private=True,
|
367 |
)
|
368 |
else:
|
369 |
dataset = load_tokenized_prepared_datasets(
|
|
|
375 |
f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
|
376 |
)
|
377 |
dataset = dataset.shard(
|
378 |
+
num_shards=cfg.dataset_shard_num,
|
379 |
+
index=cfg.dataset_shard_idx,
|
380 |
)
|
381 |
|
382 |
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
|
src/axolotl/utils/models.py
CHANGED
@@ -5,23 +5,17 @@ import logging
|
|
5 |
import math
|
6 |
import os
|
7 |
from pathlib import Path
|
8 |
-
from typing import Optional, Tuple
|
9 |
|
10 |
import bitsandbytes as bnb
|
11 |
import torch
|
12 |
import transformers
|
13 |
-
from transformers import
|
14 |
-
|
15 |
-
|
16 |
-
PreTrainedModel,
|
17 |
-
AutoConfig,
|
18 |
-
BitsAndBytesConfig,
|
19 |
-
)
|
20 |
|
21 |
try:
|
22 |
-
from transformers import
|
23 |
-
LlamaForCausalLM,
|
24 |
-
)
|
25 |
except ImportError:
|
26 |
logging.warning(
|
27 |
"This version of transformers does not support Llama. Consider upgrading."
|
@@ -31,9 +25,10 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
|
31 |
|
32 |
if TYPE_CHECKING:
|
33 |
from peft import PeftConfig # noqa: F401
|
34 |
-
from axolotl.utils.dict import DictDefault # noqa: F401
|
35 |
from transformers import PreTrainedTokenizer # noqa: F401
|
36 |
|
|
|
|
|
37 |
|
38 |
def load_tokenizer(
|
39 |
base_model_config,
|
@@ -56,7 +51,10 @@ def load_tokenizer(
|
|
56 |
logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
57 |
logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
58 |
|
59 |
-
if tokenizer.__class__.__name__ in [
|
|
|
|
|
|
|
60 |
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
61 |
|
62 |
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
@@ -312,11 +310,7 @@ def load_adapter(model, cfg, adapter):
|
|
312 |
|
313 |
def load_llama_adapter(model, cfg):
|
314 |
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
315 |
-
from peft import
|
316 |
-
AdaptionPromptConfig,
|
317 |
-
get_peft_model,
|
318 |
-
PeftModel,
|
319 |
-
)
|
320 |
|
321 |
peft_config = AdaptionPromptConfig(
|
322 |
adapter_layers=cfg.peft_adapter.layers, # layers (L)
|
@@ -361,11 +355,7 @@ def find_all_linear_names(bits, model):
|
|
361 |
def load_lora(model, cfg):
|
362 |
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
363 |
|
364 |
-
from peft import
|
365 |
-
LoraConfig,
|
366 |
-
get_peft_model,
|
367 |
-
PeftModel,
|
368 |
-
)
|
369 |
|
370 |
lora_target_modules = list(cfg.lora_target_modules or [])
|
371 |
|
|
|
5 |
import math
|
6 |
import os
|
7 |
from pathlib import Path
|
8 |
+
from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
|
9 |
|
10 |
import bitsandbytes as bnb
|
11 |
import torch
|
12 |
import transformers
|
13 |
+
from transformers import AutoModelForCausalLM # noqa: F401
|
14 |
+
from transformers import PreTrainedModel # noqa: F401
|
15 |
+
from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig
|
|
|
|
|
|
|
|
|
16 |
|
17 |
try:
|
18 |
+
from transformers import LlamaForCausalLM
|
|
|
|
|
19 |
except ImportError:
|
20 |
logging.warning(
|
21 |
"This version of transformers does not support Llama. Consider upgrading."
|
|
|
25 |
|
26 |
if TYPE_CHECKING:
|
27 |
from peft import PeftConfig # noqa: F401
|
|
|
28 |
from transformers import PreTrainedTokenizer # noqa: F401
|
29 |
|
30 |
+
from axolotl.utils.dict import DictDefault # noqa: F401
|
31 |
+
|
32 |
|
33 |
def load_tokenizer(
|
34 |
base_model_config,
|
|
|
51 |
logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
52 |
logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
53 |
|
54 |
+
if tokenizer.__class__.__name__ in [
|
55 |
+
"LlamaTokenizer",
|
56 |
+
"LlamaTokenizerFast",
|
57 |
+
]:
|
58 |
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
59 |
|
60 |
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
|
|
310 |
|
311 |
def load_llama_adapter(model, cfg):
|
312 |
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
313 |
+
from peft import AdaptionPromptConfig, PeftModel, get_peft_model
|
|
|
|
|
|
|
|
|
314 |
|
315 |
peft_config = AdaptionPromptConfig(
|
316 |
adapter_layers=cfg.peft_adapter.layers, # layers (L)
|
|
|
355 |
def load_lora(model, cfg):
|
356 |
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
357 |
|
358 |
+
from peft import LoraConfig, PeftModel, get_peft_model
|
|
|
|
|
|
|
|
|
359 |
|
360 |
lora_target_modules = list(cfg.lora_target_modules or [])
|
361 |
|
src/axolotl/utils/tokenization.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2 |
|
3 |
|
4 |
import logging
|
|
|
5 |
from termcolor import colored
|
6 |
|
7 |
|
|
|
2 |
|
3 |
|
4 |
import logging
|
5 |
+
|
6 |
from termcolor import colored
|
7 |
|
8 |
|
src/axolotl/utils/trainer.py
CHANGED
@@ -15,8 +15,8 @@ from torch.optim.lr_scheduler import OneCycleLR
|
|
15 |
from transformers import EarlyStoppingCallback, Trainer
|
16 |
from transformers.trainer_pt_utils import get_parameter_names
|
17 |
|
18 |
-
from axolotl.utils.schedulers import InterpolatingLogScheduler
|
19 |
from axolotl.utils.callbacks import SavePeftModelCallback
|
|
|
20 |
|
21 |
|
22 |
class OneCycleLRSchedulerTrainer(Trainer):
|
@@ -29,7 +29,9 @@ class OneCycleLRSchedulerTrainer(Trainer):
|
|
29 |
self.lr_scheduler = None
|
30 |
|
31 |
def create_scheduler(
|
32 |
-
self,
|
|
|
|
|
33 |
):
|
34 |
optimizer = self.optimizer if optimizer is None else optimizer
|
35 |
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
@@ -216,7 +218,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
216 |
)
|
217 |
callbacks.append(early_stop_cb)
|
218 |
|
219 |
-
if cfg.local_rank == 0 and cfg.adapter in [
|
|
|
|
|
|
|
220 |
callbacks.append(SavePeftModelCallback)
|
221 |
|
222 |
data_collator_kwargs = {
|
|
|
15 |
from transformers import EarlyStoppingCallback, Trainer
|
16 |
from transformers.trainer_pt_utils import get_parameter_names
|
17 |
|
|
|
18 |
from axolotl.utils.callbacks import SavePeftModelCallback
|
19 |
+
from axolotl.utils.schedulers import InterpolatingLogScheduler
|
20 |
|
21 |
|
22 |
class OneCycleLRSchedulerTrainer(Trainer):
|
|
|
29 |
self.lr_scheduler = None
|
30 |
|
31 |
def create_scheduler(
|
32 |
+
self,
|
33 |
+
num_training_steps: int,
|
34 |
+
optimizer: Optional[torch.optim.Optimizer] = None,
|
35 |
):
|
36 |
optimizer = self.optimizer if optimizer is None else optimizer
|
37 |
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
|
|
218 |
)
|
219 |
callbacks.append(early_stop_cb)
|
220 |
|
221 |
+
if cfg.local_rank == 0 and cfg.adapter in [
|
222 |
+
"lora",
|
223 |
+
"qlora",
|
224 |
+
]: # only save in rank 0
|
225 |
callbacks.append(SavePeftModelCallback)
|
226 |
|
227 |
data_collator_kwargs = {
|
tests/test_validation.py
CHANGED
@@ -4,8 +4,8 @@ import unittest
|
|
4 |
|
5 |
import pytest
|
6 |
|
7 |
-
from axolotl.utils.validation import validate_config
|
8 |
from axolotl.utils.dict import DictDefault
|
|
|
9 |
|
10 |
|
11 |
class ValidationTest(unittest.TestCase):
|
|
|
4 |
|
5 |
import pytest
|
6 |
|
|
|
7 |
from axolotl.utils.dict import DictDefault
|
8 |
+
from axolotl.utils.validation import validate_config
|
9 |
|
10 |
|
11 |
class ValidationTest(unittest.TestCase):
|