winglian commited on
Commit
3369c4d
·
unverified ·
2 Parent(s): 6e7d4d5 bc97f9c

Merge pull request #39 from OpenAccess-AI-Collective/dev

Browse files
.github/workflows/base.yml CHANGED
@@ -11,6 +11,15 @@ jobs:
11
  if: github.repository_owner == 'OpenAccess-AI-Collective'
12
  # this job needs to be run on self-hosted GPU runners...
13
  runs-on: self-hosted
 
 
 
 
 
 
 
 
 
14
  steps:
15
  - name: Checkout
16
  uses: actions/checkout@v3
@@ -32,7 +41,11 @@ jobs:
32
  context: .
33
  file: ./docker/Dockerfile-base
34
  push: ${{ github.event_name != 'pull_request' }}
35
- tags: ${{ steps.metadata.outputs.tags }}
36
  labels: ${{ steps.metadata.outputs.labels }}
37
  cache-from: type=gha
38
  cache-to: type=gha,mode=max
 
 
 
 
 
11
  if: github.repository_owner == 'OpenAccess-AI-Collective'
12
  # this job needs to be run on self-hosted GPU runners...
13
  runs-on: self-hosted
14
+ strategy:
15
+ matrix:
16
+ include:
17
+ - cuda: cu118
18
+ cuda_version: 11.8.0
19
+ pytorch: 2.0.0
20
+ - cuda: cu117
21
+ cuda_version: 11.7.0
22
+ pytorch: 1.13.1
23
  steps:
24
  - name: Checkout
25
  uses: actions/checkout@v3
 
41
  context: .
42
  file: ./docker/Dockerfile-base
43
  push: ${{ github.event_name != 'pull_request' }}
44
+ tags: ${{ steps.metadata.outputs.tags }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
45
  labels: ${{ steps.metadata.outputs.labels }}
46
  cache-from: type=gha
47
  cache-to: type=gha,mode=max
48
+ build-args: |
49
+ CUDA_VERSION=${{ matrix.cuda_version }}
50
+ CUDA=${{ matrix.cuda }}
51
+ PYTORCH_VERSION=${{ matrix.pytorch }}
.github/workflows/main.yml CHANGED
@@ -10,6 +10,15 @@ jobs:
10
  build-axolotl:
11
  if: github.repository_owner == 'OpenAccess-AI-Collective'
12
  # this job needs to be run on self-hosted GPU runners...
 
 
 
 
 
 
 
 
 
13
  runs-on: self-hosted
14
  steps:
15
  - name: Checkout
@@ -31,10 +40,10 @@ jobs:
31
  with:
32
  context: .
33
  build-args: |
34
- BASE_TAG=${{ github.ref_name }}-base
35
  file: ./docker/Dockerfile
36
  push: ${{ github.event_name != 'pull_request' }}
37
- tags: ${{ steps.metadata.outputs.tags }}
38
  labels: ${{ steps.metadata.outputs.labels }}
39
  cache-from: type=gha
40
  cache-to: type=gha,mode=max
@@ -42,6 +51,15 @@ jobs:
42
  needs: build-axolotl
43
  if: github.repository_owner == 'OpenAccess-AI-Collective'
44
  # this job needs to be run on self-hosted GPU runners...
 
 
 
 
 
 
 
 
 
45
  runs-on: self-hosted
46
  steps:
47
  - name: Checkout
@@ -63,10 +81,10 @@ jobs:
63
  with:
64
  context: .
65
  build-args: |
66
- BASE_TAG=${{ github.ref_name }}
67
  file: ./docker/Dockerfile-runpod
68
  push: ${{ github.event_name != 'pull_request' }}
69
- tags: ${{ steps.metadata.outputs.tags }}
70
  labels: ${{ steps.metadata.outputs.labels }}
71
  cache-from: type=gha
72
  cache-to: type=gha,mode=max
 
10
  build-axolotl:
11
  if: github.repository_owner == 'OpenAccess-AI-Collective'
12
  # this job needs to be run on self-hosted GPU runners...
13
+ strategy:
14
+ matrix:
15
+ include:
16
+ - cuda: cu118
17
+ cuda_version: 11.8.0
18
+ pytorch: 2.0.0
19
+ - cuda: cu117
20
+ cuda_version: 11.7.0
21
+ pytorch: 1.13.1
22
  runs-on: self-hosted
23
  steps:
24
  - name: Checkout
 
40
  with:
41
  context: .
42
  build-args: |
43
+ BASE_TAG=${{ github.ref_name }}-base-${{ matrix.cuda }}-${{ matrix.pytorch }}
44
  file: ./docker/Dockerfile
45
  push: ${{ github.event_name != 'pull_request' }}
46
+ tags: ${{ steps.metadata.outputs.tags }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
47
  labels: ${{ steps.metadata.outputs.labels }}
48
  cache-from: type=gha
49
  cache-to: type=gha,mode=max
 
51
  needs: build-axolotl
52
  if: github.repository_owner == 'OpenAccess-AI-Collective'
53
  # this job needs to be run on self-hosted GPU runners...
54
+ strategy:
55
+ matrix:
56
+ include:
57
+ - cuda: cu118
58
+ cuda_version: 11.8.0
59
+ pytorch: 2.0.0
60
+ - cuda: cu117
61
+ cuda_version: 11.7.0
62
+ pytorch: 1.13.1
63
  runs-on: self-hosted
64
  steps:
65
  - name: Checkout
 
81
  with:
82
  context: .
83
  build-args: |
84
+ BASE_TAG=${{ github.ref_name }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
85
  file: ./docker/Dockerfile-runpod
86
  push: ${{ github.event_name != 'pull_request' }}
87
+ tags: ${{ steps.metadata.outputs.tags }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
88
  labels: ${{ steps.metadata.outputs.labels }}
89
  cache-from: type=gha
90
  cache-to: type=gha,mode=max
README.md CHANGED
@@ -324,7 +324,7 @@ If you are inferencing a pretrained LORA, pass
324
  --lora_model_dir ./completed-model
325
  ```
326
 
327
- ### Merge LORA to base (Dev branch 🔧 )
328
 
329
  Add below flag to train command above
330
 
@@ -345,4 +345,4 @@ Please reduce any below
345
 
346
  Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
347
 
348
- PRs are **greatly welcome**!
 
324
  --lora_model_dir ./completed-model
325
  ```
326
 
327
+ ### Merge LORA to base
328
 
329
  Add below flag to train command above
330
 
 
345
 
346
  Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
347
 
348
+ PRs are **greatly welcome**!
docker/Dockerfile-base CHANGED
@@ -1,6 +1,7 @@
1
  ARG CUDA_VERSION="11.8.0"
2
  ARG CUDNN_VERSION="8"
3
  ARG UBUNTU_VERSION="22.04"
 
4
 
5
  FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION as base-builder
6
 
@@ -39,6 +40,14 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
39
 
40
  RUN git clone https://github.com/HazyResearch/flash-attention.git && \
41
  cd flash-attention && \
 
 
 
 
 
 
 
 
42
  python3 setup.py bdist_wheel
43
 
44
  FROM base-builder AS deepspeed-builder
@@ -60,8 +69,12 @@ RUN cd apex && MAX_JOBS=1 python3 -m pip install --global-option="--cpp_ext" --g
60
  RUN mkdir /workspace/wheels
61
  COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
62
  COPY --from=flash-attn-builder /workspace/flash-attention/dist/flash_attn-*.whl wheels
 
 
 
 
63
 
64
- RUN pip3 install wheels/deepspeed-*.whl wheels/flash_attn-*.whl
65
  RUN git lfs install --skip-repo
66
  RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main" \
67
  "accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
 
1
  ARG CUDA_VERSION="11.8.0"
2
  ARG CUDNN_VERSION="8"
3
  ARG UBUNTU_VERSION="22.04"
4
+ ARG MAX_JOBS=4
5
 
6
  FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION as base-builder
7
 
 
40
 
41
  RUN git clone https://github.com/HazyResearch/flash-attention.git && \
42
  cd flash-attention && \
43
+ python3 setup.py bdist_wheel && \
44
+ cd csrc/fused_dense_lib && \
45
+ python3 setup.py bdist_wheel && \
46
+ cd csrc/xentropy && \
47
+ python3 setup.py bdist_wheel && \
48
+ cd csrc/rotary && \
49
+ python3 setup.py bdist_wheel && \
50
+ cd csrc/layer_norm && \
51
  python3 setup.py bdist_wheel
52
 
53
  FROM base-builder AS deepspeed-builder
 
69
  RUN mkdir /workspace/wheels
70
  COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
71
  COPY --from=flash-attn-builder /workspace/flash-attention/dist/flash_attn-*.whl wheels
72
+ COPY --from=flash-attn-builder /workspace/flash-attention/csrc/fused_dense_lib/dist/fused_dense_lib-*.whl wheels
73
+ COPY --from=flash-attn-builder /workspace/flash-attention/csrc/xentropy/dist/xentropy-*.whl wheels
74
+ COPY --from=flash-attn-builder /workspace/flash-attention/csrc/rotary/dist/rotary-*.whl wheels
75
+ COPY --from=flash-attn-builder /workspace/flash-attention/csrc/layer_norm/dist/dropout_layer_norm-*.whl wheels
76
 
77
+ RUN pip3 install wheels/deepspeed-*.whl wheels/flash_attn-*.whl wheels/fused_dense_lib-*.whl wheels/xeontropy-*.whl wheels/rotary-*.whl wheels/dropout_layer_norm-*.whl
78
  RUN git lfs install --skip-repo
79
  RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main" \
80
  "accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
docker/Dockerfile-runpod CHANGED
@@ -1,11 +1,14 @@
1
  ARG BASE_TAG=main
2
  FROM winglian/axolotl:$BASE_TAG
3
 
 
 
4
  RUN apt install --yes --no-install-recommends openssh-server tmux && \
5
  mkdir -p ~/.ssh && \
6
  chmod 700 ~/.ssh && \
7
  printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
8
- chmod +x /workspace/axolotl/scripts/runpod-entrypoint.sh
 
9
 
10
- ENTRYPOINT ["/workspace/axolotl/scripts/runpod-entrypoint.sh"]
11
  CMD ["sleep", "infinity"]
 
1
  ARG BASE_TAG=main
2
  FROM winglian/axolotl:$BASE_TAG
3
 
4
+ COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh
5
+
6
  RUN apt install --yes --no-install-recommends openssh-server tmux && \
7
  mkdir -p ~/.ssh && \
8
  chmod 700 ~/.ssh && \
9
  printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
10
+ chmod +x /workspace/axolotl/scripts/runpod-entrypoint.sh && \
11
+ chmod +x /root/runpod-entrypoint.sh
12
 
13
+ ENTRYPOINT ["/root/runpod-entrypoint.sh"]
14
  CMD ["sleep", "infinity"]
examples/replit-3b/config-lora.yml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: replit/replit-code-v1-3b
2
+ base_model_config: replit/replit-code-v1-3b
3
+ trust_remote_code: true
4
+ load_in_8bit: false
5
+ datasets:
6
+ - path: vicgalle/alpaca-gpt4
7
+ type: alpaca
8
+ dataset_prepared_path: last_run_prepared
9
+ val_set_size: 0.05
10
+ adapter: lora
11
+ lora_model_dir:
12
+ sequence_len: 2048
13
+ max_packed_sequence_len:
14
+ lora_r: 8
15
+ lora_alpha: 16
16
+ lora_dropout: 0.05
17
+ lora_target_modules:
18
+ - Wqkv
19
+ - mlp_up
20
+ - mlp_down
21
+ lora_fan_in_fan_out:
22
+ wandb_project: lora-replit
23
+ wandb_watch:
24
+ wandb_run_id:
25
+ wandb_log_model:
26
+ output_dir: ./lora-replit
27
+ batch_size: 8
28
+ micro_batch_size: 1
29
+ num_epochs: 3
30
+ optimizer:
31
+ torchdistx_path:
32
+ lr_scheduler:
33
+ learning_rate: 0.00001
34
+ train_on_inputs: false
35
+ group_by_length: false
36
+ bf16: true
37
+ tf32: true
38
+ gradient_checkpointing:
39
+ early_stopping_patience:
40
+ resume_from_checkpoint:
41
+ local_rank:
42
+ logging_steps: 1
43
+ xformers_attention:
44
+ flash_attention:
45
+ gptq_groupsize:
46
+ gptq_model_v1:
47
+ warmup_steps: 20
48
+ eval_steps: 50
49
+ save_steps:
50
+ debug:
51
+ deepspeed:
52
+ weight_decay: 0
53
+ fsdp:
54
+ fsdp_config:
55
+ #special_tokens:
requirements.txt CHANGED
@@ -1,12 +1,12 @@
1
  peft @ git+https://github.com/huggingface/peft.git
2
  transformers @ git+https://github.com/huggingface/transformers.git
 
3
  attrdict
4
  fire
5
  PyYAML==6.0
6
  black
7
- bitsandbytes==0.37.2
8
  datasets
9
- accelerate
10
  sentencepiece
11
  wandb
12
  einops
 
1
  peft @ git+https://github.com/huggingface/peft.git
2
  transformers @ git+https://github.com/huggingface/transformers.git
3
+ bitsandbytes>=0.39.0
4
  attrdict
5
  fire
6
  PyYAML==6.0
7
  black
 
8
  datasets
9
+ accelerate>=0.19.0
10
  sentencepiece
11
  wandb
12
  einops
scripts/finetune.py CHANGED
@@ -1,7 +1,6 @@
1
  import importlib
2
  import logging
3
  import os
4
- import pathlib
5
  import random
6
  import signal
7
  import sys
@@ -10,12 +9,12 @@ from typing import Optional
10
 
11
  import fire
12
  import torch
13
- import transformers
14
  import yaml
15
  from attrdict import AttrDefault
16
 
17
  # add src to the pythonpath so we don't need to pip install this
18
  from axolotl.utils.tokenization import check_dataset_labels
 
19
 
20
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
21
  src_dir = os.path.join(project_root, "src")
@@ -33,7 +32,7 @@ DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
33
  def choose_device(cfg):
34
  def get_device():
35
  if torch.cuda.is_available():
36
- return "cuda"
37
  else:
38
  try:
39
  if torch.backends.mps.is_available():
@@ -69,7 +68,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
69
  instruction = get_multi_line_input()
70
  if not instruction:
71
  return
72
- prompt = prompter_module().build_prompt(instruction=instruction)
73
  batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
74
 
75
  model.eval()
@@ -133,7 +132,8 @@ def train(
133
  # then overwrite the value
134
  cfg_keys = dict(cfg).keys()
135
  for k in kwargs:
136
- if k in cfg_keys:
 
137
  # handle booleans
138
  if isinstance(cfg[k], bool):
139
  cfg[k] = bool(kwargs[k])
@@ -159,6 +159,8 @@ def train(
159
  cfg.fp16 = True
160
  cfg.bf16 = False
161
 
 
 
162
  # Load the model and tokenizer
163
  logging.info("loading model, tokenizer, and peft_config...")
164
  model, tokenizer, peft_config = load_model(
@@ -171,6 +173,15 @@ def train(
171
  inference=("inference" in kwargs),
172
  )
173
 
 
 
 
 
 
 
 
 
 
174
  if "inference" in kwargs:
175
  logging.info("calling do_inference function")
176
  do_inference(cfg, model, tokenizer)
@@ -184,10 +195,6 @@ def train(
184
  tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
185
  )
186
 
187
- if prepare_ds_only:
188
- logging.info("Finished preparing dataset. Exiting...")
189
- return
190
-
191
  if cfg.debug:
192
  logging.info("check_dataset_labels...")
193
  check_dataset_labels(
@@ -197,6 +204,10 @@ def train(
197
  tokenizer,
198
  )
199
 
 
 
 
 
200
  trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
201
 
202
  model.config.use_cache = False
@@ -218,6 +229,8 @@ def train(
218
  )
219
 
220
  logging.info("Starting trainer...")
 
 
221
  resume_from_checkpoint = cfg.resume_from_checkpoint
222
  if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
223
  possible_checkpoints = [
@@ -236,7 +249,9 @@ def train(
236
  logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
237
 
238
  # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
239
- model.save_pretrained(cfg.output_dir)
 
 
240
  # trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
241
 
242
 
 
1
  import importlib
2
  import logging
3
  import os
 
4
  import random
5
  import signal
6
  import sys
 
9
 
10
  import fire
11
  import torch
 
12
  import yaml
13
  from attrdict import AttrDefault
14
 
15
  # add src to the pythonpath so we don't need to pip install this
16
  from axolotl.utils.tokenization import check_dataset_labels
17
+ from axolotl.utils.validation import validate_config
18
 
19
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
20
  src_dir = os.path.join(project_root, "src")
 
32
  def choose_device(cfg):
33
  def get_device():
34
  if torch.cuda.is_available():
35
+ return f"cuda:{cfg.local_rank}"
36
  else:
37
  try:
38
  if torch.backends.mps.is_available():
 
68
  instruction = get_multi_line_input()
69
  if not instruction:
70
  return
71
+ prompt: str = next(prompter_module().build_prompt(instruction=instruction))
72
  batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
73
 
74
  model.eval()
 
132
  # then overwrite the value
133
  cfg_keys = dict(cfg).keys()
134
  for k in kwargs:
135
+ # if not strict, allow writing to cfg even if it's not in the yml already
136
+ if k in cfg_keys or cfg.strict is False:
137
  # handle booleans
138
  if isinstance(cfg[k], bool):
139
  cfg[k] = bool(kwargs[k])
 
159
  cfg.fp16 = True
160
  cfg.bf16 = False
161
 
162
+ validate_config(cfg)
163
+
164
  # Load the model and tokenizer
165
  logging.info("loading model, tokenizer, and peft_config...")
166
  model, tokenizer, peft_config = load_model(
 
173
  inference=("inference" in kwargs),
174
  )
175
 
176
+ if "merge_lora" in kwargs and cfg.adapter is not None:
177
+ logging.info("running merge of LoRA with base model")
178
+ model = model.merge_and_unload()
179
+
180
+ if cfg.local_rank == 0:
181
+ logging.info("saving merged model")
182
+ model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
183
+ return
184
+
185
  if "inference" in kwargs:
186
  logging.info("calling do_inference function")
187
  do_inference(cfg, model, tokenizer)
 
195
  tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
196
  )
197
 
 
 
 
 
198
  if cfg.debug:
199
  logging.info("check_dataset_labels...")
200
  check_dataset_labels(
 
204
  tokenizer,
205
  )
206
 
207
+ if prepare_ds_only:
208
+ logging.info("Finished preparing dataset. Exiting...")
209
+ return
210
+
211
  trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
212
 
213
  model.config.use_cache = False
 
229
  )
230
 
231
  logging.info("Starting trainer...")
232
+ if cfg.group_by_length:
233
+ logging.info("hang tight... sorting dataset for group_by_length")
234
  resume_from_checkpoint = cfg.resume_from_checkpoint
235
  if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
236
  possible_checkpoints = [
 
249
  logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
250
 
251
  # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
252
+ # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
253
+ if cfg.local_rank == 0:
254
+ model.save_pretrained(cfg.output_dir)
255
  # trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
256
 
257
 
src/axolotl/datasets.py CHANGED
@@ -106,7 +106,7 @@ class ConstantLengthDataset(IterableDataset):
106
  }
107
  else:
108
  logging.warning(
109
- "dropping batch due to tensor size mismatch"
110
  )
111
  buffer = {"input_ids": [], "attention_mask": [], "labels": []}
112
  buffer_len = 0
 
106
  }
107
  else:
108
  logging.warning(
109
+ f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
110
  )
111
  buffer = {"input_ids": [], "attention_mask": [], "labels": []}
112
  buffer_len = 0
src/axolotl/prompt_strategies/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+
4
+ def load(strategy, tokenizer, cfg):
5
+ try:
6
+ load_fn = "load"
7
+ if strategy.split(".")[-1].startswith("load_"):
8
+ load_fn = strategy.split(".")[-1]
9
+ strategy = ".".join(strategy.split(".")[:-1])
10
+ m = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
11
+ fn = getattr(m, load_fn)
12
+ return fn(tokenizer, cfg)
13
+ except:
14
+ pass
src/axolotl/prompt_strategies/alpaca_chat.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from axolotl.prompt_tokenizers import (
2
+ AlpacaPromptTokenizingStrategy,
3
+ InstructionPromptTokenizingStrategy,
4
+ )
5
+ from axolotl.prompters import AlpacaPrompter, PromptStyle
6
+
7
+
8
+ def load(tokenizer, cfg):
9
+ return AlpacaPromptTokenizingStrategy(
10
+ AlpacaPrompter(PromptStyle.chat.value),
11
+ tokenizer,
12
+ cfg.train_on_inputs,
13
+ cfg.sequence_len,
14
+ )
15
+
16
+
17
+ class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
18
+ def parse_instruction_fields(self, prompt) -> (str, str, str):
19
+ return (
20
+ prompt["question"],
21
+ "",
22
+ prompt["answer"],
23
+ )
24
+
25
+
26
+ def load_qa(tokenizer, cfg):
27
+ return AlpacaQAPromptTokenizingStrategy(
28
+ AlpacaPrompter(PromptStyle.chat.value),
29
+ tokenizer,
30
+ cfg.train_on_inputs,
31
+ cfg.sequence_len,
32
+ )
src/axolotl/prompt_strategies/alpaca_instruct.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
2
+ from axolotl.prompters import AlpacaPrompter, PromptStyle
3
+
4
+
5
+ def load(tokenizer, cfg):
6
+ return AlpacaPromptTokenizingStrategy(
7
+ AlpacaPrompter(PromptStyle.instruct),
8
+ tokenizer,
9
+ cfg.train_on_inputs,
10
+ cfg.sequence_len,
11
+ )
src/axolotl/prompt_strategies/creative_acr.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Generator
2
+
3
+ import yaml
4
+ from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
5
+
6
+
7
+ class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
8
+ def parse_instruction_fields(self, prompt) -> (str, str, str):
9
+ question = prompt["instruction"]
10
+ answer = prompt[
11
+ "revision"
12
+ ] # don't use prompt[answer], that's data we don't want in the dataset
13
+ return (
14
+ question,
15
+ "",
16
+ answer,
17
+ )
18
+
19
+
20
+ class CreativeCritiquePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
21
+ user_prompt = """Given the following Question and Response, critique the Response on a scale of 1-10. You should critique the answer in the following criteria:
22
+ refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question.
23
+ prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means there is prescriptive bias.
24
+ creativity: score how creative the Response is. 10 is most creative, 0 lacks creativity
25
+ comprehensiveness: score how comprehensive and in-depth the Response is. 10 is most comprehensive and in-depth, 0 lacks coverage and depth
26
+ Then for each criteria, give a detailed explanation of why the response doesn't meet the criteria (don't reference the score), what it did right, and how it could be improved for that criteria.
27
+
28
+ Output your answer in YAML format like so:
29
+ scores:
30
+ refusal: <0-10>
31
+ prescriptive_bias: <0-10>
32
+ creativity: <0-10>
33
+ comprehensiveness: <0-10>
34
+ critiques:
35
+ refusal:
36
+ explanation: ...
37
+ improvements: ...
38
+ prescriptive_bias:
39
+ explanation: ...
40
+ improvements: ...
41
+ creativity:
42
+ explanation: ...
43
+ improvements: ...
44
+ comprehensiveness:
45
+ explanation: ...
46
+ improvements: ...
47
+
48
+ Question: {question}
49
+ Answer: {answer}
50
+ """
51
+
52
+ def parse_instruction_fields(self, prompt) -> (str, str, str):
53
+ scores = yaml.dump(
54
+ prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
55
+ )
56
+ critiques = yaml.dump(
57
+ prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
58
+ )
59
+ evaluation = scores + critiques
60
+ question = prompt["instruction"]
61
+ answer = prompt["answer"]
62
+ return (
63
+ self.user_prompt.format(question=question, answer=answer),
64
+ "",
65
+ evaluation,
66
+ )
67
+
68
+
69
+ class CreativeRevisePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
70
+ user_prompt = """Definitions:
71
+ refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question.
72
+ prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means their is prescriptive bias.
73
+ creativity: score how creative the Response is. 10 is most creative, 0 lacks creativity
74
+ comprehensiveness: score how comprehensive and in-depth the Response is. 10 is most comprehensive and in-depth, 0 lacks coverage and depth
75
+
76
+ Given the following Question, Response, and Evaluation, revise the Response based on the Evaluation and recommendations for improvements. Reply only with the revised response.
77
+
78
+ Question: {question}
79
+ Answer: {answer}
80
+ Evaluation:
81
+ {evaluation}
82
+ """
83
+
84
+ def parse_instruction_fields(self, prompt) -> (str, str, str):
85
+ scores = yaml.dump(
86
+ prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
87
+ )
88
+ critiques = yaml.dump(
89
+ prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
90
+ )
91
+ evaluation = scores + critiques
92
+ question = prompt["instruction"]
93
+ answer = prompt["answer"]
94
+ return (
95
+ self.user_prompt.format(
96
+ question=question, answer=answer, evaluation=evaluation
97
+ ),
98
+ "",
99
+ prompt["revision"],
100
+ )
101
+
102
+
103
+ class CreativePrompterBase:
104
+ system_prompt = ""
105
+ prompt_input = "{system_prompt}\nUSER: {instruction}\nASSISTANT:"
106
+
107
+ def build_prompt(
108
+ self,
109
+ instruction: str,
110
+ input: Union[None, str] = None,
111
+ output: Union[None, str] = None,
112
+ ) -> Generator[str, None, None]:
113
+ if self.system_prompt:
114
+ res = f"{self.system_prompt}\nUSER: {instruction}\nASSISTANT:"
115
+ else:
116
+ res = f"USER: {instruction}\nASSISTANT:"
117
+ if output:
118
+ res = f"{res}{output}"
119
+ yield res
120
+
121
+
122
+ class CreativeAnswerPrompter(CreativePrompterBase):
123
+ system_prompt = "Answer the following question in a comprehensive, in-depth, and creative way. Additionally your response should be relevant, accurate, and free of any ambiguity."
124
+
125
+
126
+ class CreativeCritiquePrompter(CreativePrompterBase):
127
+ system_prompt = ""
128
+
129
+
130
+ class CreativeRevisePrompter(CreativePrompterBase):
131
+ system_prompt = ""
132
+
133
+
134
+ def load_answer(tokenizer, cfg):
135
+ return CreativeAnsweringPromptTokenizingStrategy(
136
+ CreativeAnswerPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
137
+ )
138
+
139
+
140
+ def load_critique(tokenizer, cfg):
141
+ return CreativeCritiquePromptTokenizingStrategy(
142
+ CreativeCritiquePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
143
+ )
144
+
145
+
146
+ def load_revise(tokenizer, cfg):
147
+ return CreativeRevisePromptTokenizingStrategy(
148
+ CreativeRevisePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
149
+ )
src/axolotl/prompt_strategies/pygmalion.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ from collections import defaultdict
4
+ from typing import Generator
5
+
6
+ from axolotl.prompt_tokenizers import PromptTokenizingStrategy
7
+
8
+ IGNORE_TOKEN_ID = -100
9
+
10
+
11
+ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
12
+ bot_prefix_token_ids = []
13
+
14
+ def __init__(self, prompter, tokenizer, *args, **kwargs):
15
+ super().__init__(prompter, tokenizer)
16
+ res = self._tokenize("<|model|>", add_eos_token=False, strip_bos_token=True)
17
+ self.bot_prefix_token_ids = res["input_ids"]
18
+
19
+ def tokenize_prompt(self, prompt):
20
+ result = {
21
+ "input_ids": [],
22
+ "attention_mask": [],
23
+ "labels": [],
24
+ }
25
+ current_len = 0
26
+ for i, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
27
+ role, message = part
28
+ if role == "system":
29
+ prefix = "<|system|>"
30
+ # this should include a bos token, no eos token, strip trailing "\n<START>"
31
+ if message.endswith("\n<START>"):
32
+ message = message[:-8]
33
+ res = self._tokenize(
34
+ prefix + "Persona: " + message.strip(),
35
+ add_eos_token=False,
36
+ strip_bos_token=False,
37
+ )
38
+ # everything from this is masked out from the labels
39
+ labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
40
+ elif role == "human":
41
+ prefix = "<|user|>"
42
+ res = self._tokenize(
43
+ prefix + " " + message.strip(),
44
+ add_eos_token=False,
45
+ strip_bos_token=True,
46
+ )
47
+ # everything from this is masked out from the labels
48
+ labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
49
+ elif role == "bot":
50
+ prefix = "<|model|>"
51
+ res = self._tokenize(
52
+ prefix + " " + message.strip(),
53
+ add_eos_token=True,
54
+ strip_bos_token=True,
55
+ )
56
+ # mask out the prefix token, rest is not masked out from labels
57
+ # make sure we create the labels first, otherwise we get incorrect lengths
58
+ labels = [IGNORE_TOKEN_ID] * len(self.bot_prefix_token_ids) + [
59
+ *copy.deepcopy(res["input_ids"])
60
+ ][len(self.bot_prefix_token_ids) :]
61
+ else:
62
+ logging.warning(f"unknown role in conversation: {role}")
63
+ res = defaultdict(lambda: [])
64
+ input_ids = res["input_ids"]
65
+ input_len = len(input_ids)
66
+ result["input_ids"][current_len : current_len + input_len] = input_ids
67
+ result["attention_mask"][current_len : current_len + input_len] = [
68
+ 1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids
69
+ ]
70
+ result["labels"][current_len : current_len + input_len] = labels
71
+ current_len += input_len
72
+ return result
73
+
74
+ def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
75
+ result = self.tokenizer(
76
+ prompt,
77
+ truncation=True,
78
+ max_length=self.sequence_len,
79
+ padding=False,
80
+ return_tensors=None,
81
+ )
82
+ if (
83
+ result["input_ids"][-1] != self.tokenizer.eos_token_id
84
+ and len(result["input_ids"]) < self.sequence_len
85
+ and add_eos_token
86
+ ):
87
+ result["input_ids"].append(self.tokenizer.eos_token_id)
88
+ result["attention_mask"].append(1)
89
+
90
+ if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
91
+ result["input_ids"] = result["input_ids"][1:]
92
+ result["attention_mask"] = result["attention_mask"][1:]
93
+
94
+ result["labels"] = result["input_ids"].copy()
95
+ return result
96
+
97
+
98
+ class PygmalionPrompter:
99
+ def __init__(self, *args, **kwargs):
100
+ pass
101
+
102
+ def build_prompt(self, source, *args, **kwargs) -> Generator[str, None, None]:
103
+ for msg in source:
104
+ yield msg["role"], msg["value"]
105
+
106
+
107
+ def load(tokenizer, cfg):
108
+ return PygmalionPromptTokenizingStrategy(
109
+ PygmalionPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
110
+ )
src/axolotl/prompt_tokenizers.py CHANGED
@@ -1,7 +1,12 @@
1
  import abc
 
 
 
2
 
3
  from transformers import PreTrainedTokenizer
4
 
 
 
5
  IGNORE_INDEX = -100
6
  LLAMA_DEFAULT_PAD_TOKEN = "[PAD]"
7
  LLAMA_DEFAULT_EOS_TOKEN = "</s>"
@@ -30,6 +35,20 @@ class PromptTokenizingStrategy(abc.ABC):
30
  def tokenize_prompt(self, prompt):
31
  pass
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
35
  def parse_instruction_fields(self, prompt) -> (str, str, str):
@@ -40,9 +59,13 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
40
  full_prompt = self._build_full_prompt(instruction, input, response)
41
  tokenized_full_prompt = self._tokenize(full_prompt)
42
  if not self.train_on_inputs:
43
- user_prompt = self.prompter.build_prompt(
44
- instruction,
45
- input,
 
 
 
 
46
  )
47
  tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
48
  user_prompt_len = len(tokenized_user_prompt["input_ids"])
@@ -54,13 +77,17 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
54
  return tokenized_full_prompt
55
 
56
  def _build_full_prompt(self, instruction, input, response):
57
- return self.prompter.build_prompt(
58
- instruction,
59
- input,
60
- response,
 
 
 
 
61
  )
62
 
63
- def _tokenize(self, prompt, add_eos_token=True):
64
  result = self.tokenizer(
65
  prompt,
66
  truncation=True,
@@ -76,6 +103,10 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
76
  result["input_ids"].append(self.tokenizer.eos_token_id)
77
  result["attention_mask"].append(1)
78
 
 
 
 
 
79
  result["labels"] = result["input_ids"].copy()
80
  return result
81
 
@@ -89,6 +120,15 @@ class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
89
  )
90
 
91
 
 
 
 
 
 
 
 
 
 
92
  class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
93
  def parse_instruction_fields(self, prompt) -> (str, str, str):
94
  return (
@@ -107,6 +147,15 @@ class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy)
107
  )
108
 
109
 
 
 
 
 
 
 
 
 
 
110
  class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
111
  def parse_instruction_fields(self, prompt) -> (str, str, str):
112
  return (
@@ -131,13 +180,13 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
131
 
132
  def tokenize_prompt(self, prompt):
133
  instruction = self.parse_instruction_fields(prompt)
134
- full_prompt = self._build_full_prompt(instruction)
135
  tokenized_full_prompt = self._tokenize(full_prompt)
136
 
137
  return tokenized_full_prompt
138
 
139
- def _build_full_prompt(self, instruction):
140
- return self.prompter.build_prompt(instruction)
141
 
142
 
143
  class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
@@ -157,9 +206,13 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
157
  )
158
  tokenized_full_prompt = self._tokenize(full_prompt)
159
  if not self.train_on_inputs:
160
- user_prompt = self.prompter.build_prompt(
161
- instruction,
162
- input,
 
 
 
 
163
  )
164
  tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
165
  user_prompt_len = len(tokenized_user_prompt["input_ids"])
@@ -171,12 +224,16 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
171
  return tokenized_full_prompt
172
 
173
  def _build_full_prompt(self, instruction, input, output, reflection, corrected):
174
- return self.prompter.build_prompt(
175
- instruction,
176
- input,
177
- output,
178
- reflection,
179
- corrected,
 
 
 
 
180
  )
181
 
182
  def _tokenize(self, prompt, add_eos_token=True):
@@ -212,7 +269,80 @@ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
212
 
213
  class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
214
  def tokenize_prompt(self, prompt):
 
 
 
 
 
 
 
 
215
  try:
216
- return self.prompter.build_prompt(prompt["conversations"], self.tokenizer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  except (KeyError, AssertionError, IndexError) as e:
218
  raise InvalidDataException(str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import abc
2
+ import copy
3
+ import functools
4
+ import logging
5
 
6
  from transformers import PreTrainedTokenizer
7
 
8
+ from axolotl.prompters import IGNORE_TOKEN_ID
9
+
10
  IGNORE_INDEX = -100
11
  LLAMA_DEFAULT_PAD_TOKEN = "[PAD]"
12
  LLAMA_DEFAULT_EOS_TOKEN = "</s>"
 
35
  def tokenize_prompt(self, prompt):
36
  pass
37
 
38
+ @functools.cache
39
+ def _get_user_token(self):
40
+ id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
41
+ if isinstance(id_or_ids, (int,)):
42
+ return id_or_ids
43
+ return False
44
+
45
+ @functools.cache
46
+ def _get_assistant_token(self):
47
+ id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
48
+ if isinstance(id_or_ids, (int,)):
49
+ return id_or_ids
50
+ return False
51
+
52
 
53
  class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
54
  def parse_instruction_fields(self, prompt) -> (str, str, str):
 
59
  full_prompt = self._build_full_prompt(instruction, input, response)
60
  tokenized_full_prompt = self._tokenize(full_prompt)
61
  if not self.train_on_inputs:
62
+ user_prompt = next(
63
+ iter(
64
+ self.prompter.build_prompt(
65
+ instruction,
66
+ input,
67
+ )
68
+ )
69
  )
70
  tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
71
  user_prompt_len = len(tokenized_user_prompt["input_ids"])
 
77
  return tokenized_full_prompt
78
 
79
  def _build_full_prompt(self, instruction, input, response):
80
+ return next(
81
+ iter(
82
+ self.prompter.build_prompt(
83
+ instruction,
84
+ input,
85
+ response,
86
+ )
87
+ )
88
  )
89
 
90
+ def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
91
  result = self.tokenizer(
92
  prompt,
93
  truncation=True,
 
103
  result["input_ids"].append(self.tokenizer.eos_token_id)
104
  result["attention_mask"].append(1)
105
 
106
+ if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
107
+ result["input_ids"] = result["input_ids"][1:]
108
+ result["attention_mask"] = result["attention_mask"][1:]
109
+
110
  result["labels"] = result["input_ids"].copy()
111
  return result
112
 
 
120
  )
121
 
122
 
123
+ class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
124
+ def parse_instruction_fields(self, prompt) -> (str, str, str):
125
+ return (
126
+ prompt["question"],
127
+ "\n".join(f'- "{choice}"' for choice in prompt["choices"]),
128
+ prompt["solution"] if "solution" in prompt else prompt["explanation"],
129
+ )
130
+
131
+
132
  class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
133
  def parse_instruction_fields(self, prompt) -> (str, str, str):
134
  return (
 
147
  )
148
 
149
 
150
+ class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
151
+ def parse_instruction_fields(self, prompt) -> (str, str, str):
152
+ return (
153
+ prompt["article"],
154
+ "",
155
+ prompt["summary"],
156
+ )
157
+
158
+
159
  class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
160
  def parse_instruction_fields(self, prompt) -> (str, str, str):
161
  return (
 
180
 
181
  def tokenize_prompt(self, prompt):
182
  instruction = self.parse_instruction_fields(prompt)
183
+ full_prompt = self._build_full_prompt(instruction, None, None)
184
  tokenized_full_prompt = self._tokenize(full_prompt)
185
 
186
  return tokenized_full_prompt
187
 
188
+ def _build_full_prompt(self, instruction, input, response):
189
+ return next(iter(self.prompter.build_prompt(instruction)))
190
 
191
 
192
  class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
 
206
  )
207
  tokenized_full_prompt = self._tokenize(full_prompt)
208
  if not self.train_on_inputs:
209
+ user_prompt = next(
210
+ iter(
211
+ self.prompter.build_prompt(
212
+ instruction,
213
+ input,
214
+ )
215
+ )
216
  )
217
  tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
218
  user_prompt_len = len(tokenized_user_prompt["input_ids"])
 
224
  return tokenized_full_prompt
225
 
226
  def _build_full_prompt(self, instruction, input, output, reflection, corrected):
227
+ return next(
228
+ iter(
229
+ self.prompter.build_prompt(
230
+ instruction,
231
+ input,
232
+ output,
233
+ reflection,
234
+ corrected,
235
+ )
236
+ )
237
  )
238
 
239
  def _tokenize(self, prompt, add_eos_token=True):
 
269
 
270
  class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
271
  def tokenize_prompt(self, prompt):
272
+ result = {
273
+ "input_ids": [],
274
+ "attention_mask": [],
275
+ "labels": [],
276
+ }
277
+ current_len = 0
278
+ user_token = self._get_user_token()
279
+ assistant_token = self._get_assistant_token()
280
  try:
281
+ for i, part in enumerate(
282
+ self.prompter.build_prompt(prompt["conversations"])
283
+ ):
284
+ if isinstance(part, tuple):
285
+ if part[0] == "USER:":
286
+ part = part[0] + part[1] if not user_token else part[1]
287
+ # this is still the user query, we should
288
+ res = self._tokenize(
289
+ part.strip(), add_eos_token=False, strip_bos_token=True
290
+ )
291
+ if user_token:
292
+ res["input_ids"] = [user_token, *res["input_ids"]]
293
+ # everything from this is masked out from the labels
294
+ labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
295
+ elif part[0] == "ASSISTANT:":
296
+ # TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
297
+ part = part[0] + part[1] if not assistant_token else part[1]
298
+ # this should be the assistent response, should end with an eos token
299
+ res = self._tokenize(
300
+ part.strip(), add_eos_token=True, strip_bos_token=True
301
+ )
302
+ if assistant_token:
303
+ res["input_ids"] = [assistant_token, *res["input_ids"]]
304
+ # not masked out from labels
305
+ labels = copy.deepcopy(res["input_ids"])
306
+ else:
307
+ logging.warning("unhandled role: " + part[0])
308
+ else:
309
+ # this is only ever the first part, should include the bos token and the user query
310
+ res = self._tokenize(
311
+ part.strip(), add_eos_token=False, strip_bos_token=False
312
+ )
313
+ # everything from this is masked out from the labels
314
+ labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
315
+ input_ids = res["input_ids"]
316
+ input_len = len(input_ids)
317
+ result["input_ids"][current_len : current_len + input_len] = input_ids
318
+ result["attention_mask"][current_len : current_len + input_len] = [
319
+ 1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids
320
+ ]
321
+ result["labels"][current_len : current_len + input_len] = labels
322
+ current_len += input_len
323
+ return result
324
  except (KeyError, AssertionError, IndexError) as e:
325
  raise InvalidDataException(str(e))
326
+
327
+ def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
328
+ result = self.tokenizer(
329
+ prompt,
330
+ truncation=True,
331
+ max_length=self.sequence_len,
332
+ padding=False,
333
+ return_tensors=None,
334
+ )
335
+ if (
336
+ result["input_ids"][-1] != self.tokenizer.eos_token_id
337
+ and len(result["input_ids"]) < self.sequence_len
338
+ and add_eos_token
339
+ ):
340
+ result["input_ids"].append(self.tokenizer.eos_token_id)
341
+ result["attention_mask"].append(1)
342
+
343
+ if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
344
+ result["input_ids"] = result["input_ids"][1:]
345
+ result["attention_mask"] = result["attention_mask"][1:]
346
+
347
+ result["labels"] = result["input_ids"].copy()
348
+ return result
src/axolotl/prompters.py CHANGED
@@ -1,22 +1,52 @@
1
  import copy
2
  import dataclasses
 
3
  from enum import auto, Enum
4
- from typing import List, Tuple, Any, Union
5
 
6
  IGNORE_TOKEN_ID = -100
7
 
8
 
 
 
 
 
 
9
  class AlpacaPrompter:
10
- prompt_input = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
11
- prompt_no_input = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n"
12
- response_split = "### Response:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def build_prompt(
15
  self,
16
  instruction: str,
17
  input: Union[None, str] = None,
18
  output: Union[None, str] = None,
19
- ) -> str:
20
  # returns the full prompt from instruction and optional input
21
  # if a label (=response, =output) is provided, it's also appended.
22
  if input:
@@ -25,19 +55,42 @@ class AlpacaPrompter:
25
  res = self.prompt_no_input.format(instruction=instruction)
26
  if output:
27
  res = f"{res}{output}"
28
- return res
29
 
30
  def get_response(self, output: str) -> str:
31
  return output.split(self.response_split)[1].strip()
32
 
33
 
 
 
 
 
 
34
  class JeopardyPrompter(AlpacaPrompter):
35
  prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  class CompletionPrompter(AlpacaPrompter):
39
- def build_prompt(self, instruction: str) -> str:
40
- return instruction
 
 
41
 
42
  def get_response(self, output: str) -> str:
43
  return output.strip()
@@ -52,11 +105,44 @@ class NomicGPT4AllPrompter(AlpacaPrompter):
52
 
53
 
54
  class ReflectAlpacaPrompter:
55
- prompt_input = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
56
- prompt_no_input = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n### Instruction:\n{instruction}\n\n### Response:\n"
57
- agent_label = "{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
 
 
 
 
 
58
  response_split = "### Response:"
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def build_prompt(
61
  self,
62
  instruction: str,
@@ -64,7 +150,7 @@ class ReflectAlpacaPrompter:
64
  output: Union[None, str] = None,
65
  reflection: Union[None, str] = None,
66
  corrected: Union[None, str] = None,
67
- ) -> str:
68
  # returns the full prompt from instruction and optional input
69
  # if a label (=response, =output) is provided, it's also appended.
70
  if input:
@@ -76,7 +162,7 @@ class ReflectAlpacaPrompter:
76
  output=output, reflection=reflection, corrected=corrected
77
  )
78
  res = f"{res}{label}"
79
- return res
80
 
81
  def get_response(self, output: str) -> str:
82
  return output.split(self.response_split)[1].strip()
@@ -103,15 +189,16 @@ class Conversation:
103
  sep: str = "###"
104
  sep2: str = None
105
 
106
- def get_prompt(self):
107
  seps = [self.sep, self.sep2]
108
- ret = self.system + seps[0]
 
109
  for i, (role, message) in enumerate(self.messages):
110
  if message:
111
- ret += role + ": " + message + seps[i % 2]
112
  else:
113
- ret += role + ":"
114
- return ret
115
 
116
  def copy(self):
117
  return Conversation(
@@ -136,12 +223,24 @@ conv_vicuna_v1_1 = Conversation(
136
  offset=0,
137
  sep_style=SeparatorStyle.TWO,
138
  sep=" ",
139
- sep2="</s>",
140
  )
141
 
142
 
143
  class ShareGPTPrompter:
144
- def build_prompt(self, source, tokenizer, sequence_len=2048):
 
 
 
 
 
 
 
 
 
 
 
 
145
  # ignore the system prompt if provided
146
  if source[0]["from"] == "system":
147
  source.pop(0)
@@ -171,61 +270,6 @@ class ShareGPTPrompter:
171
  role = roles[sentence["from"]]
172
  assert role == conv.roles[j % 2]
173
  conv.append_message(role, sentence["value"])
174
- # TODO, this concatenates everything, but doesn't seem to properly add the eos_token_id, as the eos_token gets split up
175
- conversation = conv.get_prompt()
176
-
177
- # Tokenize conversations
178
- tokenized_result = tokenizer(
179
- conversation,
180
- truncation=True,
181
- max_length=sequence_len, # FIXME
182
- padding=False,
183
- return_tensors=None,
184
- )
185
- target = copy.deepcopy(tokenized_result["input_ids"])
186
-
187
- # Mask targets
188
- sep = conv.sep + conv.roles[1] + ": "
189
-
190
- rounds = conversation.split(conv.sep2)
191
- rounds = [r + conv.sep2 for r in rounds]
192
- cur_len = 1
193
- target[0] = IGNORE_TOKEN_ID # mask out the bos
194
- for i, rou in enumerate(rounds):
195
- if rou == "":
196
- break
197
-
198
- parts = rou.split(sep)
199
- if len(parts) != 2:
200
- break
201
- parts[0] += sep
202
- round_len = (
203
- len(tokenizer(rou)["input_ids"]) - 1
204
- ) # -1 ignores the bos_token generated for this
205
- # we have to strip the initial part, any dangling whitespace creates an additional ghost token
206
- instruction_len = (
207
- len(tokenizer(parts[0].strip())["input_ids"]) - 1
208
- ) # -1 ignores the bos_token generated for this
209
- target[cur_len : cur_len + instruction_len] = [
210
- IGNORE_TOKEN_ID
211
- ] * instruction_len
212
-
213
- cur_len += round_len
214
- if cur_len >= sequence_len:
215
- break
216
-
217
- # Fix: Truncate the target to have the same length as input_ids
218
- target = target[: len(tokenized_result["input_ids"])]
219
- # target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
220
-
221
- attention_mask = [
222
- 1 if x != tokenizer.pad_token_id else 0
223
- for x in tokenized_result["input_ids"]
224
- ]
225
-
226
- # TODO truncate len to sequence_len
227
- return dict(
228
- input_ids=tokenized_result["input_ids"],
229
- labels=target,
230
- attention_mask=attention_mask,
231
- )
 
1
  import copy
2
  import dataclasses
3
+ import logging
4
  from enum import auto, Enum
5
+ from typing import List, Tuple, Any, Union, Generator
6
 
7
  IGNORE_TOKEN_ID = -100
8
 
9
 
10
+ class PromptStyle(Enum):
11
+ instruct = "instruct"
12
+ chat = "chat"
13
+
14
+
15
  class AlpacaPrompter:
16
+ system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
17
+ system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
18
+ prompt_style = None
19
+
20
+ def __init__(self, prompt_style="instruct"):
21
+ self.prompt_style = prompt_style
22
+ self.match_prompt_style()
23
+
24
+ def match_prompt_style(self):
25
+ if self.prompt_style == PromptStyle.instruct.value:
26
+ self.prompt_input = (
27
+ self.system_prompt
28
+ + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
29
+ )
30
+ self.prompt_no_input = (
31
+ self.system_no_input_prompt
32
+ + "### Instruction:\n{instruction}\n\n### Response:\n"
33
+ )
34
+ self.response_split = "### Response:"
35
+ if self.prompt_style == PromptStyle.chat.value:
36
+ self.prompt_input = (
37
+ self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
38
+ )
39
+ self.prompt_no_input = (
40
+ self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
41
+ )
42
+ self.response_split = "ASSISTANT:"
43
 
44
  def build_prompt(
45
  self,
46
  instruction: str,
47
  input: Union[None, str] = None,
48
  output: Union[None, str] = None,
49
+ ) -> Generator[str, None, None]:
50
  # returns the full prompt from instruction and optional input
51
  # if a label (=response, =output) is provided, it's also appended.
52
  if input:
 
55
  res = self.prompt_no_input.format(instruction=instruction)
56
  if output:
57
  res = f"{res}{output}"
58
+ yield res
59
 
60
  def get_response(self, output: str) -> str:
61
  return output.split(self.response_split)[1].strip()
62
 
63
 
64
+ class UnpromptedPrompter(AlpacaPrompter):
65
+ system_prompt = ""
66
+ system_no_input_prompt = ""
67
+
68
+
69
  class JeopardyPrompter(AlpacaPrompter):
70
  prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
71
 
72
 
73
+ class MultipleChoiceExplainPrompter(AlpacaPrompter):
74
+ system_prompt = (
75
+ "Choose the answer that best answers the question. Explain your reasoning."
76
+ )
77
+
78
+
79
+ class MultipleChoiceConcisePrompter(AlpacaPrompter):
80
+ prompt_input = "Choose the answer that best answers the question. Be concise in your response.\n\nUSER: {instruction}\n{input}\nASSISTANT:\n"
81
+
82
+
83
+ class SummarizeTLDRPrompter(AlpacaPrompter):
84
+ prompt_no_input = (
85
+ "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
86
+ )
87
+
88
+
89
  class CompletionPrompter(AlpacaPrompter):
90
+ def build_prompt(
91
+ self, instruction: str, input=None, output=None
92
+ ) -> Generator[str, None, None]:
93
+ yield instruction
94
 
95
  def get_response(self, output: str) -> str:
96
  return output.strip()
 
105
 
106
 
107
  class ReflectAlpacaPrompter:
108
+ system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
109
+ system_no_input_prompt = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
110
+
111
+ prompt_input = (
112
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
113
+ )
114
+ prompt_no_input = "### Instruction:\n{instruction}\n\n### Response:\n"
115
+ agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
116
  response_split = "### Response:"
117
 
118
+ def __init__(self, prompt_style="instruct"):
119
+ self.prompt_style = prompt_style
120
+ self.match_prompt_style()
121
+
122
+ def match_prompt_style(self):
123
+ if self.prompt_style == PromptStyle.instruct.value:
124
+ self.prompt_input = (
125
+ self.system_prompt
126
+ + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
127
+ )
128
+ self.prompt_no_input = (
129
+ self.system_no_input_prompt
130
+ + "### Instruction:\n{instruction}\n\n### Response:\n"
131
+ )
132
+ self.agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
133
+ self.response_split = "### Final Response:"
134
+ if self.prompt_style == PromptStyle.chat.value:
135
+ self.prompt_input = (
136
+ self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
137
+ )
138
+ self.prompt_no_input = (
139
+ self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
140
+ )
141
+ self.agent_label = (
142
+ "\nTHOUGHT: {output}\nASSISTANT REFLECTION: {reflection}\nASSISTANT:"
143
+ )
144
+ self.response_split = "ASSISTANT:"
145
+
146
  def build_prompt(
147
  self,
148
  instruction: str,
 
150
  output: Union[None, str] = None,
151
  reflection: Union[None, str] = None,
152
  corrected: Union[None, str] = None,
153
+ ) -> Generator[str, None, None]:
154
  # returns the full prompt from instruction and optional input
155
  # if a label (=response, =output) is provided, it's also appended.
156
  if input:
 
162
  output=output, reflection=reflection, corrected=corrected
163
  )
164
  res = f"{res}{label}"
165
+ yield res
166
 
167
  def get_response(self, output: str) -> str:
168
  return output.split(self.response_split)[1].strip()
 
189
  sep: str = "###"
190
  sep2: str = None
191
 
192
+ def get_prompt(self) -> Generator[str, None, None]:
193
  seps = [self.sep, self.sep2]
194
+ preamble = self.system + seps[0]
195
+ yield preamble
196
  for i, (role, message) in enumerate(self.messages):
197
  if message:
198
+ yield (role + ":", " " + message)
199
  else:
200
+ logging.warning("role with empty message: " + role)
201
+ yield (role + ":",)
202
 
203
  def copy(self):
204
  return Conversation(
 
223
  offset=0,
224
  sep_style=SeparatorStyle.TWO,
225
  sep=" ",
226
+ sep2=" ",
227
  )
228
 
229
 
230
  class ShareGPTPrompter:
231
+ def __init__(self, prompt_style=None):
232
+ if prompt_style != PromptStyle.chat.value:
233
+ raise Exception(
234
+ f"unsupported prompt_style for ShareGPTPrompter({prompt_style})"
235
+ )
236
+
237
+ # def match_prompt_style(self):
238
+ # if self.prompt_style == PromptStyle.chat.value:
239
+ # self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
240
+ # self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
241
+ # self.response_split = "ASSISTANT:"
242
+
243
+ def build_prompt(self, source, *args, **kwargs) -> Generator[str, None, None]:
244
  # ignore the system prompt if provided
245
  if source[0]["from"] == "system":
246
  source.pop(0)
 
270
  role = roles[sentence["from"]]
271
  assert role == conv.roles[j % 2]
272
  conv.append_message(role, sentence["value"])
273
+
274
+ for part in conv.get_prompt():
275
+ yield part
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/axolotl/utils/data.py CHANGED
@@ -8,10 +8,13 @@ from datasets import (
8
  IterableDataset,
9
  Dataset,
10
  concatenate_datasets,
 
11
  )
12
  from huggingface_hub import hf_hub_download
 
13
 
14
  from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
 
15
  from axolotl.prompt_tokenizers import (
16
  AlpacaPromptTokenizingStrategy,
17
  GPTeacherPromptTokenizingStrategy,
@@ -20,6 +23,8 @@ from axolotl.prompt_tokenizers import (
20
  ShareGPTPromptTokenizingStrategy,
21
  JeopardyPromptTokenizingStrategy,
22
  CompletionPromptTokenizingStrategy,
 
 
23
  )
24
  from axolotl.prompters import (
25
  AlpacaPrompter,
@@ -28,16 +33,24 @@ from axolotl.prompters import (
28
  ShareGPTPrompter,
29
  JeopardyPrompter,
30
  CompletionPrompter,
 
 
 
31
  )
32
 
33
 
34
- def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path):
 
 
 
35
  ds_hash = str(
36
  md5(
37
  (
38
  str(cfg.sequence_len)
39
  + "@"
40
  + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
 
 
41
  ).encode("utf-8")
42
  ).hexdigest()
43
  )
@@ -46,8 +59,19 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
46
  if cfg.dataset_prepared_path
47
  else Path(default_dataset_prepared_path) / ds_hash
48
  )
 
 
 
 
 
 
 
 
 
49
 
50
- if any(prepared_ds_path.glob("*")):
 
 
51
  logging.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
52
  dataset = load_from_disk(str(prepared_ds_path))
53
  logging.info("Prepared dataset loaded from disk...")
@@ -59,7 +83,7 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
59
  ds = None
60
  ds_from_hub = False
61
  try:
62
- load_dataset(d.path, streaming=True)
63
  ds_from_hub = True
64
  except FileNotFoundError:
65
  pass
@@ -67,64 +91,117 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
67
  # prefer local dataset, even if hub exists
68
  if Path(d.path).exists():
69
  ds: IterableDataset = load_dataset(
70
- "json", data_files=d.path, streaming=True, split=None
71
  )
72
  elif ds_from_hub:
73
  if d.data_files:
74
- ds = load_dataset(d.path, streaming=True, data_files=d.data_files)
 
 
 
 
 
75
  else:
76
- ds = load_dataset(d.path, streaming=True)
77
  else:
78
  fp = hf_hub_download(
79
  repo_id=d.path, repo_type="dataset", filename=d.data_files
80
  )
81
- ds = load_dataset("json", data_files=fp, streaming=True, split=None)
82
  if not ds:
83
  raise Exception("unhandled dataset load")
84
-
85
- if d.type == "alpaca":
 
 
 
 
 
 
 
 
 
86
  ds_strategy = AlpacaPromptTokenizingStrategy(
87
- AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  )
89
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
90
  datasets.append(ds_wrapper)
91
- elif d.type == "jeopardy":
92
  ds_strategy = JeopardyPromptTokenizingStrategy(
93
- JeopardyPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
 
 
 
94
  )
95
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
96
  datasets.append(ds_wrapper)
97
- elif d.type == "oasst":
98
  ds_strategy = OpenAssistantPromptTokenizingStrategy(
99
- AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
 
 
 
100
  )
101
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
102
  datasets.append(ds_wrapper)
103
- elif d.type == "gpteacher":
104
  ds_strategy = GPTeacherPromptTokenizingStrategy(
105
- GPTeacherPrompter(),
106
  tokenizer,
107
  cfg.train_on_inputs,
108
  cfg.sequence_len,
109
  )
110
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
111
  datasets.append(ds_wrapper)
112
- elif d.type == "reflection":
113
  ds_strategy = AlpacaReflectionPTStrategy(
114
- ReflectAlpacaPrompter(),
115
  tokenizer,
116
  cfg.train_on_inputs,
117
  cfg.sequence_len,
118
  )
119
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
120
  datasets.append(ds_wrapper)
121
- elif d.type == "sharegpt":
122
  ds_strategy = ShareGPTPromptTokenizingStrategy(
123
- ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
 
 
 
124
  )
125
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
126
  datasets.append(ds_wrapper)
127
- elif d.type == "completion":
128
  ds_strategy = CompletionPromptTokenizingStrategy(
129
  CompletionPrompter(),
130
  tokenizer,
@@ -146,11 +223,20 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
146
  f"Saving merged prepared dataset to disk... {prepared_ds_path}"
147
  )
148
  dataset.save_to_disk(prepared_ds_path)
 
 
 
 
 
 
 
149
 
150
  return dataset
151
 
152
 
153
- def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
 
 
154
  max_packed_sequence_len = (
155
  cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
156
  )
@@ -158,16 +244,20 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
158
  max_packed_sequence_len, cfg.sequence_len
159
  ) # make sure we don't accidentally set it larger than sequence_len
160
 
 
161
  if cfg.max_packed_sequence_len is not None:
162
  # see if we can go ahead and load the stacked dataset
163
-
164
  ds_hash = str(
165
  md5(
166
  (
167
  str(cfg.sequence_len)
168
  + "@"
169
  + str(max_packed_sequence_len)
 
170
  + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
 
 
171
  ).encode("utf-8")
172
  ).hexdigest()
173
  )
@@ -177,17 +267,42 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
177
  else Path(default_dataset_prepared_path) / ds_hash
178
  )
179
 
180
- if any(prepared_ds_path.glob("*")):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  logging.info(
182
  f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
183
  )
184
  dataset = load_from_disk(str(prepared_ds_path))
185
  logging.info("Prepared packed dataset loaded from disk...")
 
 
 
 
 
 
 
186
  else:
187
  dataset = load_tokenized_prepared_datasets(
188
  tokenizer, cfg, default_dataset_prepared_path
189
  )
190
 
 
 
 
191
  constant_len_dataset = ConstantLengthDataset(
192
  tokenizer,
193
  [dataset],
@@ -204,9 +319,9 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
204
  d
205
  for d in dataset
206
  if len(d["input_ids"]) < cfg.sequence_len
207
- and len(d["input_ids"]) > 0
208
- and len(d["input_ids"]) == len(d["attention_mask"])
209
- and len(d["input_ids"]) == len(d["labels"])
210
  ]
211
  )
212
 
@@ -215,6 +330,13 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
215
  f"Saving packed prepared dataset to disk... {prepared_ds_path}"
216
  )
217
  dataset.save_to_disk(prepared_ds_path)
 
 
 
 
 
 
 
218
  else:
219
  dataset = load_tokenized_prepared_datasets(
220
  tokenizer, cfg, default_dataset_prepared_path
 
8
  IterableDataset,
9
  Dataset,
10
  concatenate_datasets,
11
+ DatasetDict,
12
  )
13
  from huggingface_hub import hf_hub_download
14
+ from transformers import PreTrainedTokenizerBase
15
 
16
  from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
17
+ from axolotl.prompt_strategies import load
18
  from axolotl.prompt_tokenizers import (
19
  AlpacaPromptTokenizingStrategy,
20
  GPTeacherPromptTokenizingStrategy,
 
23
  ShareGPTPromptTokenizingStrategy,
24
  JeopardyPromptTokenizingStrategy,
25
  CompletionPromptTokenizingStrategy,
26
+ AlpacaMultipleChoicePromptTokenizingStrategy,
27
+ SummarizeTLDRPromptTokenizingStrategy,
28
  )
29
  from axolotl.prompters import (
30
  AlpacaPrompter,
 
33
  ShareGPTPrompter,
34
  JeopardyPrompter,
35
  CompletionPrompter,
36
+ MultipleChoiceExplainPrompter,
37
+ SummarizeTLDRPrompter,
38
+ MultipleChoiceConcisePrompter,
39
  )
40
 
41
 
42
+ def load_tokenized_prepared_datasets(
43
+ tokenizer, cfg, default_dataset_prepared_path
44
+ ) -> DatasetDict:
45
+ tokenizer_name = tokenizer.__class__.__name__
46
  ds_hash = str(
47
  md5(
48
  (
49
  str(cfg.sequence_len)
50
  + "@"
51
  + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
52
+ + "|"
53
+ + tokenizer_name
54
  ).encode("utf-8")
55
  ).hexdigest()
56
  )
 
59
  if cfg.dataset_prepared_path
60
  else Path(default_dataset_prepared_path) / ds_hash
61
  )
62
+ dataset = None
63
+ try:
64
+ if cfg.push_dataset_to_hub:
65
+ dataset = load_dataset(
66
+ f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=True
67
+ )
68
+ dataset = dataset["train"]
69
+ except:
70
+ pass
71
 
72
+ if dataset:
73
+ ...
74
+ elif any(prepared_ds_path.glob("*")):
75
  logging.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
76
  dataset = load_from_disk(str(prepared_ds_path))
77
  logging.info("Prepared dataset loaded from disk...")
 
83
  ds = None
84
  ds_from_hub = False
85
  try:
86
+ load_dataset(d.path, streaming=True, use_auth_token=True)
87
  ds_from_hub = True
88
  except FileNotFoundError:
89
  pass
 
91
  # prefer local dataset, even if hub exists
92
  if Path(d.path).exists():
93
  ds: IterableDataset = load_dataset(
94
+ "json", data_files=d.path, streaming=False, split=None
95
  )
96
  elif ds_from_hub:
97
  if d.data_files:
98
+ ds = load_dataset(
99
+ d.path,
100
+ streaming=False,
101
+ data_files=d.data_files,
102
+ use_auth_token=True,
103
+ )
104
  else:
105
+ ds = load_dataset(d.path, streaming=False, use_auth_token=True)
106
  else:
107
  fp = hf_hub_download(
108
  repo_id=d.path, repo_type="dataset", filename=d.data_files
109
  )
110
+ ds = load_dataset("json", data_files=fp, streaming=False, split=None)
111
  if not ds:
112
  raise Exception("unhandled dataset load")
113
+ # support for using a subset of the data
114
+ if d.shards:
115
+ ds = ds.shuffle(seed=42)["train"].shard(num_shards=cfg.shards, index=0)
116
+ d_type = d.type
117
+ d_type_split = d_type.split(":")
118
+ d_base_type = d_type_split[0]
119
+ d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
120
+ if ds_strategy := load(d.type, tokenizer, cfg):
121
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
122
+ datasets.append(ds_wrapper)
123
+ elif d_base_type == "alpaca":
124
  ds_strategy = AlpacaPromptTokenizingStrategy(
125
+ AlpacaPrompter(d_prompt_style),
126
+ tokenizer,
127
+ cfg.train_on_inputs,
128
+ cfg.sequence_len,
129
+ )
130
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
131
+ datasets.append(ds_wrapper)
132
+ elif d_base_type == "explainchoice":
133
+ ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
134
+ MultipleChoiceExplainPrompter(d_prompt_style),
135
+ tokenizer,
136
+ cfg.train_on_inputs,
137
+ cfg.sequence_len,
138
+ )
139
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
140
+ datasets.append(ds_wrapper)
141
+ elif d_base_type == "concisechoice":
142
+ ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
143
+ MultipleChoiceConcisePrompter(d_prompt_style),
144
+ tokenizer,
145
+ cfg.train_on_inputs,
146
+ cfg.sequence_len,
147
+ )
148
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
149
+ datasets.append(ds_wrapper)
150
+ elif d_base_type == "summarizetldr":
151
+ ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
152
+ SummarizeTLDRPrompter(d_prompt_style),
153
+ tokenizer,
154
+ cfg.train_on_inputs,
155
+ cfg.sequence_len,
156
  )
157
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
158
  datasets.append(ds_wrapper)
159
+ elif d_base_type == "jeopardy":
160
  ds_strategy = JeopardyPromptTokenizingStrategy(
161
+ JeopardyPrompter(d_prompt_style),
162
+ tokenizer,
163
+ cfg.train_on_inputs,
164
+ cfg.sequence_len,
165
  )
166
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
167
  datasets.append(ds_wrapper)
168
+ elif d_base_type == "oasst":
169
  ds_strategy = OpenAssistantPromptTokenizingStrategy(
170
+ AlpacaPrompter(d_prompt_style),
171
+ tokenizer,
172
+ cfg.train_on_inputs,
173
+ cfg.sequence_len,
174
  )
175
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
176
  datasets.append(ds_wrapper)
177
+ elif d_base_type == "gpteacher":
178
  ds_strategy = GPTeacherPromptTokenizingStrategy(
179
+ GPTeacherPrompter(d_prompt_style),
180
  tokenizer,
181
  cfg.train_on_inputs,
182
  cfg.sequence_len,
183
  )
184
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
185
  datasets.append(ds_wrapper)
186
+ elif d_base_type == "reflection":
187
  ds_strategy = AlpacaReflectionPTStrategy(
188
+ ReflectAlpacaPrompter(d_prompt_style),
189
  tokenizer,
190
  cfg.train_on_inputs,
191
  cfg.sequence_len,
192
  )
193
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
194
  datasets.append(ds_wrapper)
195
+ elif d_base_type == "sharegpt":
196
  ds_strategy = ShareGPTPromptTokenizingStrategy(
197
+ ShareGPTPrompter(d_prompt_style),
198
+ tokenizer,
199
+ cfg.train_on_inputs,
200
+ cfg.sequence_len,
201
  )
202
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
203
  datasets.append(ds_wrapper)
204
+ elif d_base_type == "completion":
205
  ds_strategy = CompletionPromptTokenizingStrategy(
206
  CompletionPrompter(),
207
  tokenizer,
 
223
  f"Saving merged prepared dataset to disk... {prepared_ds_path}"
224
  )
225
  dataset.save_to_disk(prepared_ds_path)
226
+ if cfg.push_dataset_to_hub:
227
+ logging.info(
228
+ f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
229
+ )
230
+ dataset.push_to_hub(
231
+ f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
232
+ )
233
 
234
  return dataset
235
 
236
 
237
+ def load_prepare_datasets(
238
+ tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path
239
+ ) -> (Dataset, Dataset):
240
  max_packed_sequence_len = (
241
  cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
242
  )
 
244
  max_packed_sequence_len, cfg.sequence_len
245
  ) # make sure we don't accidentally set it larger than sequence_len
246
 
247
+ tokenizer_name = tokenizer.__class__.__name__
248
  if cfg.max_packed_sequence_len is not None:
249
  # see if we can go ahead and load the stacked dataset
250
+ seed = f"@{str(cfg.seed)}" if cfg.seed else ""
251
  ds_hash = str(
252
  md5(
253
  (
254
  str(cfg.sequence_len)
255
  + "@"
256
  + str(max_packed_sequence_len)
257
+ + seed
258
  + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
259
+ + "|"
260
+ + tokenizer_name
261
  ).encode("utf-8")
262
  ).hexdigest()
263
  )
 
267
  else Path(default_dataset_prepared_path) / ds_hash
268
  )
269
 
270
+ dataset = None
271
+ try:
272
+ if cfg.push_dataset_to_hub:
273
+ logging.info(
274
+ f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
275
+ )
276
+ dataset = load_dataset(
277
+ f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=True
278
+ )
279
+ dataset = dataset["train"]
280
+ except:
281
+ pass
282
+
283
+ if dataset:
284
+ ...
285
+ elif any(prepared_ds_path.glob("*")):
286
  logging.info(
287
  f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
288
  )
289
  dataset = load_from_disk(str(prepared_ds_path))
290
  logging.info("Prepared packed dataset loaded from disk...")
291
+ if cfg.push_dataset_to_hub:
292
+ logging.info(
293
+ f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
294
+ )
295
+ dataset.push_to_hub(
296
+ f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
297
+ )
298
  else:
299
  dataset = load_tokenized_prepared_datasets(
300
  tokenizer, cfg, default_dataset_prepared_path
301
  )
302
 
303
+ if cfg.seed:
304
+ dataset = dataset.shuffle(seed=cfg.seed)
305
+
306
  constant_len_dataset = ConstantLengthDataset(
307
  tokenizer,
308
  [dataset],
 
319
  d
320
  for d in dataset
321
  if len(d["input_ids"]) < cfg.sequence_len
322
+ and len(d["input_ids"]) > 0
323
+ and len(d["input_ids"]) == len(d["attention_mask"])
324
+ and len(d["input_ids"]) == len(d["labels"])
325
  ]
326
  )
327
 
 
330
  f"Saving packed prepared dataset to disk... {prepared_ds_path}"
331
  )
332
  dataset.save_to_disk(prepared_ds_path)
333
+ if cfg.push_dataset_to_hub:
334
+ logging.info(
335
+ f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
336
+ )
337
+ dataset.push_to_hub(
338
+ f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
339
+ )
340
  else:
341
  dataset = load_tokenized_prepared_datasets(
342
  tokenizer, cfg, default_dataset_prepared_path
src/axolotl/utils/models.py CHANGED
@@ -1,15 +1,18 @@
1
  import logging
 
2
  import os
3
  from pathlib import Path
4
  from typing import Optional, Tuple, TYPE_CHECKING
5
 
6
  import torch
7
  import transformers
 
8
  from transformers import (
9
  AutoModelForCausalLM,
10
  AutoTokenizer,
11
  PreTrainedModel,
12
  AutoConfig,
 
13
  )
14
 
15
  try:
@@ -80,6 +83,16 @@ def load_model(
80
  logging.exception(e)
81
  raise e
82
 
 
 
 
 
 
 
 
 
 
 
83
  try:
84
  if cfg.load_4bit and is_llama_derived_model:
85
  from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
@@ -123,16 +136,46 @@ def load_model(
123
  model = LlamaForCausalLM.from_pretrained(
124
  base_model,
125
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
 
126
  torch_dtype=torch_dtype,
127
  device_map=cfg.device_map,
 
128
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  elif model_type:
130
  model = getattr(transformers, model_type).from_pretrained(
131
  base_model,
132
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
 
133
  torch_dtype=torch_dtype,
134
  device_map=cfg.device_map,
135
  trust_remote_code=True if cfg.trust_remote_code is True else False,
 
136
  )
137
  else:
138
  config = AutoConfig.from_pretrained(
@@ -143,9 +186,11 @@ def load_model(
143
  base_model,
144
  config=config,
145
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
 
146
  torch_dtype=torch_dtype,
147
  device_map=cfg.device_map,
148
  trust_remote_code=True if cfg.trust_remote_code is True else False,
 
149
  )
150
  except Exception as e:
151
  logging.error(
@@ -158,16 +203,26 @@ def load_model(
158
  torch_dtype=torch_dtype,
159
  device_map=cfg.device_map,
160
  trust_remote_code=True if cfg.trust_remote_code is True else False,
 
161
  )
162
 
163
  if not tokenizer:
164
  try:
165
  if is_llama_derived_model and "LlamaTokenizer" in globals():
166
- tokenizer = LlamaTokenizer.from_pretrained(model)
 
 
 
167
  else:
168
- tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
 
 
 
169
  except:
170
- tokenizer = AutoTokenizer.from_pretrained(base_model_config, trust_remote_code=True if cfg.trust_remote_code is True else False)
 
 
 
171
 
172
  logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
173
  logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
@@ -181,14 +236,18 @@ def load_model(
181
  tokenizer.add_special_tokens({"pad_token": "[PAD]"})
182
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
183
 
184
- if cfg.tokens:
185
- for k, v in cfg.tokens.items():
186
  tokenizer.add_special_tokens({k: v})
 
 
187
 
188
- # this should only be needed if you are messing with new tokens in the vocab
189
- # model.resize_token_embeddings(len(tokenizer))
190
 
191
- if cfg.adapter and load_in_8bit and not cfg.load_4bit:
 
 
192
  logging.info("converting PEFT model w/ prepare_model_for_int8_training")
193
  model = prepare_model_for_int8_training(model)
194
 
@@ -209,7 +268,11 @@ def load_model(
209
  m.scales = m.scales.half()
210
  m.bias = m.bias.half()
211
 
212
- if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) > 1 and cfg.load_4bit:
 
 
 
 
213
  # llama is PROBABLY model parallelizable, but the default isn't that it is
214
  # so let's only set it for the 4bit, see
215
  # https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
@@ -222,6 +285,7 @@ def load_model(
222
  requires_grad.append(f"{name}: {param.requires_grad}")
223
  if len(requires_grad) == 0:
224
  logging.warning("there are no parameters that require gradient updates")
 
225
 
226
  # TODO resume_from_checkpoint handling
227
  return model, tokenizer, lora_config
@@ -232,7 +296,7 @@ def load_adapter(model, cfg, adapter):
232
 
233
  if adapter is None:
234
  return model, None
235
- if adapter == "lora":
236
  return load_lora(model, cfg)
237
  if adapter == "llama-adapter":
238
  return load_llama_adapter(model, cfg)
@@ -254,7 +318,8 @@ def load_llama_adapter(model, cfg):
254
  task_type="CAUSAL_LM",
255
  )
256
 
257
- if cfg.peft_model_dir:
 
258
  model = PeftModel.from_pretrained(
259
  model,
260
  cfg.lora_model_dir,
@@ -296,7 +361,7 @@ def load_lora(model, cfg):
296
  model,
297
  cfg.lora_model_dir,
298
  device_map=cfg.device_map,
299
- torch_dtype=torch.float16,
300
  )
301
  else:
302
  model = get_peft_model(model, lora_config)
 
1
  import logging
2
+ import math
3
  import os
4
  from pathlib import Path
5
  from typing import Optional, Tuple, TYPE_CHECKING
6
 
7
  import torch
8
  import transformers
9
+ from torch import nn
10
  from transformers import (
11
  AutoModelForCausalLM,
12
  AutoTokenizer,
13
  PreTrainedModel,
14
  AutoConfig,
15
+ BitsAndBytesConfig,
16
  )
17
 
18
  try:
 
83
  logging.exception(e)
84
  raise e
85
 
86
+ model_kwargs = {}
87
+ if cfg.adapter == "qlora":
88
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
89
+ load_in_4bit=True,
90
+ llm_int8_threshold=6.0,
91
+ llm_int8_has_fp16_weight=False,
92
+ bnb_4bit_compute_dtype=torch.float16,
93
+ bnb_4bit_use_double_quant=True,
94
+ bnb_4bit_quant_type="nf4",
95
+ )
96
  try:
97
  if cfg.load_4bit and is_llama_derived_model:
98
  from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
 
136
  model = LlamaForCausalLM.from_pretrained(
137
  base_model,
138
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
139
+ load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
140
  torch_dtype=torch_dtype,
141
  device_map=cfg.device_map,
142
+ **model_kwargs,
143
  )
144
+ # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
145
+ # This is a WIP, still an issue with the backward pass
146
+ # RuntimeError: grad can be implicitly created only for scalar outputs
147
+ # TODO: try config.sequence_parallel = False
148
+ # # https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/tests/models/test_gpt_neox.py#L12
149
+ # # https://github.com/HazyResearch/flash-attention/tree/main/training#model-components
150
+ # # add `**kwargs` to https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/flash_attn/models/gpt.py#L442
151
+ # from flash_attn.utils.pretrained import state_dict_from_pretrained
152
+ # from flash_attn.models.gpt import GPTLMHeadModel
153
+ # from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config
154
+ # from transformers import GPTNeoXConfig
155
+ # config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(base_model))
156
+ # config.use_flash_attn = True
157
+ # config.fused_bias_fc = True
158
+ # config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast"
159
+ # config.activation_function = "gelu_fast"
160
+ # config.fused_dropout_add_ln = True
161
+ # # config.residual_in_fp32 = True
162
+ #
163
+ # model: GPTLMHeadModel = GPTLMHeadModel.from_pretrained(
164
+ # base_model,
165
+ # config,
166
+ # dtype=torch_dtype,
167
+ # device=cfg.device,
168
+ # )
169
+ # model.train() # sets to train instead of eval mode
170
  elif model_type:
171
  model = getattr(transformers, model_type).from_pretrained(
172
  base_model,
173
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
174
+ load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
175
  torch_dtype=torch_dtype,
176
  device_map=cfg.device_map,
177
  trust_remote_code=True if cfg.trust_remote_code is True else False,
178
+ **model_kwargs,
179
  )
180
  else:
181
  config = AutoConfig.from_pretrained(
 
186
  base_model,
187
  config=config,
188
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
189
+ load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
190
  torch_dtype=torch_dtype,
191
  device_map=cfg.device_map,
192
  trust_remote_code=True if cfg.trust_remote_code is True else False,
193
+ **model_kwargs,
194
  )
195
  except Exception as e:
196
  logging.error(
 
203
  torch_dtype=torch_dtype,
204
  device_map=cfg.device_map,
205
  trust_remote_code=True if cfg.trust_remote_code is True else False,
206
+ **model_kwargs,
207
  )
208
 
209
  if not tokenizer:
210
  try:
211
  if is_llama_derived_model and "LlamaTokenizer" in globals():
212
+ tokenizer = LlamaTokenizer.from_pretrained(
213
+ model,
214
+ trust_remote_code=True if cfg.trust_remote_code is True else False,
215
+ )
216
  else:
217
+ tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
218
+ model,
219
+ trust_remote_code=True if cfg.trust_remote_code is True else False,
220
+ )
221
  except:
222
+ tokenizer = AutoTokenizer.from_pretrained(
223
+ base_model_config,
224
+ trust_remote_code=True if cfg.trust_remote_code is True else False,
225
+ )
226
 
227
  logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
228
  logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
 
236
  tokenizer.add_special_tokens({"pad_token": "[PAD]"})
237
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
238
 
239
+ if cfg.special_tokens:
240
+ for k, v in cfg.special_tokens.items():
241
  tokenizer.add_special_tokens({k: v})
242
+ if cfg.tokens:
243
+ tokenizer.add_tokens(list(cfg.tokens))
244
 
245
+ embeddings_len = math.ceil(len(tokenizer) / 32) * 32
246
+ model.resize_token_embeddings(embeddings_len)
247
 
248
+ if (
249
+ (cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora"
250
+ ) and not cfg.load_4bit:
251
  logging.info("converting PEFT model w/ prepare_model_for_int8_training")
252
  model = prepare_model_for_int8_training(model)
253
 
 
268
  m.scales = m.scales.half()
269
  m.bias = m.bias.half()
270
 
271
+ if (
272
+ torch.cuda.device_count() > 1
273
+ and int(os.getenv("WORLD_SIZE", "1")) > 1
274
+ and cfg.load_4bit
275
+ ):
276
  # llama is PROBABLY model parallelizable, but the default isn't that it is
277
  # so let's only set it for the 4bit, see
278
  # https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
 
285
  requires_grad.append(f"{name}: {param.requires_grad}")
286
  if len(requires_grad) == 0:
287
  logging.warning("there are no parameters that require gradient updates")
288
+ model.config.use_cache = False
289
 
290
  # TODO resume_from_checkpoint handling
291
  return model, tokenizer, lora_config
 
296
 
297
  if adapter is None:
298
  return model, None
299
+ if adapter == "lora" or adapter == "qlora":
300
  return load_lora(model, cfg)
301
  if adapter == "llama-adapter":
302
  return load_llama_adapter(model, cfg)
 
318
  task_type="CAUSAL_LM",
319
  )
320
 
321
+ if cfg.lora_model_dir:
322
+ logging.info("Loading pretained LORA")
323
  model = PeftModel.from_pretrained(
324
  model,
325
  cfg.lora_model_dir,
 
361
  model,
362
  cfg.lora_model_dir,
363
  device_map=cfg.device_map,
364
+ # torch_dtype=torch.float16,
365
  )
366
  else:
367
  model = get_peft_model(model, lora_config)
src/axolotl/utils/trainer.py CHANGED
@@ -9,13 +9,33 @@ import torch.cuda
9
  import transformers
10
  from torch import nn
11
  from torch.optim.lr_scheduler import OneCycleLR
12
- from transformers import EarlyStoppingCallback
13
  from transformers.trainer_pt_utils import get_parameter_names
14
 
15
  from axolotl.utils.schedulers import InterpolatingLogScheduler
16
  from axolotl.utils.callbacks import SavePeftModelCallback
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
20
  total_num_steps = int(
21
  math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
@@ -38,6 +58,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
38
  training_arguments_kwargs["bf16_full_eval"] = True
39
  else:
40
  training_arguments_kwargs["bf16"] = cfg.bf16
 
41
  training_arguments_kwargs["tf32"] = cfg.tf32
42
  training_arguments_kwargs["warmup_steps"] = warmup_steps
43
  training_arguments_kwargs["logging_steps"] = logging_steps
@@ -119,6 +140,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
119
  cfg.optimizer == "adamw_bnb_8bit"
120
  and not cfg.load_4bit
121
  and not "deepspeed" in training_arguments_kwargs
 
122
  ):
123
  decay_parameters = get_parameter_names(model, [nn.LayerNorm])
124
  decay_parameters = [name for name in decay_parameters if "bias" not in name]
@@ -157,7 +179,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
157
  cfg.learning_rate,
158
  total_steps=total_num_steps,
159
  epochs=cfg.num_epochs,
160
- div_factor=10,
161
  **lr_scheduler_kwargs,
162
  )
163
  elif cfg.lr_scheduler == "log_sweep":
@@ -182,8 +204,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
182
  cfg.early_stopping_patience,
183
  )
184
  callbacks.append(early_stop_cb)
185
-
186
- if cfg.local_rank == 0 and cfg.adapter == 'lora': # only save in rank 0
187
  callbacks.append(SavePeftModelCallback)
188
 
189
  data_collator_kwargs = {
@@ -194,7 +216,12 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
194
  else:
195
  data_collator_kwargs["pad_to_multiple_of"] = 8
196
 
197
- trainer = transformers.Trainer(
 
 
 
 
 
198
  model=model,
199
  train_dataset=train_dataset,
200
  eval_dataset=eval_dataset,
 
9
  import transformers
10
  from torch import nn
11
  from torch.optim.lr_scheduler import OneCycleLR
12
+ from transformers import EarlyStoppingCallback, Trainer
13
  from transformers.trainer_pt_utils import get_parameter_names
14
 
15
  from axolotl.utils.schedulers import InterpolatingLogScheduler
16
  from axolotl.utils.callbacks import SavePeftModelCallback
17
 
18
 
19
+ class OneCycleLRSchedulerTrainer(Trainer):
20
+ def create_scheduler(
21
+ self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
22
+ ):
23
+ optimizer = self.optimizer if optimizer is None else optimizer
24
+ num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
25
+ num_training_steps = num_training_steps
26
+ pct_start = num_warmup_steps / num_training_steps
27
+
28
+ self.lr_scheduler = OneCycleLR(
29
+ optimizer,
30
+ max_lr=self.args.learning_rate,
31
+ total_steps=num_training_steps,
32
+ pct_start=pct_start,
33
+ div_factor=6,
34
+ )
35
+
36
+ return self.lr_scheduler
37
+
38
+
39
  def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
40
  total_num_steps = int(
41
  math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
 
58
  training_arguments_kwargs["bf16_full_eval"] = True
59
  else:
60
  training_arguments_kwargs["bf16"] = cfg.bf16
61
+ training_arguments_kwargs["fp16"] = True if cfg.fp16 and not cfg.bf16 else False
62
  training_arguments_kwargs["tf32"] = cfg.tf32
63
  training_arguments_kwargs["warmup_steps"] = warmup_steps
64
  training_arguments_kwargs["logging_steps"] = logging_steps
 
140
  cfg.optimizer == "adamw_bnb_8bit"
141
  and not cfg.load_4bit
142
  and not "deepspeed" in training_arguments_kwargs
143
+ and not cfg.fsdp
144
  ):
145
  decay_parameters = get_parameter_names(model, [nn.LayerNorm])
146
  decay_parameters = [name for name in decay_parameters if "bias" not in name]
 
179
  cfg.learning_rate,
180
  total_steps=total_num_steps,
181
  epochs=cfg.num_epochs,
182
+ div_factor=cfg.lr_div_factor if cfg.lr_div_factor else 6,
183
  **lr_scheduler_kwargs,
184
  )
185
  elif cfg.lr_scheduler == "log_sweep":
 
204
  cfg.early_stopping_patience,
205
  )
206
  callbacks.append(early_stop_cb)
207
+
208
+ if cfg.local_rank == 0 and cfg.adapter == "lora": # only save in rank 0
209
  callbacks.append(SavePeftModelCallback)
210
 
211
  data_collator_kwargs = {
 
216
  else:
217
  data_collator_kwargs["pad_to_multiple_of"] = 8
218
 
219
+ trainer_cls = (
220
+ OneCycleLRSchedulerTrainer
221
+ if cfg.lr_scheduler == "one_cycle" and cfg.fsdp
222
+ else transformers.Trainer
223
+ )
224
+ trainer = trainer_cls(
225
  model=model,
226
  train_dataset=train_dataset,
227
  eval_dataset=eval_dataset,