optionally be able to specify alpaca or chat style prompts
Browse files- docker/Dockerfile-base +14 -1
- scripts/finetune.py +14 -2
- src/axolotl/prompt_tokenizers.py +47 -13
- src/axolotl/prompters.py +60 -12
- src/axolotl/utils/data.py +60 -23
- src/axolotl/utils/models.py +28 -2
docker/Dockerfile-base
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
ARG CUDA_VERSION="11.8.0"
|
2 |
ARG CUDNN_VERSION="8"
|
3 |
ARG UBUNTU_VERSION="22.04"
|
|
|
4 |
|
5 |
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION as base-builder
|
6 |
|
@@ -39,6 +40,14 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
|
39 |
|
40 |
RUN git clone https://github.com/HazyResearch/flash-attention.git && \
|
41 |
cd flash-attention && \
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
python3 setup.py bdist_wheel
|
43 |
|
44 |
FROM base-builder AS deepspeed-builder
|
@@ -60,8 +69,12 @@ RUN cd apex && MAX_JOBS=1 python3 -m pip install --global-option="--cpp_ext" --g
|
|
60 |
RUN mkdir /workspace/wheels
|
61 |
COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
|
62 |
COPY --from=flash-attn-builder /workspace/flash-attention/dist/flash_attn-*.whl wheels
|
|
|
|
|
|
|
|
|
63 |
|
64 |
-
RUN pip3 install wheels/deepspeed-*.whl wheels/flash_attn-*.whl
|
65 |
RUN git lfs install --skip-repo
|
66 |
RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main" \
|
67 |
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
|
|
|
1 |
ARG CUDA_VERSION="11.8.0"
|
2 |
ARG CUDNN_VERSION="8"
|
3 |
ARG UBUNTU_VERSION="22.04"
|
4 |
+
ARG MAX_JOBS=4
|
5 |
|
6 |
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION as base-builder
|
7 |
|
|
|
40 |
|
41 |
RUN git clone https://github.com/HazyResearch/flash-attention.git && \
|
42 |
cd flash-attention && \
|
43 |
+
python3 setup.py bdist_wheel && \
|
44 |
+
cd csrc/fused_dense_lib && \
|
45 |
+
python3 setup.py bdist_wheel && \
|
46 |
+
cd csrc/xentropy && \
|
47 |
+
python3 setup.py bdist_wheel && \
|
48 |
+
cd csrc/rotary && \
|
49 |
+
python3 setup.py bdist_wheel && \
|
50 |
+
cd csrc/layer_norm && \
|
51 |
python3 setup.py bdist_wheel
|
52 |
|
53 |
FROM base-builder AS deepspeed-builder
|
|
|
69 |
RUN mkdir /workspace/wheels
|
70 |
COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
|
71 |
COPY --from=flash-attn-builder /workspace/flash-attention/dist/flash_attn-*.whl wheels
|
72 |
+
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/fused_dense_lib/dist/fused_dense_lib-*.whl wheels
|
73 |
+
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/xentropy/dist/xentropy-*.whl wheels
|
74 |
+
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/rotary/dist/rotary-*.whl wheels
|
75 |
+
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/layer_norm/dist/dropout_layer_norm-*.whl wheels
|
76 |
|
77 |
+
RUN pip3 install wheels/deepspeed-*.whl wheels/flash_attn-*.whl wheels/fused_dense_lib-*.whl wheels/xeontropy-*.whl wheels/rotary-*.whl wheels/dropout_layer_norm-*.whl
|
78 |
RUN git lfs install --skip-repo
|
79 |
RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main" \
|
80 |
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
|
scripts/finetune.py
CHANGED
@@ -31,7 +31,7 @@ DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
|
31 |
def choose_device(cfg):
|
32 |
def get_device():
|
33 |
if torch.cuda.is_available():
|
34 |
-
return "cuda"
|
35 |
else:
|
36 |
try:
|
37 |
if torch.backends.mps.is_available():
|
@@ -131,7 +131,8 @@ def train(
|
|
131 |
# then overwrite the value
|
132 |
cfg_keys = dict(cfg).keys()
|
133 |
for k in kwargs:
|
134 |
-
if
|
|
|
135 |
# handle booleans
|
136 |
if isinstance(cfg[k], bool):
|
137 |
cfg[k] = bool(kwargs[k])
|
@@ -169,6 +170,15 @@ def train(
|
|
169 |
inference=("inference" in kwargs),
|
170 |
)
|
171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
if "inference" in kwargs:
|
173 |
logging.info("calling do_inference function")
|
174 |
do_inference(cfg, model, tokenizer)
|
@@ -216,6 +226,8 @@ def train(
|
|
216 |
)
|
217 |
|
218 |
logging.info("Starting trainer...")
|
|
|
|
|
219 |
resume_from_checkpoint = cfg.resume_from_checkpoint
|
220 |
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
221 |
possible_checkpoints = [
|
|
|
31 |
def choose_device(cfg):
|
32 |
def get_device():
|
33 |
if torch.cuda.is_available():
|
34 |
+
return f"cuda:{cfg.local_rank}"
|
35 |
else:
|
36 |
try:
|
37 |
if torch.backends.mps.is_available():
|
|
|
131 |
# then overwrite the value
|
132 |
cfg_keys = dict(cfg).keys()
|
133 |
for k in kwargs:
|
134 |
+
# if not strict, allow writing to cfg even if it's not in the yml already
|
135 |
+
if k in cfg_keys or cfg.strict is False:
|
136 |
# handle booleans
|
137 |
if isinstance(cfg[k], bool):
|
138 |
cfg[k] = bool(kwargs[k])
|
|
|
170 |
inference=("inference" in kwargs),
|
171 |
)
|
172 |
|
173 |
+
if "merge_lora" in kwargs and cfg.adapter is not None:
|
174 |
+
print("running merge of LoRA with base model")
|
175 |
+
model = model.merge_and_unload()
|
176 |
+
|
177 |
+
if cfg.local_rank == 0:
|
178 |
+
print("saving merged model")
|
179 |
+
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
180 |
+
return
|
181 |
+
|
182 |
if "inference" in kwargs:
|
183 |
logging.info("calling do_inference function")
|
184 |
do_inference(cfg, model, tokenizer)
|
|
|
226 |
)
|
227 |
|
228 |
logging.info("Starting trainer...")
|
229 |
+
if cfg.group_by_length:
|
230 |
+
logging.info("hang tight... sorting dataset for group_by_length")
|
231 |
resume_from_checkpoint = cfg.resume_from_checkpoint
|
232 |
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
233 |
possible_checkpoints = [
|
src/axolotl/prompt_tokenizers.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
import abc
|
2 |
import copy
|
|
|
|
|
3 |
|
4 |
from transformers import PreTrainedTokenizer
|
5 |
|
@@ -33,6 +35,20 @@ class PromptTokenizingStrategy(abc.ABC):
|
|
33 |
def tokenize_prompt(self, prompt):
|
34 |
pass
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
38 |
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
@@ -63,7 +79,7 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
63 |
response,
|
64 |
)))
|
65 |
|
66 |
-
def _tokenize(self, prompt, add_eos_token=True):
|
67 |
result = self.tokenizer(
|
68 |
prompt,
|
69 |
truncation=True,
|
@@ -79,6 +95,13 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
79 |
result["input_ids"].append(self.tokenizer.eos_token_id)
|
80 |
result["attention_mask"].append(1)
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
result["labels"] = result["input_ids"].copy()
|
83 |
return result
|
84 |
|
@@ -239,23 +262,34 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
239 |
"labels": [],
|
240 |
}
|
241 |
current_len = 0
|
|
|
|
|
242 |
try:
|
243 |
-
for i, part in enumerate(self.prompter.build_prompt(prompt["conversations"]
|
244 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
# this is only ever the first part, should include the bos token and the user query
|
246 |
res = self._tokenize(part.strip(), add_eos_token=False, strip_bos_token=False)
|
247 |
# everything from this is masked out from the labels
|
248 |
labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"])
|
249 |
-
elif i % 2 == 0:
|
250 |
-
# this is still the user query, we should
|
251 |
-
res = self._tokenize(part.strip(), add_eos_token=False, strip_bos_token=True)
|
252 |
-
# everything from this is masked out from the labels
|
253 |
-
labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"])
|
254 |
-
else:
|
255 |
-
# this should be the assistent response, should end with an eos token
|
256 |
-
res = self._tokenize(part.strip(), add_eos_token=True, strip_bos_token=True)
|
257 |
-
# not masked out from labels
|
258 |
-
labels = copy.deepcopy(res["input_ids"])
|
259 |
input_ids = res["input_ids"]
|
260 |
input_len = len(input_ids)
|
261 |
result["input_ids"][current_len : current_len + input_len] = input_ids
|
|
|
1 |
import abc
|
2 |
import copy
|
3 |
+
import functools
|
4 |
+
import logging
|
5 |
|
6 |
from transformers import PreTrainedTokenizer
|
7 |
|
|
|
35 |
def tokenize_prompt(self, prompt):
|
36 |
pass
|
37 |
|
38 |
+
@functools.cache
|
39 |
+
def _get_user_token(self):
|
40 |
+
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
|
41 |
+
if type(id_or_ids, (int,)):
|
42 |
+
return id_or_ids
|
43 |
+
return False
|
44 |
+
|
45 |
+
@functools.cache
|
46 |
+
def _get_assistant_token(self):
|
47 |
+
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
|
48 |
+
if type(id_or_ids, (int,)):
|
49 |
+
return id_or_ids
|
50 |
+
return False
|
51 |
+
|
52 |
|
53 |
class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
54 |
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
|
|
79 |
response,
|
80 |
)))
|
81 |
|
82 |
+
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
83 |
result = self.tokenizer(
|
84 |
prompt,
|
85 |
truncation=True,
|
|
|
95 |
result["input_ids"].append(self.tokenizer.eos_token_id)
|
96 |
result["attention_mask"].append(1)
|
97 |
|
98 |
+
if (
|
99 |
+
result["input_ids"][0] == self.tokenizer.bos_token_id
|
100 |
+
and strip_bos_token
|
101 |
+
):
|
102 |
+
result["input_ids"] = result["input_ids"][1:]
|
103 |
+
result["attention_mask"] = result["attention_mask"][1:]
|
104 |
+
|
105 |
result["labels"] = result["input_ids"].copy()
|
106 |
return result
|
107 |
|
|
|
262 |
"labels": [],
|
263 |
}
|
264 |
current_len = 0
|
265 |
+
user_token = self._get_user_token()
|
266 |
+
assistant_token = self._get_assistant_token()
|
267 |
try:
|
268 |
+
for i, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
|
269 |
+
if isinstance(part, tuple):
|
270 |
+
if part[0] == "USER:":
|
271 |
+
part = part[0] + part[1] if not user_token else part[1]
|
272 |
+
# this is still the user query, we should
|
273 |
+
res = self._tokenize(part.strip(), add_eos_token=False, strip_bos_token=True)
|
274 |
+
if user_token:
|
275 |
+
res = [user_token, *res]
|
276 |
+
# everything from this is masked out from the labels
|
277 |
+
labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"])
|
278 |
+
elif part[0] == "ASSISTANT:":
|
279 |
+
part = part[0] + part[1] if not assistant_token else part[1]
|
280 |
+
# this should be the assistent response, should end with an eos token
|
281 |
+
res = self._tokenize(part.strip(), add_eos_token=True, strip_bos_token=True)
|
282 |
+
if assistant_token:
|
283 |
+
res = [assistant_token, *res]
|
284 |
+
# not masked out from labels
|
285 |
+
labels = copy.deepcopy(res["input_ids"])
|
286 |
+
else:
|
287 |
+
logging.warning("unhandled role: " + part[0])
|
288 |
+
else:
|
289 |
# this is only ever the first part, should include the bos token and the user query
|
290 |
res = self._tokenize(part.strip(), add_eos_token=False, strip_bos_token=False)
|
291 |
# everything from this is masked out from the labels
|
292 |
labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
input_ids = res["input_ids"]
|
294 |
input_len = len(input_ids)
|
295 |
result["input_ids"][current_len : current_len + input_len] = input_ids
|
src/axolotl/prompters.py
CHANGED
@@ -1,15 +1,34 @@
|
|
1 |
import copy
|
2 |
import dataclasses
|
|
|
3 |
from enum import auto, Enum
|
4 |
from typing import List, Tuple, Any, Union, Generator
|
5 |
|
6 |
IGNORE_TOKEN_ID = -100
|
7 |
|
8 |
|
|
|
|
|
|
|
|
|
9 |
class AlpacaPrompter:
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
def build_prompt(
|
15 |
self,
|
@@ -36,7 +55,7 @@ class JeopardyPrompter(AlpacaPrompter):
|
|
36 |
|
37 |
|
38 |
class MultipleChoiceExplainPrompter(AlpacaPrompter):
|
39 |
-
|
40 |
|
41 |
|
42 |
class MultipleChoiceConcisePrompter(AlpacaPrompter):
|
@@ -64,11 +83,30 @@ class NomicGPT4AllPrompter(AlpacaPrompter):
|
|
64 |
|
65 |
|
66 |
class ReflectAlpacaPrompter:
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
70 |
response_split = "### Response:"
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
def build_prompt(
|
73 |
self,
|
74 |
instruction: str,
|
@@ -118,13 +156,13 @@ class Conversation:
|
|
118 |
def get_prompt(self) -> Generator[str, None, None]:
|
119 |
seps = [self.sep, self.sep2]
|
120 |
preamble = self.system + seps[0]
|
|
|
121 |
for i, (role, message) in enumerate(self.messages):
|
122 |
if message:
|
123 |
-
yield
|
124 |
else:
|
125 |
-
|
126 |
-
|
127 |
-
preamble = ""
|
128 |
|
129 |
def copy(self):
|
130 |
return Conversation(
|
@@ -154,7 +192,17 @@ conv_vicuna_v1_1 = Conversation(
|
|
154 |
|
155 |
|
156 |
class ShareGPTPrompter:
|
157 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
# ignore the system prompt if provided
|
159 |
if source[0]["from"] == "system":
|
160 |
source.pop(0)
|
|
|
1 |
import copy
|
2 |
import dataclasses
|
3 |
+
import logging
|
4 |
from enum import auto, Enum
|
5 |
from typing import List, Tuple, Any, Union, Generator
|
6 |
|
7 |
IGNORE_TOKEN_ID = -100
|
8 |
|
9 |
|
10 |
+
class PromptStyle(Enum):
|
11 |
+
instruct = "instruct"
|
12 |
+
chat = "chat"
|
13 |
+
|
14 |
class AlpacaPrompter:
|
15 |
+
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
16 |
+
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
17 |
+
prompt_style = None
|
18 |
+
|
19 |
+
def __init__(self, prompt_style="instruct"):
|
20 |
+
self.prompt_style = prompt_style
|
21 |
+
self.match_prompt_style()
|
22 |
+
|
23 |
+
def match_prompt_style(self):
|
24 |
+
if self.prompt_style == PromptStyle.instruct.value:
|
25 |
+
self.prompt_input = self.system_prompt + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
26 |
+
self.prompt_no_input = self.system_no_input_prompt + "### Instruction:\n{instruction}\n\n### Response:\n"
|
27 |
+
self.response_split = "### Response:"
|
28 |
+
if self.prompt_style == PromptStyle.chat.value:
|
29 |
+
self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
|
30 |
+
self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
|
31 |
+
self.response_split = "ASSISTANT:"
|
32 |
|
33 |
def build_prompt(
|
34 |
self,
|
|
|
55 |
|
56 |
|
57 |
class MultipleChoiceExplainPrompter(AlpacaPrompter):
|
58 |
+
system_prompt = "Choose the answer that best answers the question. Explain your reasoning."
|
59 |
|
60 |
|
61 |
class MultipleChoiceConcisePrompter(AlpacaPrompter):
|
|
|
83 |
|
84 |
|
85 |
class ReflectAlpacaPrompter:
|
86 |
+
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
|
87 |
+
system_no_input_prompt = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
|
88 |
+
|
89 |
+
prompt_input = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
90 |
+
prompt_no_input = "### Instruction:\n{instruction}\n\n### Response:\n"
|
91 |
+
agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
|
92 |
response_split = "### Response:"
|
93 |
|
94 |
+
def __init__(self, prompt_style="instruct"):
|
95 |
+
self.prompt_style = prompt_style
|
96 |
+
self.match_prompt_style()
|
97 |
+
|
98 |
+
def match_prompt_style(self):
|
99 |
+
if self.prompt_style == PromptStyle.instruct.value:
|
100 |
+
self.prompt_input = self.system_prompt + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
101 |
+
self.prompt_no_input = self.system_no_input_prompt + "### Instruction:\n{instruction}\n\n### Response:\n"
|
102 |
+
self.agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
|
103 |
+
self.response_split = "### Final Response:"
|
104 |
+
if self.prompt_style == PromptStyle.chat.value:
|
105 |
+
self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
|
106 |
+
self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
|
107 |
+
self.agent_label = "\nTHOUGHT: {output}\nASSISTANT REFLECTION: {reflection}\nASSISTANT:"
|
108 |
+
self.response_split = "ASSISTANT:"
|
109 |
+
|
110 |
def build_prompt(
|
111 |
self,
|
112 |
instruction: str,
|
|
|
156 |
def get_prompt(self) -> Generator[str, None, None]:
|
157 |
seps = [self.sep, self.sep2]
|
158 |
preamble = self.system + seps[0]
|
159 |
+
yield preamble
|
160 |
for i, (role, message) in enumerate(self.messages):
|
161 |
if message:
|
162 |
+
yield (role + ":", " " + message)
|
163 |
else:
|
164 |
+
logging.warning("role with empty message: " + role)
|
165 |
+
yield (role + ":", )
|
|
|
166 |
|
167 |
def copy(self):
|
168 |
return Conversation(
|
|
|
192 |
|
193 |
|
194 |
class ShareGPTPrompter:
|
195 |
+
def __init__(self, prompt_style=None):
|
196 |
+
if prompt_style != PromptStyle.chat.value:
|
197 |
+
raise Exception(f"unsupported prompt_style for ShareGPTPrompter({prompt_style})")
|
198 |
+
|
199 |
+
# def match_prompt_style(self):
|
200 |
+
# if self.prompt_style == PromptStyle.chat.value:
|
201 |
+
# self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
|
202 |
+
# self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
|
203 |
+
# self.response_split = "ASSISTANT:"
|
204 |
+
|
205 |
+
def build_prompt(self, source, *args, **kwargs) -> Generator[str, None, None]:
|
206 |
# ignore the system prompt if provided
|
207 |
if source[0]["from"] == "system":
|
208 |
source.pop(0)
|
src/axolotl/utils/data.py
CHANGED
@@ -50,8 +50,16 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
|
|
50 |
if cfg.dataset_prepared_path
|
51 |
else Path(default_dataset_prepared_path) / ds_hash
|
52 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
-
if
|
|
|
|
|
55 |
logging.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
56 |
dataset = load_from_disk(str(prepared_ds_path))
|
57 |
logging.info("Prepared dataset loaded from disk...")
|
@@ -85,68 +93,71 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
|
|
85 |
ds = load_dataset("json", data_files=fp, streaming=True, split=None)
|
86 |
if not ds:
|
87 |
raise Exception("unhandled dataset load")
|
88 |
-
|
89 |
-
|
|
|
|
|
|
|
90 |
ds_strategy = AlpacaPromptTokenizingStrategy(
|
91 |
-
AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
92 |
)
|
93 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
94 |
datasets.append(ds_wrapper)
|
95 |
-
elif
|
96 |
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
97 |
-
MultipleChoiceExplainPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
98 |
)
|
99 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
100 |
datasets.append(ds_wrapper)
|
101 |
-
elif
|
102 |
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
103 |
-
MultipleChoiceConcisePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
104 |
)
|
105 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
106 |
datasets.append(ds_wrapper)
|
107 |
-
elif
|
108 |
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
|
109 |
-
SummarizeTLDRPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
110 |
)
|
111 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
112 |
datasets.append(ds_wrapper)
|
113 |
-
elif
|
114 |
ds_strategy = JeopardyPromptTokenizingStrategy(
|
115 |
-
JeopardyPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
116 |
)
|
117 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
118 |
datasets.append(ds_wrapper)
|
119 |
-
elif
|
120 |
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
121 |
-
AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
122 |
)
|
123 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
124 |
datasets.append(ds_wrapper)
|
125 |
-
elif
|
126 |
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
127 |
-
GPTeacherPrompter(),
|
128 |
tokenizer,
|
129 |
cfg.train_on_inputs,
|
130 |
cfg.sequence_len,
|
131 |
)
|
132 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
133 |
datasets.append(ds_wrapper)
|
134 |
-
elif
|
135 |
ds_strategy = AlpacaReflectionPTStrategy(
|
136 |
-
ReflectAlpacaPrompter(),
|
137 |
tokenizer,
|
138 |
cfg.train_on_inputs,
|
139 |
cfg.sequence_len,
|
140 |
)
|
141 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
142 |
datasets.append(ds_wrapper)
|
143 |
-
elif
|
144 |
ds_strategy = ShareGPTPromptTokenizingStrategy(
|
145 |
-
ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
146 |
)
|
147 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
148 |
datasets.append(ds_wrapper)
|
149 |
-
elif
|
150 |
ds_strategy = CompletionPromptTokenizingStrategy(
|
151 |
CompletionPrompter(),
|
152 |
tokenizer,
|
@@ -168,6 +179,11 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
|
|
168 |
f"Saving merged prepared dataset to disk... {prepared_ds_path}"
|
169 |
)
|
170 |
dataset.save_to_disk(prepared_ds_path)
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
return dataset
|
173 |
|
@@ -182,13 +198,14 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
182 |
|
183 |
if cfg.max_packed_sequence_len is not None:
|
184 |
# see if we can go ahead and load the stacked dataset
|
185 |
-
|
186 |
ds_hash = str(
|
187 |
md5(
|
188 |
(
|
189 |
str(cfg.sequence_len)
|
190 |
+ "@"
|
191 |
+ str(max_packed_sequence_len)
|
|
|
192 |
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
193 |
).encode("utf-8")
|
194 |
).hexdigest()
|
@@ -199,7 +216,19 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
199 |
else Path(default_dataset_prepared_path) / ds_hash
|
200 |
)
|
201 |
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
logging.info(
|
204 |
f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
|
205 |
)
|
@@ -210,6 +239,9 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
210 |
tokenizer, cfg, default_dataset_prepared_path
|
211 |
)
|
212 |
|
|
|
|
|
|
|
213 |
constant_len_dataset = ConstantLengthDataset(
|
214 |
tokenizer,
|
215 |
[dataset],
|
@@ -237,6 +269,11 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
237 |
f"Saving packed prepared dataset to disk... {prepared_ds_path}"
|
238 |
)
|
239 |
dataset.save_to_disk(prepared_ds_path)
|
|
|
|
|
|
|
|
|
|
|
240 |
else:
|
241 |
dataset = load_tokenized_prepared_datasets(
|
242 |
tokenizer, cfg, default_dataset_prepared_path
|
|
|
50 |
if cfg.dataset_prepared_path
|
51 |
else Path(default_dataset_prepared_path) / ds_hash
|
52 |
)
|
53 |
+
dataset = None
|
54 |
+
try:
|
55 |
+
if cfg.push_dataset_to_hub:
|
56 |
+
dataset = load_dataset(f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=True)
|
57 |
+
except:
|
58 |
+
pass
|
59 |
|
60 |
+
if dataset:
|
61 |
+
...
|
62 |
+
elif any(prepared_ds_path.glob("*")):
|
63 |
logging.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
64 |
dataset = load_from_disk(str(prepared_ds_path))
|
65 |
logging.info("Prepared dataset loaded from disk...")
|
|
|
93 |
ds = load_dataset("json", data_files=fp, streaming=True, split=None)
|
94 |
if not ds:
|
95 |
raise Exception("unhandled dataset load")
|
96 |
+
d_type = d.type
|
97 |
+
d_type_split = d.type.split(":")
|
98 |
+
d_base_type = d_type_split[0]
|
99 |
+
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
100 |
+
if d_base_type == "alpaca":
|
101 |
ds_strategy = AlpacaPromptTokenizingStrategy(
|
102 |
+
AlpacaPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
103 |
)
|
104 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
105 |
datasets.append(ds_wrapper)
|
106 |
+
elif d_base_type == "explainchoice":
|
107 |
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
108 |
+
MultipleChoiceExplainPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
109 |
)
|
110 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
111 |
datasets.append(ds_wrapper)
|
112 |
+
elif d_base_type == "concisechoice":
|
113 |
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
114 |
+
MultipleChoiceConcisePrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
115 |
)
|
116 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
117 |
datasets.append(ds_wrapper)
|
118 |
+
elif d_base_type == "summarizetldr":
|
119 |
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
|
120 |
+
SummarizeTLDRPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
121 |
)
|
122 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
123 |
datasets.append(ds_wrapper)
|
124 |
+
elif d_base_type == "jeopardy":
|
125 |
ds_strategy = JeopardyPromptTokenizingStrategy(
|
126 |
+
JeopardyPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
127 |
)
|
128 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
129 |
datasets.append(ds_wrapper)
|
130 |
+
elif d_base_type == "oasst":
|
131 |
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
132 |
+
AlpacaPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
133 |
)
|
134 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
135 |
datasets.append(ds_wrapper)
|
136 |
+
elif d_base_type == "gpteacher":
|
137 |
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
138 |
+
GPTeacherPrompter(d_prompt_style),
|
139 |
tokenizer,
|
140 |
cfg.train_on_inputs,
|
141 |
cfg.sequence_len,
|
142 |
)
|
143 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
144 |
datasets.append(ds_wrapper)
|
145 |
+
elif d_base_type == "reflection":
|
146 |
ds_strategy = AlpacaReflectionPTStrategy(
|
147 |
+
ReflectAlpacaPrompter(d_prompt_style),
|
148 |
tokenizer,
|
149 |
cfg.train_on_inputs,
|
150 |
cfg.sequence_len,
|
151 |
)
|
152 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
153 |
datasets.append(ds_wrapper)
|
154 |
+
elif d_base_type == "sharegpt":
|
155 |
ds_strategy = ShareGPTPromptTokenizingStrategy(
|
156 |
+
ShareGPTPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
157 |
)
|
158 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
159 |
datasets.append(ds_wrapper)
|
160 |
+
elif d_base_type == "completion":
|
161 |
ds_strategy = CompletionPromptTokenizingStrategy(
|
162 |
CompletionPrompter(),
|
163 |
tokenizer,
|
|
|
179 |
f"Saving merged prepared dataset to disk... {prepared_ds_path}"
|
180 |
)
|
181 |
dataset.save_to_disk(prepared_ds_path)
|
182 |
+
if cfg.push_dataset_to_hub:
|
183 |
+
logging.info(
|
184 |
+
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
185 |
+
)
|
186 |
+
dataset.push_to_hub(f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True)
|
187 |
|
188 |
return dataset
|
189 |
|
|
|
198 |
|
199 |
if cfg.max_packed_sequence_len is not None:
|
200 |
# see if we can go ahead and load the stacked dataset
|
201 |
+
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
202 |
ds_hash = str(
|
203 |
md5(
|
204 |
(
|
205 |
str(cfg.sequence_len)
|
206 |
+ "@"
|
207 |
+ str(max_packed_sequence_len)
|
208 |
+
+ seed
|
209 |
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
210 |
).encode("utf-8")
|
211 |
).hexdigest()
|
|
|
216 |
else Path(default_dataset_prepared_path) / ds_hash
|
217 |
)
|
218 |
|
219 |
+
dataset = None
|
220 |
+
try:
|
221 |
+
if cfg.push_dataset_to_hub:
|
222 |
+
logging.info(
|
223 |
+
f"checkking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
224 |
+
)
|
225 |
+
dataset = load_dataset(f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=True)
|
226 |
+
except:
|
227 |
+
pass
|
228 |
+
|
229 |
+
if dataset:
|
230 |
+
...
|
231 |
+
elif any(prepared_ds_path.glob("*")):
|
232 |
logging.info(
|
233 |
f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
|
234 |
)
|
|
|
239 |
tokenizer, cfg, default_dataset_prepared_path
|
240 |
)
|
241 |
|
242 |
+
if cfg.seed:
|
243 |
+
dataset = dataset.shuffle(seed=cfg.seed)
|
244 |
+
|
245 |
constant_len_dataset = ConstantLengthDataset(
|
246 |
tokenizer,
|
247 |
[dataset],
|
|
|
269 |
f"Saving packed prepared dataset to disk... {prepared_ds_path}"
|
270 |
)
|
271 |
dataset.save_to_disk(prepared_ds_path)
|
272 |
+
if cfg.push_dataset_to_hub:
|
273 |
+
logging.info(
|
274 |
+
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
275 |
+
)
|
276 |
+
dataset.push_to_hub(f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True)
|
277 |
else:
|
278 |
dataset = load_tokenized_prepared_datasets(
|
279 |
tokenizer, cfg, default_dataset_prepared_path
|
src/axolotl/utils/models.py
CHANGED
@@ -126,6 +126,32 @@ def load_model(
|
|
126 |
torch_dtype=torch_dtype,
|
127 |
device_map=cfg.device_map,
|
128 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
elif model_type:
|
130 |
model = getattr(transformers, model_type).from_pretrained(
|
131 |
base_model,
|
@@ -266,7 +292,7 @@ def load_llama_adapter(model, cfg):
|
|
266 |
task_type="CAUSAL_LM",
|
267 |
)
|
268 |
|
269 |
-
if cfg.
|
270 |
model = PeftModel.from_pretrained(
|
271 |
model,
|
272 |
cfg.lora_model_dir,
|
@@ -307,7 +333,7 @@ def load_lora(model, cfg):
|
|
307 |
model,
|
308 |
cfg.lora_model_dir,
|
309 |
device_map=cfg.device_map,
|
310 |
-
torch_dtype=torch.float16,
|
311 |
)
|
312 |
else:
|
313 |
model = get_peft_model(model, lora_config)
|
|
|
126 |
torch_dtype=torch_dtype,
|
127 |
device_map=cfg.device_map,
|
128 |
)
|
129 |
+
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
130 |
+
# This is a WIP, still an issue with the backward pass
|
131 |
+
# RuntimeError: grad can be implicitly created only for scalar outputs
|
132 |
+
# TODO: try config.sequence_parallel = False
|
133 |
+
# # https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/tests/models/test_gpt_neox.py#L12
|
134 |
+
# # https://github.com/HazyResearch/flash-attention/tree/main/training#model-components
|
135 |
+
# # add `**kwargs` to https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/flash_attn/models/gpt.py#L442
|
136 |
+
# from flash_attn.utils.pretrained import state_dict_from_pretrained
|
137 |
+
# from flash_attn.models.gpt import GPTLMHeadModel
|
138 |
+
# from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config
|
139 |
+
# from transformers import GPTNeoXConfig
|
140 |
+
# config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(base_model))
|
141 |
+
# config.use_flash_attn = True
|
142 |
+
# config.fused_bias_fc = True
|
143 |
+
# config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast"
|
144 |
+
# config.activation_function = "gelu_fast"
|
145 |
+
# config.fused_dropout_add_ln = True
|
146 |
+
# # config.residual_in_fp32 = True
|
147 |
+
#
|
148 |
+
# model: GPTLMHeadModel = GPTLMHeadModel.from_pretrained(
|
149 |
+
# base_model,
|
150 |
+
# config,
|
151 |
+
# dtype=torch_dtype,
|
152 |
+
# device=cfg.device,
|
153 |
+
# )
|
154 |
+
# model.train() # sets to train instead of eval mode
|
155 |
elif model_type:
|
156 |
model = getattr(transformers, model_type).from_pretrained(
|
157 |
base_model,
|
|
|
292 |
task_type="CAUSAL_LM",
|
293 |
)
|
294 |
|
295 |
+
if cfg.lora_model_dir:
|
296 |
model = PeftModel.from_pretrained(
|
297 |
model,
|
298 |
cfg.lora_model_dir,
|
|
|
333 |
model,
|
334 |
cfg.lora_model_dir,
|
335 |
device_map=cfg.device_map,
|
336 |
+
# torch_dtype=torch.float16,
|
337 |
)
|
338 |
else:
|
339 |
model = get_peft_model(model, lora_config)
|