winglian commited on
Commit
b029a11
·
unverified ·
2 Parent(s): fa8bd14 e3df3a9

Merge pull request #34 from OpenAccess-AI-Collective/dev-unstable

Browse files
.github/workflows/base.yml CHANGED
@@ -11,6 +11,15 @@ jobs:
11
  if: github.repository_owner == 'OpenAccess-AI-Collective'
12
  # this job needs to be run on self-hosted GPU runners...
13
  runs-on: self-hosted
 
 
 
 
 
 
 
 
 
14
  steps:
15
  - name: Checkout
16
  uses: actions/checkout@v3
@@ -32,7 +41,11 @@ jobs:
32
  context: .
33
  file: ./docker/Dockerfile-base
34
  push: ${{ github.event_name != 'pull_request' }}
35
- tags: ${{ steps.metadata.outputs.tags }}
36
  labels: ${{ steps.metadata.outputs.labels }}
37
  cache-from: type=gha
38
  cache-to: type=gha,mode=max
 
 
 
 
 
11
  if: github.repository_owner == 'OpenAccess-AI-Collective'
12
  # this job needs to be run on self-hosted GPU runners...
13
  runs-on: self-hosted
14
+ strategy:
15
+ matrix:
16
+ include:
17
+ - cuda: cu118
18
+ cuda_version: 11.8.0
19
+ pytorch: 2.0.0
20
+ - cuda: cu117
21
+ cuda_version: 11.7.0
22
+ pytorch: 1.13.1
23
  steps:
24
  - name: Checkout
25
  uses: actions/checkout@v3
 
41
  context: .
42
  file: ./docker/Dockerfile-base
43
  push: ${{ github.event_name != 'pull_request' }}
44
+ tags: ${{ steps.metadata.outputs.tags }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
45
  labels: ${{ steps.metadata.outputs.labels }}
46
  cache-from: type=gha
47
  cache-to: type=gha,mode=max
48
+ build-args: |
49
+ CUDA_VERSION=${{ matrix.cuda_version }}
50
+ CUDA=${{ matrix.cuda }}
51
+ PYTORCH_VERSION=${{ matrix.pytorch }}
.github/workflows/main.yml CHANGED
@@ -10,6 +10,15 @@ jobs:
10
  build-axolotl:
11
  if: github.repository_owner == 'OpenAccess-AI-Collective'
12
  # this job needs to be run on self-hosted GPU runners...
 
 
 
 
 
 
 
 
 
13
  runs-on: self-hosted
14
  steps:
15
  - name: Checkout
@@ -31,10 +40,10 @@ jobs:
31
  with:
32
  context: .
33
  build-args: |
34
- BASE_TAG=${{ github.ref_name }}-base
35
  file: ./docker/Dockerfile
36
  push: ${{ github.event_name != 'pull_request' }}
37
- tags: ${{ steps.metadata.outputs.tags }}
38
  labels: ${{ steps.metadata.outputs.labels }}
39
  cache-from: type=gha
40
  cache-to: type=gha,mode=max
@@ -42,6 +51,15 @@ jobs:
42
  needs: build-axolotl
43
  if: github.repository_owner == 'OpenAccess-AI-Collective'
44
  # this job needs to be run on self-hosted GPU runners...
 
 
 
 
 
 
 
 
 
45
  runs-on: self-hosted
46
  steps:
47
  - name: Checkout
@@ -63,10 +81,10 @@ jobs:
63
  with:
64
  context: .
65
  build-args: |
66
- BASE_TAG=${{ github.ref_name }}
67
  file: ./docker/Dockerfile-runpod
68
  push: ${{ github.event_name != 'pull_request' }}
69
- tags: ${{ steps.metadata.outputs.tags }}
70
  labels: ${{ steps.metadata.outputs.labels }}
71
  cache-from: type=gha
72
  cache-to: type=gha,mode=max
 
10
  build-axolotl:
11
  if: github.repository_owner == 'OpenAccess-AI-Collective'
12
  # this job needs to be run on self-hosted GPU runners...
13
+ strategy:
14
+ matrix:
15
+ include:
16
+ - cuda: cu118
17
+ cuda_version: 11.8.0
18
+ pytorch: 2.0.0
19
+ - cuda: cu117
20
+ cuda_version: 11.7.0
21
+ pytorch: 1.13.1
22
  runs-on: self-hosted
23
  steps:
24
  - name: Checkout
 
40
  with:
41
  context: .
42
  build-args: |
43
+ BASE_TAG=${{ github.ref_name }}-base-${{ matrix.cuda }}-${{ matrix.pytorch }}
44
  file: ./docker/Dockerfile
45
  push: ${{ github.event_name != 'pull_request' }}
46
+ tags: ${{ steps.metadata.outputs.tags }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
47
  labels: ${{ steps.metadata.outputs.labels }}
48
  cache-from: type=gha
49
  cache-to: type=gha,mode=max
 
51
  needs: build-axolotl
52
  if: github.repository_owner == 'OpenAccess-AI-Collective'
53
  # this job needs to be run on self-hosted GPU runners...
54
+ strategy:
55
+ matrix:
56
+ include:
57
+ - cuda: cu118
58
+ cuda_version: 11.8.0
59
+ pytorch: 2.0.0
60
+ - cuda: cu117
61
+ cuda_version: 11.7.0
62
+ pytorch: 1.13.1
63
  runs-on: self-hosted
64
  steps:
65
  - name: Checkout
 
81
  with:
82
  context: .
83
  build-args: |
84
+ BASE_TAG=${{ github.ref_name }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
85
  file: ./docker/Dockerfile-runpod
86
  push: ${{ github.event_name != 'pull_request' }}
87
+ tags: ${{ steps.metadata.outputs.tags }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
88
  labels: ${{ steps.metadata.outputs.labels }}
89
  cache-from: type=gha
90
  cache-to: type=gha,mode=max
docker/Dockerfile-base CHANGED
@@ -1,6 +1,7 @@
1
  ARG CUDA_VERSION="11.8.0"
2
  ARG CUDNN_VERSION="8"
3
  ARG UBUNTU_VERSION="22.04"
 
4
 
5
  FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION as base-builder
6
 
@@ -39,6 +40,14 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
39
 
40
  RUN git clone https://github.com/HazyResearch/flash-attention.git && \
41
  cd flash-attention && \
 
 
 
 
 
 
 
 
42
  python3 setup.py bdist_wheel
43
 
44
  FROM base-builder AS deepspeed-builder
@@ -60,8 +69,12 @@ RUN cd apex && MAX_JOBS=1 python3 -m pip install --global-option="--cpp_ext" --g
60
  RUN mkdir /workspace/wheels
61
  COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
62
  COPY --from=flash-attn-builder /workspace/flash-attention/dist/flash_attn-*.whl wheels
 
 
 
 
63
 
64
- RUN pip3 install wheels/deepspeed-*.whl wheels/flash_attn-*.whl
65
  RUN git lfs install --skip-repo
66
  RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main" \
67
  "accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
 
1
  ARG CUDA_VERSION="11.8.0"
2
  ARG CUDNN_VERSION="8"
3
  ARG UBUNTU_VERSION="22.04"
4
+ ARG MAX_JOBS=4
5
 
6
  FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION as base-builder
7
 
 
40
 
41
  RUN git clone https://github.com/HazyResearch/flash-attention.git && \
42
  cd flash-attention && \
43
+ python3 setup.py bdist_wheel && \
44
+ cd csrc/fused_dense_lib && \
45
+ python3 setup.py bdist_wheel && \
46
+ cd csrc/xentropy && \
47
+ python3 setup.py bdist_wheel && \
48
+ cd csrc/rotary && \
49
+ python3 setup.py bdist_wheel && \
50
+ cd csrc/layer_norm && \
51
  python3 setup.py bdist_wheel
52
 
53
  FROM base-builder AS deepspeed-builder
 
69
  RUN mkdir /workspace/wheels
70
  COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
71
  COPY --from=flash-attn-builder /workspace/flash-attention/dist/flash_attn-*.whl wheels
72
+ COPY --from=flash-attn-builder /workspace/flash-attention/csrc/fused_dense_lib/dist/fused_dense_lib-*.whl wheels
73
+ COPY --from=flash-attn-builder /workspace/flash-attention/csrc/xentropy/dist/xentropy-*.whl wheels
74
+ COPY --from=flash-attn-builder /workspace/flash-attention/csrc/rotary/dist/rotary-*.whl wheels
75
+ COPY --from=flash-attn-builder /workspace/flash-attention/csrc/layer_norm/dist/dropout_layer_norm-*.whl wheels
76
 
77
+ RUN pip3 install wheels/deepspeed-*.whl wheels/flash_attn-*.whl wheels/fused_dense_lib-*.whl wheels/xeontropy-*.whl wheels/rotary-*.whl wheels/dropout_layer_norm-*.whl
78
  RUN git lfs install --skip-repo
79
  RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main" \
80
  "accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
scripts/finetune.py CHANGED
@@ -31,7 +31,7 @@ DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
31
  def choose_device(cfg):
32
  def get_device():
33
  if torch.cuda.is_available():
34
- return "cuda"
35
  else:
36
  try:
37
  if torch.backends.mps.is_available():
@@ -131,7 +131,8 @@ def train(
131
  # then overwrite the value
132
  cfg_keys = dict(cfg).keys()
133
  for k in kwargs:
134
- if k in cfg_keys:
 
135
  # handle booleans
136
  if isinstance(cfg[k], bool):
137
  cfg[k] = bool(kwargs[k])
@@ -169,6 +170,15 @@ def train(
169
  inference=("inference" in kwargs),
170
  )
171
 
 
 
 
 
 
 
 
 
 
172
  if "inference" in kwargs:
173
  logging.info("calling do_inference function")
174
  do_inference(cfg, model, tokenizer)
@@ -216,6 +226,8 @@ def train(
216
  )
217
 
218
  logging.info("Starting trainer...")
 
 
219
  resume_from_checkpoint = cfg.resume_from_checkpoint
220
  if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
221
  possible_checkpoints = [
 
31
  def choose_device(cfg):
32
  def get_device():
33
  if torch.cuda.is_available():
34
+ return f"cuda:{cfg.local_rank}"
35
  else:
36
  try:
37
  if torch.backends.mps.is_available():
 
131
  # then overwrite the value
132
  cfg_keys = dict(cfg).keys()
133
  for k in kwargs:
134
+ # if not strict, allow writing to cfg even if it's not in the yml already
135
+ if k in cfg_keys or cfg.strict is False:
136
  # handle booleans
137
  if isinstance(cfg[k], bool):
138
  cfg[k] = bool(kwargs[k])
 
170
  inference=("inference" in kwargs),
171
  )
172
 
173
+ if "merge_lora" in kwargs and cfg.adapter is not None:
174
+ logging.info("running merge of LoRA with base model")
175
+ model = model.merge_and_unload()
176
+
177
+ if cfg.local_rank == 0:
178
+ logging.info("saving merged model")
179
+ model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
180
+ return
181
+
182
  if "inference" in kwargs:
183
  logging.info("calling do_inference function")
184
  do_inference(cfg, model, tokenizer)
 
226
  )
227
 
228
  logging.info("Starting trainer...")
229
+ if cfg.group_by_length:
230
+ logging.info("hang tight... sorting dataset for group_by_length")
231
  resume_from_checkpoint = cfg.resume_from_checkpoint
232
  if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
233
  possible_checkpoints = [
src/axolotl/datasets.py CHANGED
@@ -106,7 +106,7 @@ class ConstantLengthDataset(IterableDataset):
106
  }
107
  else:
108
  logging.warning(
109
- "dropping batch due to tensor size mismatch"
110
  )
111
  buffer = {"input_ids": [], "attention_mask": [], "labels": []}
112
  buffer_len = 0
 
106
  }
107
  else:
108
  logging.warning(
109
+ f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
110
  )
111
  buffer = {"input_ids": [], "attention_mask": [], "labels": []}
112
  buffer_len = 0
src/axolotl/prompt_strategies/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ def load(strategy, tokenizer, cfg):
4
+ try:
5
+ load_fn = "load"
6
+ if strategy.split(".")[-1].startswith("load_"):
7
+ load_fn = strategy.split(".")[-1]
8
+ strategy = ".".join(strategy.split(".")[:-1])
9
+ m = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
10
+ fn = getattr(m, load_fn)
11
+ return fn(tokenizer, cfg)
12
+ except:
13
+ pass
src/axolotl/prompt_strategies/alpaca_chat.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
2
+ from axolotl.prompters import AlpacaPrompter, PromptStyle
3
+
4
+
5
+ def load(tokenizer, cfg):
6
+ return AlpacaPromptTokenizingStrategy(
7
+ AlpacaPrompter(PromptStyle.chat), tokenizer, cfg.train_on_inputs, cfg.sequence_len
8
+ )
src/axolotl/prompt_strategies/alpaca_instruct.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
2
+ from axolotl.prompters import AlpacaPrompter, PromptStyle
3
+
4
+
5
+ def load(tokenizer, cfg):
6
+ return AlpacaPromptTokenizingStrategy(
7
+ AlpacaPrompter(PromptStyle.instruct), tokenizer, cfg.train_on_inputs, cfg.sequence_len
8
+ )
src/axolotl/prompt_strategies/creative_acr.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Generator
2
+
3
+ import yaml
4
+ from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
5
+
6
+
7
+ class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
8
+ def parse_instruction_fields(self, prompt) -> (str, str, str):
9
+ question = prompt["instruction"]
10
+ answer = prompt["revision"] # don't use prompt[answer], that's data we don't want in the dataset
11
+ return (
12
+ question,
13
+ "",
14
+ answer,
15
+ )
16
+
17
+
18
+ class CreativeCritiquePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
19
+ user_prompt = """Given the following Question and Response, critique the Response on a scale of 1-10. You should critique the answer in the following criteria:
20
+ refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question.
21
+ prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means there is prescriptive bias.
22
+ creativity: score how creative the Response is. 10 is most creative, 0 lacks creativity
23
+ comprehensiveness: score how comprehensive and in-depth the Response is. 10 is most comprehensive and in-depth, 0 lacks coverage and depth
24
+ Then for each criteria, give a detailed explanation of why the response doesn't meet the criteria (don't reference the score), what it did right, and how it could be improved for that criteria.
25
+
26
+ Output your answer in YAML format like so:
27
+ scores:
28
+ refusal: <0-10>
29
+ prescriptive_bias: <0-10>
30
+ creativity: <0-10>
31
+ comprehensiveness: <0-10>
32
+ critiques:
33
+ refusal:
34
+ explanation: ...
35
+ improvements: ...
36
+ prescriptive_bias:
37
+ explanation: ...
38
+ improvements: ...
39
+ creativity:
40
+ explanation: ...
41
+ improvements: ...
42
+ comprehensiveness:
43
+ explanation: ...
44
+ improvements: ...
45
+
46
+ Question: {question}
47
+ Answer: {answer}
48
+ """
49
+
50
+ def parse_instruction_fields(self, prompt) -> (str, str, str):
51
+ scores = yaml.dump(prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper)
52
+ critiques = yaml.dump(prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper)
53
+ evaluation = scores + critiques
54
+ question = prompt["instruction"]
55
+ answer = prompt["answer"]
56
+ return (
57
+ self.user_prompt.format(question=question, answer=answer),
58
+ "",
59
+ evaluation,
60
+ )
61
+
62
+
63
+ class CreativeRevisePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
64
+ user_prompt = """Definitions:
65
+ refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question.
66
+ prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means their is prescriptive bias.
67
+ creativity: score how creative the Response is. 10 is most creative, 0 lacks creativity
68
+ comprehensiveness: score how comprehensive and in-depth the Response is. 10 is most comprehensive and in-depth, 0 lacks coverage and depth
69
+
70
+ Given the following Question, Response, and Evaluation, revise the Response based on the Evaluation and recommendations for improvements. Reply only with the revised response.
71
+
72
+ Question: {question}
73
+ Answer: {answer}
74
+ Evaluation:
75
+ {evaluation}
76
+ """
77
+
78
+ def parse_instruction_fields(self, prompt) -> (str, str, str):
79
+ scores = yaml.dump(prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper)
80
+ critiques = yaml.dump(prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper)
81
+ evaluation = scores + critiques
82
+ question = prompt["instruction"]
83
+ answer = prompt["answer"]
84
+ return (
85
+ self.user_prompt.format(question=question, answer=answer, evaluation=evaluation),
86
+ "",
87
+ prompt["revision"],
88
+ )
89
+
90
+
91
+ class CreativePrompterBase:
92
+ system_prompt = ""
93
+ prompt_input = "{system_prompt}\nUSER: {instruction}\nASSISTANT:"
94
+
95
+ def build_prompt(
96
+ self,
97
+ instruction: str,
98
+ input: Union[None, str] = None,
99
+ output: Union[None, str] = None,
100
+ ) -> Generator[str, None, None]:
101
+ if self.system_prompt:
102
+ res = f"{self.system_prompt}\nUSER: {instruction}\nASSISTANT:"
103
+ else:
104
+ res = f"USER: {instruction}\nASSISTANT:"
105
+ if output:
106
+ res = f"{res}{output}"
107
+ yield res
108
+
109
+
110
+ class CreativeAnswerPrompter(CreativePrompterBase):
111
+ system_prompt = "Answer the following question in a comprehensive, in-depth, and creative way. Additionally your response should be relevant, accurate, and free of any ambiguity."
112
+
113
+
114
+ class CreativeCritiquePrompter(CreativePrompterBase):
115
+ system_prompt = ""
116
+
117
+
118
+ class CreativeRevisePrompter(CreativePrompterBase):
119
+ system_prompt = ""
120
+
121
+
122
+ def load_answer(tokenizer, cfg):
123
+ return CreativeAnsweringPromptTokenizingStrategy(
124
+ CreativeAnswerPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
125
+ )
126
+
127
+
128
+ def load_critique(tokenizer, cfg):
129
+ return CreativeCritiquePromptTokenizingStrategy(
130
+ CreativeCritiquePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
131
+ )
132
+
133
+
134
+ def load_revise(tokenizer, cfg):
135
+ return CreativeRevisePromptTokenizingStrategy(
136
+ CreativeRevisePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
137
+ )
src/axolotl/prompt_strategies/pygmalion.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ from collections import defaultdict
4
+ from typing import Generator
5
+
6
+ from axolotl.prompt_tokenizers import PromptTokenizingStrategy
7
+
8
+ IGNORE_TOKEN_ID = -100
9
+
10
+
11
+ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
12
+ bot_prefix_token_ids = []
13
+
14
+ def __init__(self, prompter, tokenizer, *args, **kwargs):
15
+ super().__init__(prompter, tokenizer)
16
+ res = self._tokenize("<|model|>", add_eos_token=False, strip_bos_token=True)
17
+ self.bot_prefix_token_ids = res["input_ids"]
18
+
19
+ def tokenize_prompt(self, prompt):
20
+ result = {
21
+ "input_ids": [],
22
+ "attention_mask": [],
23
+ "labels": [],
24
+ }
25
+ current_len = 0
26
+ for i, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
27
+ role, message = part
28
+ if role == "system":
29
+ prefix = "<|system|>"
30
+ # this should include a bos token, no eos token, strip trailing "\n<START>"
31
+ if message.endswith("\n<START>"):
32
+ message = message[:-8]
33
+ res = self._tokenize(prefix + "Persona: " + message.strip(), add_eos_token=False, strip_bos_token=False)
34
+ # everything from this is masked out from the labels
35
+ labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"])
36
+ elif role == "human":
37
+ prefix = "<|user|>"
38
+ res = self._tokenize(prefix + " " + message.strip(), add_eos_token=False, strip_bos_token=True)
39
+ # everything from this is masked out from the labels
40
+ labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"])
41
+ elif role == "bot":
42
+ prefix = "<|model|>"
43
+ res = self._tokenize(prefix + " " + message.strip(), add_eos_token=True, strip_bos_token=True)
44
+ # mask out the prefix token, rest is not masked out from labels
45
+ # make sure we create the labels first, otherwise we get incorrect lengths
46
+ labels = [ IGNORE_TOKEN_ID ] * len(self.bot_prefix_token_ids) + [*copy.deepcopy(res["input_ids"])][len(self.bot_prefix_token_ids):]
47
+ else:
48
+ logging.warning(f"unknown role in conversation: {role}")
49
+ res = defaultdict(lambda: [])
50
+ input_ids = res["input_ids"]
51
+ input_len = len(input_ids)
52
+ result["input_ids"][current_len : current_len + input_len] = input_ids
53
+ result["attention_mask"][current_len : current_len + input_len] = [
54
+ 1 if x != self.tokenizer.pad_token_id else 0
55
+ for x in input_ids
56
+ ]
57
+ result["labels"][current_len : current_len + input_len] = labels
58
+ current_len += input_len
59
+ return result
60
+
61
+ def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
62
+ result = self.tokenizer(
63
+ prompt,
64
+ truncation=True,
65
+ max_length=self.sequence_len,
66
+ padding=False,
67
+ return_tensors=None,
68
+ )
69
+ if (
70
+ result["input_ids"][-1] != self.tokenizer.eos_token_id
71
+ and len(result["input_ids"]) < self.sequence_len
72
+ and add_eos_token
73
+ ):
74
+ result["input_ids"].append(self.tokenizer.eos_token_id)
75
+ result["attention_mask"].append(1)
76
+
77
+ if (
78
+ result["input_ids"][0] == self.tokenizer.bos_token_id
79
+ and strip_bos_token
80
+ ):
81
+ result["input_ids"] = result["input_ids"][1:]
82
+ result["attention_mask"] = result["attention_mask"][1:]
83
+
84
+ result["labels"] = result["input_ids"].copy()
85
+ return result
86
+
87
+
88
+ class PygmalionPrompter:
89
+ def __init__(self, *args, **kwargs):
90
+ pass
91
+
92
+ def build_prompt(self, source, *args, **kwargs) -> Generator[str, None, None]:
93
+ for msg in source:
94
+ yield msg["role"], msg["value"]
95
+
96
+
97
+ def load(tokenizer, cfg):
98
+ return PygmalionPromptTokenizingStrategy(
99
+ PygmalionPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
100
+ )
src/axolotl/prompt_tokenizers.py CHANGED
@@ -1,5 +1,7 @@
1
  import abc
2
  import copy
 
 
3
 
4
  from transformers import PreTrainedTokenizer
5
 
@@ -33,6 +35,20 @@ class PromptTokenizingStrategy(abc.ABC):
33
  def tokenize_prompt(self, prompt):
34
  pass
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
38
  def parse_instruction_fields(self, prompt) -> (str, str, str):
@@ -63,7 +79,7 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
63
  response,
64
  )))
65
 
66
- def _tokenize(self, prompt, add_eos_token=True):
67
  result = self.tokenizer(
68
  prompt,
69
  truncation=True,
@@ -79,6 +95,13 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
79
  result["input_ids"].append(self.tokenizer.eos_token_id)
80
  result["attention_mask"].append(1)
81
 
 
 
 
 
 
 
 
82
  result["labels"] = result["input_ids"].copy()
83
  return result
84
 
@@ -239,23 +262,35 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
239
  "labels": [],
240
  }
241
  current_len = 0
 
 
242
  try:
243
- for i, part in enumerate(self.prompter.build_prompt(prompt["conversations"], self.tokenizer)):
244
- if i == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  # this is only ever the first part, should include the bos token and the user query
246
  res = self._tokenize(part.strip(), add_eos_token=False, strip_bos_token=False)
247
  # everything from this is masked out from the labels
248
  labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"])
249
- elif i % 2 == 0:
250
- # this is still the user query, we should
251
- res = self._tokenize(part.strip(), add_eos_token=False, strip_bos_token=True)
252
- # everything from this is masked out from the labels
253
- labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"])
254
- else:
255
- # this should be the assistent response, should end with an eos token
256
- res = self._tokenize(part.strip(), add_eos_token=True, strip_bos_token=True)
257
- # not masked out from labels
258
- labels = copy.deepcopy(res["input_ids"])
259
  input_ids = res["input_ids"]
260
  input_len = len(input_ids)
261
  result["input_ids"][current_len : current_len + input_len] = input_ids
 
1
  import abc
2
  import copy
3
+ import functools
4
+ import logging
5
 
6
  from transformers import PreTrainedTokenizer
7
 
 
35
  def tokenize_prompt(self, prompt):
36
  pass
37
 
38
+ @functools.cache
39
+ def _get_user_token(self):
40
+ id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
41
+ if isinstance(id_or_ids, (int,)):
42
+ return id_or_ids
43
+ return False
44
+
45
+ @functools.cache
46
+ def _get_assistant_token(self):
47
+ id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
48
+ if isinstance(id_or_ids, (int,)):
49
+ return id_or_ids
50
+ return False
51
+
52
 
53
  class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
54
  def parse_instruction_fields(self, prompt) -> (str, str, str):
 
79
  response,
80
  )))
81
 
82
+ def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
83
  result = self.tokenizer(
84
  prompt,
85
  truncation=True,
 
95
  result["input_ids"].append(self.tokenizer.eos_token_id)
96
  result["attention_mask"].append(1)
97
 
98
+ if (
99
+ result["input_ids"][0] == self.tokenizer.bos_token_id
100
+ and strip_bos_token
101
+ ):
102
+ result["input_ids"] = result["input_ids"][1:]
103
+ result["attention_mask"] = result["attention_mask"][1:]
104
+
105
  result["labels"] = result["input_ids"].copy()
106
  return result
107
 
 
262
  "labels": [],
263
  }
264
  current_len = 0
265
+ user_token = self._get_user_token()
266
+ assistant_token = self._get_assistant_token()
267
  try:
268
+ for i, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
269
+ if isinstance(part, tuple):
270
+ if part[0] == "USER:":
271
+ part = part[0] + part[1] if not user_token else part[1]
272
+ # this is still the user query, we should
273
+ res = self._tokenize(part.strip(), add_eos_token=False, strip_bos_token=True)
274
+ if user_token:
275
+ res["input_ids"] = [user_token, *res["input_ids"]]
276
+ # everything from this is masked out from the labels
277
+ labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"])
278
+ elif part[0] == "ASSISTANT:":
279
+ # TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
280
+ part = part[0] + part[1] if not assistant_token else part[1]
281
+ # this should be the assistent response, should end with an eos token
282
+ res = self._tokenize(part.strip(), add_eos_token=True, strip_bos_token=True)
283
+ if assistant_token:
284
+ res["input_ids"] = [assistant_token, *res["input_ids"]]
285
+ # not masked out from labels
286
+ labels = copy.deepcopy(res["input_ids"])
287
+ else:
288
+ logging.warning("unhandled role: " + part[0])
289
+ else:
290
  # this is only ever the first part, should include the bos token and the user query
291
  res = self._tokenize(part.strip(), add_eos_token=False, strip_bos_token=False)
292
  # everything from this is masked out from the labels
293
  labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"])
 
 
 
 
 
 
 
 
 
 
294
  input_ids = res["input_ids"]
295
  input_len = len(input_ids)
296
  result["input_ids"][current_len : current_len + input_len] = input_ids
src/axolotl/prompters.py CHANGED
@@ -1,15 +1,34 @@
1
  import copy
2
  import dataclasses
 
3
  from enum import auto, Enum
4
  from typing import List, Tuple, Any, Union, Generator
5
 
6
  IGNORE_TOKEN_ID = -100
7
 
8
 
 
 
 
 
9
  class AlpacaPrompter:
10
- prompt_input = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
11
- prompt_no_input = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n"
12
- response_split = "### Response:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def build_prompt(
15
  self,
@@ -36,7 +55,7 @@ class JeopardyPrompter(AlpacaPrompter):
36
 
37
 
38
  class MultipleChoiceExplainPrompter(AlpacaPrompter):
39
- prompt_input = "Choose the answer that best answers the question. Explain your reasoning.\n\n### Question:\n{instruction}\n\n### Choices:\n{input}\n\n### Response:\n"
40
 
41
 
42
  class MultipleChoiceConcisePrompter(AlpacaPrompter):
@@ -64,11 +83,30 @@ class NomicGPT4AllPrompter(AlpacaPrompter):
64
 
65
 
66
  class ReflectAlpacaPrompter:
67
- prompt_input = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
68
- prompt_no_input = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n### Instruction:\n{instruction}\n\n### Response:\n"
69
- agent_label = "{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
 
 
 
70
  response_split = "### Response:"
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def build_prompt(
73
  self,
74
  instruction: str,
@@ -118,13 +156,13 @@ class Conversation:
118
  def get_prompt(self) -> Generator[str, None, None]:
119
  seps = [self.sep, self.sep2]
120
  preamble = self.system + seps[0]
 
121
  for i, (role, message) in enumerate(self.messages):
122
  if message:
123
- yield preamble + role + ": " + message + seps[i % 2]
124
  else:
125
- yield role + ":"
126
- if i == 0:
127
- preamble = ""
128
 
129
  def copy(self):
130
  return Conversation(
@@ -154,7 +192,17 @@ conv_vicuna_v1_1 = Conversation(
154
 
155
 
156
  class ShareGPTPrompter:
157
- def build_prompt(self, source, tokenizer, sequence_len=2048) -> Generator[str, None, None]:
 
 
 
 
 
 
 
 
 
 
158
  # ignore the system prompt if provided
159
  if source[0]["from"] == "system":
160
  source.pop(0)
 
1
  import copy
2
  import dataclasses
3
+ import logging
4
  from enum import auto, Enum
5
  from typing import List, Tuple, Any, Union, Generator
6
 
7
  IGNORE_TOKEN_ID = -100
8
 
9
 
10
+ class PromptStyle(Enum):
11
+ instruct = "instruct"
12
+ chat = "chat"
13
+
14
  class AlpacaPrompter:
15
+ system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
16
+ system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
17
+ prompt_style = None
18
+
19
+ def __init__(self, prompt_style="instruct"):
20
+ self.prompt_style = prompt_style
21
+ self.match_prompt_style()
22
+
23
+ def match_prompt_style(self):
24
+ if self.prompt_style == PromptStyle.instruct.value:
25
+ self.prompt_input = self.system_prompt + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
26
+ self.prompt_no_input = self.system_no_input_prompt + "### Instruction:\n{instruction}\n\n### Response:\n"
27
+ self.response_split = "### Response:"
28
+ if self.prompt_style == PromptStyle.chat.value:
29
+ self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
30
+ self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
31
+ self.response_split = "ASSISTANT:"
32
 
33
  def build_prompt(
34
  self,
 
55
 
56
 
57
  class MultipleChoiceExplainPrompter(AlpacaPrompter):
58
+ system_prompt = "Choose the answer that best answers the question. Explain your reasoning."
59
 
60
 
61
  class MultipleChoiceConcisePrompter(AlpacaPrompter):
 
83
 
84
 
85
  class ReflectAlpacaPrompter:
86
+ system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
87
+ system_no_input_prompt = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
88
+
89
+ prompt_input = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
90
+ prompt_no_input = "### Instruction:\n{instruction}\n\n### Response:\n"
91
+ agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
92
  response_split = "### Response:"
93
 
94
+ def __init__(self, prompt_style="instruct"):
95
+ self.prompt_style = prompt_style
96
+ self.match_prompt_style()
97
+
98
+ def match_prompt_style(self):
99
+ if self.prompt_style == PromptStyle.instruct.value:
100
+ self.prompt_input = self.system_prompt + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
101
+ self.prompt_no_input = self.system_no_input_prompt + "### Instruction:\n{instruction}\n\n### Response:\n"
102
+ self.agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
103
+ self.response_split = "### Final Response:"
104
+ if self.prompt_style == PromptStyle.chat.value:
105
+ self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
106
+ self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
107
+ self.agent_label = "\nTHOUGHT: {output}\nASSISTANT REFLECTION: {reflection}\nASSISTANT:"
108
+ self.response_split = "ASSISTANT:"
109
+
110
  def build_prompt(
111
  self,
112
  instruction: str,
 
156
  def get_prompt(self) -> Generator[str, None, None]:
157
  seps = [self.sep, self.sep2]
158
  preamble = self.system + seps[0]
159
+ yield preamble
160
  for i, (role, message) in enumerate(self.messages):
161
  if message:
162
+ yield (role + ":", " " + message)
163
  else:
164
+ logging.warning("role with empty message: " + role)
165
+ yield (role + ":", )
 
166
 
167
  def copy(self):
168
  return Conversation(
 
192
 
193
 
194
  class ShareGPTPrompter:
195
+ def __init__(self, prompt_style=None):
196
+ if prompt_style != PromptStyle.chat.value:
197
+ raise Exception(f"unsupported prompt_style for ShareGPTPrompter({prompt_style})")
198
+
199
+ # def match_prompt_style(self):
200
+ # if self.prompt_style == PromptStyle.chat.value:
201
+ # self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
202
+ # self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
203
+ # self.response_split = "ASSISTANT:"
204
+
205
+ def build_prompt(self, source, *args, **kwargs) -> Generator[str, None, None]:
206
  # ignore the system prompt if provided
207
  if source[0]["from"] == "system":
208
  source.pop(0)
src/axolotl/utils/data.py CHANGED
@@ -7,11 +7,13 @@ from datasets import (
7
  load_dataset,
8
  IterableDataset,
9
  Dataset,
10
- concatenate_datasets,
11
  )
12
  from huggingface_hub import hf_hub_download
 
13
 
14
  from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
 
15
  from axolotl.prompt_tokenizers import (
16
  AlpacaPromptTokenizingStrategy,
17
  GPTeacherPromptTokenizingStrategy,
@@ -35,13 +37,15 @@ from axolotl.prompters import (
35
  )
36
 
37
 
38
- def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path):
 
39
  ds_hash = str(
40
  md5(
41
  (
42
  str(cfg.sequence_len)
43
  + "@"
44
  + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
 
45
  ).encode("utf-8")
46
  ).hexdigest()
47
  )
@@ -50,8 +54,17 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
50
  if cfg.dataset_prepared_path
51
  else Path(default_dataset_prepared_path) / ds_hash
52
  )
 
 
 
 
 
 
 
53
 
54
- if any(prepared_ds_path.glob("*")):
 
 
55
  logging.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
56
  dataset = load_from_disk(str(prepared_ds_path))
57
  logging.info("Prepared dataset loaded from disk...")
@@ -63,7 +76,7 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
63
  ds = None
64
  ds_from_hub = False
65
  try:
66
- load_dataset(d.path, streaming=True)
67
  ds_from_hub = True
68
  except FileNotFoundError:
69
  pass
@@ -71,82 +84,88 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
71
  # prefer local dataset, even if hub exists
72
  if Path(d.path).exists():
73
  ds: IterableDataset = load_dataset(
74
- "json", data_files=d.path, streaming=True, split=None
75
  )
76
  elif ds_from_hub:
77
  if d.data_files:
78
- ds = load_dataset(d.path, streaming=True, data_files=d.data_files)
79
  else:
80
- ds = load_dataset(d.path, streaming=True)
81
  else:
82
  fp = hf_hub_download(
83
  repo_id=d.path, repo_type="dataset", filename=d.data_files
84
  )
85
- ds = load_dataset("json", data_files=fp, streaming=True, split=None)
86
  if not ds:
87
  raise Exception("unhandled dataset load")
88
-
89
- if d.type == "alpaca":
 
 
 
 
 
 
90
  ds_strategy = AlpacaPromptTokenizingStrategy(
91
- AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
92
  )
93
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
94
  datasets.append(ds_wrapper)
95
- elif d.type == "explainchoice":
96
  ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
97
- MultipleChoiceExplainPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
98
  )
99
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
100
  datasets.append(ds_wrapper)
101
- elif d.type == "concisechoice":
102
  ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
103
- MultipleChoiceConcisePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
104
  )
105
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
106
  datasets.append(ds_wrapper)
107
- elif d.type == "summarizetldr":
108
  ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
109
- SummarizeTLDRPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
110
  )
111
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
112
  datasets.append(ds_wrapper)
113
- elif d.type == "jeopardy":
114
  ds_strategy = JeopardyPromptTokenizingStrategy(
115
- JeopardyPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
116
  )
117
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
118
  datasets.append(ds_wrapper)
119
- elif d.type == "oasst":
120
  ds_strategy = OpenAssistantPromptTokenizingStrategy(
121
- AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
122
  )
123
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
124
  datasets.append(ds_wrapper)
125
- elif d.type == "gpteacher":
126
  ds_strategy = GPTeacherPromptTokenizingStrategy(
127
- GPTeacherPrompter(),
128
  tokenizer,
129
  cfg.train_on_inputs,
130
  cfg.sequence_len,
131
  )
132
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
133
  datasets.append(ds_wrapper)
134
- elif d.type == "reflection":
135
  ds_strategy = AlpacaReflectionPTStrategy(
136
- ReflectAlpacaPrompter(),
137
  tokenizer,
138
  cfg.train_on_inputs,
139
  cfg.sequence_len,
140
  )
141
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
142
  datasets.append(ds_wrapper)
143
- elif d.type == "sharegpt":
144
  ds_strategy = ShareGPTPromptTokenizingStrategy(
145
- ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
146
  )
147
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
148
  datasets.append(ds_wrapper)
149
- elif d.type == "completion":
150
  ds_strategy = CompletionPromptTokenizingStrategy(
151
  CompletionPrompter(),
152
  tokenizer,
@@ -168,11 +187,16 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
168
  f"Saving merged prepared dataset to disk... {prepared_ds_path}"
169
  )
170
  dataset.save_to_disk(prepared_ds_path)
 
 
 
 
 
171
 
172
  return dataset
173
 
174
 
175
- def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
176
  max_packed_sequence_len = (
177
  cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
178
  )
@@ -180,16 +204,19 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
180
  max_packed_sequence_len, cfg.sequence_len
181
  ) # make sure we don't accidentally set it larger than sequence_len
182
 
 
183
  if cfg.max_packed_sequence_len is not None:
184
  # see if we can go ahead and load the stacked dataset
185
-
186
  ds_hash = str(
187
  md5(
188
  (
189
  str(cfg.sequence_len)
190
  + "@"
191
  + str(max_packed_sequence_len)
 
192
  + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
 
193
  ).encode("utf-8")
194
  ).hexdigest()
195
  )
@@ -199,17 +226,38 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
199
  else Path(default_dataset_prepared_path) / ds_hash
200
  )
201
 
202
- if any(prepared_ds_path.glob("*")):
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  logging.info(
204
  f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
205
  )
206
  dataset = load_from_disk(str(prepared_ds_path))
207
  logging.info("Prepared packed dataset loaded from disk...")
 
 
 
 
 
208
  else:
209
  dataset = load_tokenized_prepared_datasets(
210
  tokenizer, cfg, default_dataset_prepared_path
211
  )
212
 
 
 
 
213
  constant_len_dataset = ConstantLengthDataset(
214
  tokenizer,
215
  [dataset],
@@ -237,6 +285,11 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
237
  f"Saving packed prepared dataset to disk... {prepared_ds_path}"
238
  )
239
  dataset.save_to_disk(prepared_ds_path)
 
 
 
 
 
240
  else:
241
  dataset = load_tokenized_prepared_datasets(
242
  tokenizer, cfg, default_dataset_prepared_path
 
7
  load_dataset,
8
  IterableDataset,
9
  Dataset,
10
+ concatenate_datasets, DatasetDict,
11
  )
12
  from huggingface_hub import hf_hub_download
13
+ from transformers import PreTrainedTokenizerBase
14
 
15
  from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
16
+ from axolotl.prompt_strategies import load
17
  from axolotl.prompt_tokenizers import (
18
  AlpacaPromptTokenizingStrategy,
19
  GPTeacherPromptTokenizingStrategy,
 
37
  )
38
 
39
 
40
+ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path) -> DatasetDict:
41
+ tokenizer_name = tokenizer.__class__.__name__
42
  ds_hash = str(
43
  md5(
44
  (
45
  str(cfg.sequence_len)
46
  + "@"
47
  + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
48
+ + "|" + tokenizer_name
49
  ).encode("utf-8")
50
  ).hexdigest()
51
  )
 
54
  if cfg.dataset_prepared_path
55
  else Path(default_dataset_prepared_path) / ds_hash
56
  )
57
+ dataset = None
58
+ try:
59
+ if cfg.push_dataset_to_hub:
60
+ dataset = load_dataset(f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=True)
61
+ dataset = dataset["train"]
62
+ except:
63
+ pass
64
 
65
+ if dataset:
66
+ ...
67
+ elif any(prepared_ds_path.glob("*")):
68
  logging.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
69
  dataset = load_from_disk(str(prepared_ds_path))
70
  logging.info("Prepared dataset loaded from disk...")
 
76
  ds = None
77
  ds_from_hub = False
78
  try:
79
+ load_dataset(d.path, streaming=True, use_auth_token=True)
80
  ds_from_hub = True
81
  except FileNotFoundError:
82
  pass
 
84
  # prefer local dataset, even if hub exists
85
  if Path(d.path).exists():
86
  ds: IterableDataset = load_dataset(
87
+ "json", data_files=d.path, streaming=False, split=None
88
  )
89
  elif ds_from_hub:
90
  if d.data_files:
91
+ ds = load_dataset(d.path, streaming=False, data_files=d.data_files, use_auth_token=True)
92
  else:
93
+ ds = load_dataset(d.path, streaming=False, use_auth_token=True)
94
  else:
95
  fp = hf_hub_download(
96
  repo_id=d.path, repo_type="dataset", filename=d.data_files
97
  )
98
+ ds = load_dataset("json", data_files=fp, streaming=False, split=None)
99
  if not ds:
100
  raise Exception("unhandled dataset load")
101
+ d_type = d.type
102
+ d_type_split = d_type.split(":")
103
+ d_base_type = d_type_split[0]
104
+ d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
105
+ if (ds_strategy := load(d.type, tokenizer, cfg)):
106
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
107
+ datasets.append(ds_wrapper)
108
+ elif d_base_type == "alpaca":
109
  ds_strategy = AlpacaPromptTokenizingStrategy(
110
+ AlpacaPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
111
  )
112
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
113
  datasets.append(ds_wrapper)
114
+ elif d_base_type == "explainchoice":
115
  ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
116
+ MultipleChoiceExplainPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
117
  )
118
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
119
  datasets.append(ds_wrapper)
120
+ elif d_base_type == "concisechoice":
121
  ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
122
+ MultipleChoiceConcisePrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
123
  )
124
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
125
  datasets.append(ds_wrapper)
126
+ elif d_base_type == "summarizetldr":
127
  ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
128
+ SummarizeTLDRPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
129
  )
130
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
131
  datasets.append(ds_wrapper)
132
+ elif d_base_type == "jeopardy":
133
  ds_strategy = JeopardyPromptTokenizingStrategy(
134
+ JeopardyPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
135
  )
136
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
137
  datasets.append(ds_wrapper)
138
+ elif d_base_type == "oasst":
139
  ds_strategy = OpenAssistantPromptTokenizingStrategy(
140
+ AlpacaPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
141
  )
142
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
143
  datasets.append(ds_wrapper)
144
+ elif d_base_type == "gpteacher":
145
  ds_strategy = GPTeacherPromptTokenizingStrategy(
146
+ GPTeacherPrompter(d_prompt_style),
147
  tokenizer,
148
  cfg.train_on_inputs,
149
  cfg.sequence_len,
150
  )
151
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
152
  datasets.append(ds_wrapper)
153
+ elif d_base_type == "reflection":
154
  ds_strategy = AlpacaReflectionPTStrategy(
155
+ ReflectAlpacaPrompter(d_prompt_style),
156
  tokenizer,
157
  cfg.train_on_inputs,
158
  cfg.sequence_len,
159
  )
160
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
161
  datasets.append(ds_wrapper)
162
+ elif d_base_type == "sharegpt":
163
  ds_strategy = ShareGPTPromptTokenizingStrategy(
164
+ ShareGPTPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
165
  )
166
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
167
  datasets.append(ds_wrapper)
168
+ elif d_base_type == "completion":
169
  ds_strategy = CompletionPromptTokenizingStrategy(
170
  CompletionPrompter(),
171
  tokenizer,
 
187
  f"Saving merged prepared dataset to disk... {prepared_ds_path}"
188
  )
189
  dataset.save_to_disk(prepared_ds_path)
190
+ if cfg.push_dataset_to_hub:
191
+ logging.info(
192
+ f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
193
+ )
194
+ dataset.push_to_hub(f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True)
195
 
196
  return dataset
197
 
198
 
199
+ def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path) -> (Dataset, Dataset):
200
  max_packed_sequence_len = (
201
  cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
202
  )
 
204
  max_packed_sequence_len, cfg.sequence_len
205
  ) # make sure we don't accidentally set it larger than sequence_len
206
 
207
+ tokenizer_name = tokenizer.__class__.__name__
208
  if cfg.max_packed_sequence_len is not None:
209
  # see if we can go ahead and load the stacked dataset
210
+ seed = f"@{str(cfg.seed)}" if cfg.seed else ""
211
  ds_hash = str(
212
  md5(
213
  (
214
  str(cfg.sequence_len)
215
  + "@"
216
  + str(max_packed_sequence_len)
217
+ + seed
218
  + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
219
+ + "|" + tokenizer_name
220
  ).encode("utf-8")
221
  ).hexdigest()
222
  )
 
226
  else Path(default_dataset_prepared_path) / ds_hash
227
  )
228
 
229
+ dataset = None
230
+ try:
231
+ if cfg.push_dataset_to_hub:
232
+ logging.info(
233
+ f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
234
+ )
235
+ dataset = load_dataset(f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=True)
236
+ dataset = dataset["train"]
237
+ except:
238
+ pass
239
+
240
+ if dataset:
241
+ ...
242
+ elif any(prepared_ds_path.glob("*")):
243
  logging.info(
244
  f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
245
  )
246
  dataset = load_from_disk(str(prepared_ds_path))
247
  logging.info("Prepared packed dataset loaded from disk...")
248
+ if cfg.push_dataset_to_hub:
249
+ logging.info(
250
+ f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
251
+ )
252
+ dataset.push_to_hub(f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True)
253
  else:
254
  dataset = load_tokenized_prepared_datasets(
255
  tokenizer, cfg, default_dataset_prepared_path
256
  )
257
 
258
+ if cfg.seed:
259
+ dataset = dataset.shuffle(seed=cfg.seed)
260
+
261
  constant_len_dataset = ConstantLengthDataset(
262
  tokenizer,
263
  [dataset],
 
285
  f"Saving packed prepared dataset to disk... {prepared_ds_path}"
286
  )
287
  dataset.save_to_disk(prepared_ds_path)
288
+ if cfg.push_dataset_to_hub:
289
+ logging.info(
290
+ f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
291
+ )
292
+ dataset.push_to_hub(f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True)
293
  else:
294
  dataset = load_tokenized_prepared_datasets(
295
  tokenizer, cfg, default_dataset_prepared_path
src/axolotl/utils/models.py CHANGED
@@ -126,6 +126,32 @@ def load_model(
126
  torch_dtype=torch_dtype,
127
  device_map=cfg.device_map,
128
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  elif model_type:
130
  model = getattr(transformers, model_type).from_pretrained(
131
  base_model,
@@ -194,7 +220,7 @@ def load_model(
194
  for k, v in cfg.special_tokens.items():
195
  tokenizer.add_special_tokens({k: v})
196
  if cfg.tokens:
197
- tokenizer.add_tokens(cfg.tokens)
198
 
199
  embeddings_len = math.ceil(len(tokenizer) / 32) * 32
200
  model.resize_token_embeddings(embeddings_len)
@@ -266,7 +292,8 @@ def load_llama_adapter(model, cfg):
266
  task_type="CAUSAL_LM",
267
  )
268
 
269
- if cfg.peft_model_dir:
 
270
  model = PeftModel.from_pretrained(
271
  model,
272
  cfg.lora_model_dir,
@@ -307,7 +334,7 @@ def load_lora(model, cfg):
307
  model,
308
  cfg.lora_model_dir,
309
  device_map=cfg.device_map,
310
- torch_dtype=torch.float16,
311
  )
312
  else:
313
  model = get_peft_model(model, lora_config)
 
126
  torch_dtype=torch_dtype,
127
  device_map=cfg.device_map,
128
  )
129
+ # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
130
+ # This is a WIP, still an issue with the backward pass
131
+ # RuntimeError: grad can be implicitly created only for scalar outputs
132
+ # TODO: try config.sequence_parallel = False
133
+ # # https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/tests/models/test_gpt_neox.py#L12
134
+ # # https://github.com/HazyResearch/flash-attention/tree/main/training#model-components
135
+ # # add `**kwargs` to https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/flash_attn/models/gpt.py#L442
136
+ # from flash_attn.utils.pretrained import state_dict_from_pretrained
137
+ # from flash_attn.models.gpt import GPTLMHeadModel
138
+ # from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config
139
+ # from transformers import GPTNeoXConfig
140
+ # config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(base_model))
141
+ # config.use_flash_attn = True
142
+ # config.fused_bias_fc = True
143
+ # config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast"
144
+ # config.activation_function = "gelu_fast"
145
+ # config.fused_dropout_add_ln = True
146
+ # # config.residual_in_fp32 = True
147
+ #
148
+ # model: GPTLMHeadModel = GPTLMHeadModel.from_pretrained(
149
+ # base_model,
150
+ # config,
151
+ # dtype=torch_dtype,
152
+ # device=cfg.device,
153
+ # )
154
+ # model.train() # sets to train instead of eval mode
155
  elif model_type:
156
  model = getattr(transformers, model_type).from_pretrained(
157
  base_model,
 
220
  for k, v in cfg.special_tokens.items():
221
  tokenizer.add_special_tokens({k: v})
222
  if cfg.tokens:
223
+ tokenizer.add_tokens(list(cfg.tokens))
224
 
225
  embeddings_len = math.ceil(len(tokenizer) / 32) * 32
226
  model.resize_token_embeddings(embeddings_len)
 
292
  task_type="CAUSAL_LM",
293
  )
294
 
295
+ if cfg.lora_model_dir:
296
+ logging.info("Loading pretained LORA")
297
  model = PeftModel.from_pretrained(
298
  model,
299
  cfg.lora_model_dir,
 
334
  model,
335
  cfg.lora_model_dir,
336
  device_map=cfg.device_map,
337
+ # torch_dtype=torch.float16,
338
  )
339
  else:
340
  model = get_peft_model(model, lora_config)
src/axolotl/utils/trainer.py CHANGED
@@ -9,13 +9,31 @@ import torch.cuda
9
  import transformers
10
  from torch import nn
11
  from torch.optim.lr_scheduler import OneCycleLR
12
- from transformers import EarlyStoppingCallback
13
  from transformers.trainer_pt_utils import get_parameter_names
14
 
15
  from axolotl.utils.schedulers import InterpolatingLogScheduler
16
  from axolotl.utils.callbacks import SavePeftModelCallback
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
20
  total_num_steps = int(
21
  math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
@@ -119,6 +137,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
119
  cfg.optimizer == "adamw_bnb_8bit"
120
  and not cfg.load_4bit
121
  and not "deepspeed" in training_arguments_kwargs
 
122
  ):
123
  decay_parameters = get_parameter_names(model, [nn.LayerNorm])
124
  decay_parameters = [name for name in decay_parameters if "bias" not in name]
@@ -157,7 +176,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
157
  cfg.learning_rate,
158
  total_steps=total_num_steps,
159
  epochs=cfg.num_epochs,
160
- div_factor=10,
161
  **lr_scheduler_kwargs,
162
  )
163
  elif cfg.lr_scheduler == "log_sweep":
@@ -182,7 +201,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
182
  cfg.early_stopping_patience,
183
  )
184
  callbacks.append(early_stop_cb)
185
-
186
  if cfg.local_rank == 0 and cfg.adapter == 'lora': # only save in rank 0
187
  callbacks.append(SavePeftModelCallback)
188
 
@@ -194,7 +213,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
194
  else:
195
  data_collator_kwargs["pad_to_multiple_of"] = 8
196
 
197
- trainer = transformers.Trainer(
 
198
  model=model,
199
  train_dataset=train_dataset,
200
  eval_dataset=eval_dataset,
 
9
  import transformers
10
  from torch import nn
11
  from torch.optim.lr_scheduler import OneCycleLR
12
+ from transformers import EarlyStoppingCallback, Trainer
13
  from transformers.trainer_pt_utils import get_parameter_names
14
 
15
  from axolotl.utils.schedulers import InterpolatingLogScheduler
16
  from axolotl.utils.callbacks import SavePeftModelCallback
17
 
18
 
19
+ class OneCycleLRSchedulerTrainer(Trainer):
20
+ def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
21
+ optimizer=self.optimizer if optimizer is None else optimizer
22
+ num_warmup_steps=self.args.get_warmup_steps(num_training_steps)
23
+ num_training_steps=num_training_steps
24
+ pct_start = num_warmup_steps / num_training_steps
25
+
26
+ lr_scheduler = OneCycleLR(
27
+ optimizer,
28
+ max_lr=self.args.learning_rate,
29
+ total_steps=num_training_steps,
30
+ pct_start=pct_start,
31
+ div_factor=6,
32
+ )
33
+
34
+ return lr_scheduler
35
+
36
+
37
  def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
38
  total_num_steps = int(
39
  math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
 
137
  cfg.optimizer == "adamw_bnb_8bit"
138
  and not cfg.load_4bit
139
  and not "deepspeed" in training_arguments_kwargs
140
+ and not cfg.fsdp
141
  ):
142
  decay_parameters = get_parameter_names(model, [nn.LayerNorm])
143
  decay_parameters = [name for name in decay_parameters if "bias" not in name]
 
176
  cfg.learning_rate,
177
  total_steps=total_num_steps,
178
  epochs=cfg.num_epochs,
179
+ div_factor=cfg.lr_div_factor if cfg.lr_div_factor else 6,
180
  **lr_scheduler_kwargs,
181
  )
182
  elif cfg.lr_scheduler == "log_sweep":
 
201
  cfg.early_stopping_patience,
202
  )
203
  callbacks.append(early_stop_cb)
204
+
205
  if cfg.local_rank == 0 and cfg.adapter == 'lora': # only save in rank 0
206
  callbacks.append(SavePeftModelCallback)
207
 
 
213
  else:
214
  data_collator_kwargs["pad_to_multiple_of"] = 8
215
 
216
+ trainer_cls = OneCycleLRSchedulerTrainer if cfg.lr_scheduler == "one_cycle" and cfg.fsdp else transformers.Trainer
217
+ trainer = trainer_cls(
218
  model=model,
219
  train_dataset=train_dataset,
220
  eval_dataset=eval_dataset,