Merge pull request #39 from OpenAccess-AI-Collective/dev
Browse files- .github/workflows/base.yml +14 -1
- .github/workflows/main.yml +22 -4
- README.md +2 -2
- docker/Dockerfile-base +14 -1
- docker/Dockerfile-runpod +5 -2
- examples/replit-3b/config-lora.yml +55 -0
- requirements.txt +2 -2
- scripts/finetune.py +25 -10
- src/axolotl/datasets.py +1 -1
- src/axolotl/prompt_strategies/__init__.py +14 -0
- src/axolotl/prompt_strategies/alpaca_chat.py +32 -0
- src/axolotl/prompt_strategies/alpaca_instruct.py +11 -0
- src/axolotl/prompt_strategies/creative_acr.py +149 -0
- src/axolotl/prompt_strategies/pygmalion.py +110 -0
- src/axolotl/prompt_tokenizers.py +151 -21
- src/axolotl/prompters.py +122 -78
- src/axolotl/utils/data.py +149 -27
- src/axolotl/utils/models.py +77 -12
- src/axolotl/utils/trainer.py +32 -5
.github/workflows/base.yml
CHANGED
@@ -11,6 +11,15 @@ jobs:
|
|
11 |
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
12 |
# this job needs to be run on self-hosted GPU runners...
|
13 |
runs-on: self-hosted
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
steps:
|
15 |
- name: Checkout
|
16 |
uses: actions/checkout@v3
|
@@ -32,7 +41,11 @@ jobs:
|
|
32 |
context: .
|
33 |
file: ./docker/Dockerfile-base
|
34 |
push: ${{ github.event_name != 'pull_request' }}
|
35 |
-
tags: ${{ steps.metadata.outputs.tags }}
|
36 |
labels: ${{ steps.metadata.outputs.labels }}
|
37 |
cache-from: type=gha
|
38 |
cache-to: type=gha,mode=max
|
|
|
|
|
|
|
|
|
|
11 |
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
12 |
# this job needs to be run on self-hosted GPU runners...
|
13 |
runs-on: self-hosted
|
14 |
+
strategy:
|
15 |
+
matrix:
|
16 |
+
include:
|
17 |
+
- cuda: cu118
|
18 |
+
cuda_version: 11.8.0
|
19 |
+
pytorch: 2.0.0
|
20 |
+
- cuda: cu117
|
21 |
+
cuda_version: 11.7.0
|
22 |
+
pytorch: 1.13.1
|
23 |
steps:
|
24 |
- name: Checkout
|
25 |
uses: actions/checkout@v3
|
|
|
41 |
context: .
|
42 |
file: ./docker/Dockerfile-base
|
43 |
push: ${{ github.event_name != 'pull_request' }}
|
44 |
+
tags: ${{ steps.metadata.outputs.tags }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
|
45 |
labels: ${{ steps.metadata.outputs.labels }}
|
46 |
cache-from: type=gha
|
47 |
cache-to: type=gha,mode=max
|
48 |
+
build-args: |
|
49 |
+
CUDA_VERSION=${{ matrix.cuda_version }}
|
50 |
+
CUDA=${{ matrix.cuda }}
|
51 |
+
PYTORCH_VERSION=${{ matrix.pytorch }}
|
.github/workflows/main.yml
CHANGED
@@ -10,6 +10,15 @@ jobs:
|
|
10 |
build-axolotl:
|
11 |
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
12 |
# this job needs to be run on self-hosted GPU runners...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
runs-on: self-hosted
|
14 |
steps:
|
15 |
- name: Checkout
|
@@ -31,10 +40,10 @@ jobs:
|
|
31 |
with:
|
32 |
context: .
|
33 |
build-args: |
|
34 |
-
BASE_TAG=${{ github.ref_name }}-base
|
35 |
file: ./docker/Dockerfile
|
36 |
push: ${{ github.event_name != 'pull_request' }}
|
37 |
-
tags: ${{ steps.metadata.outputs.tags }}
|
38 |
labels: ${{ steps.metadata.outputs.labels }}
|
39 |
cache-from: type=gha
|
40 |
cache-to: type=gha,mode=max
|
@@ -42,6 +51,15 @@ jobs:
|
|
42 |
needs: build-axolotl
|
43 |
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
44 |
# this job needs to be run on self-hosted GPU runners...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
runs-on: self-hosted
|
46 |
steps:
|
47 |
- name: Checkout
|
@@ -63,10 +81,10 @@ jobs:
|
|
63 |
with:
|
64 |
context: .
|
65 |
build-args: |
|
66 |
-
BASE_TAG=${{ github.ref_name }}
|
67 |
file: ./docker/Dockerfile-runpod
|
68 |
push: ${{ github.event_name != 'pull_request' }}
|
69 |
-
tags: ${{ steps.metadata.outputs.tags }}
|
70 |
labels: ${{ steps.metadata.outputs.labels }}
|
71 |
cache-from: type=gha
|
72 |
cache-to: type=gha,mode=max
|
|
|
10 |
build-axolotl:
|
11 |
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
12 |
# this job needs to be run on self-hosted GPU runners...
|
13 |
+
strategy:
|
14 |
+
matrix:
|
15 |
+
include:
|
16 |
+
- cuda: cu118
|
17 |
+
cuda_version: 11.8.0
|
18 |
+
pytorch: 2.0.0
|
19 |
+
- cuda: cu117
|
20 |
+
cuda_version: 11.7.0
|
21 |
+
pytorch: 1.13.1
|
22 |
runs-on: self-hosted
|
23 |
steps:
|
24 |
- name: Checkout
|
|
|
40 |
with:
|
41 |
context: .
|
42 |
build-args: |
|
43 |
+
BASE_TAG=${{ github.ref_name }}-base-${{ matrix.cuda }}-${{ matrix.pytorch }}
|
44 |
file: ./docker/Dockerfile
|
45 |
push: ${{ github.event_name != 'pull_request' }}
|
46 |
+
tags: ${{ steps.metadata.outputs.tags }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
|
47 |
labels: ${{ steps.metadata.outputs.labels }}
|
48 |
cache-from: type=gha
|
49 |
cache-to: type=gha,mode=max
|
|
|
51 |
needs: build-axolotl
|
52 |
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
53 |
# this job needs to be run on self-hosted GPU runners...
|
54 |
+
strategy:
|
55 |
+
matrix:
|
56 |
+
include:
|
57 |
+
- cuda: cu118
|
58 |
+
cuda_version: 11.8.0
|
59 |
+
pytorch: 2.0.0
|
60 |
+
- cuda: cu117
|
61 |
+
cuda_version: 11.7.0
|
62 |
+
pytorch: 1.13.1
|
63 |
runs-on: self-hosted
|
64 |
steps:
|
65 |
- name: Checkout
|
|
|
81 |
with:
|
82 |
context: .
|
83 |
build-args: |
|
84 |
+
BASE_TAG=${{ github.ref_name }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
|
85 |
file: ./docker/Dockerfile-runpod
|
86 |
push: ${{ github.event_name != 'pull_request' }}
|
87 |
+
tags: ${{ steps.metadata.outputs.tags }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
|
88 |
labels: ${{ steps.metadata.outputs.labels }}
|
89 |
cache-from: type=gha
|
90 |
cache-to: type=gha,mode=max
|
README.md
CHANGED
@@ -324,7 +324,7 @@ If you are inferencing a pretrained LORA, pass
|
|
324 |
--lora_model_dir ./completed-model
|
325 |
```
|
326 |
|
327 |
-
### Merge LORA to base
|
328 |
|
329 |
Add below flag to train command above
|
330 |
|
@@ -345,4 +345,4 @@ Please reduce any below
|
|
345 |
|
346 |
Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
|
347 |
|
348 |
-
PRs are **greatly welcome**!
|
|
|
324 |
--lora_model_dir ./completed-model
|
325 |
```
|
326 |
|
327 |
+
### Merge LORA to base
|
328 |
|
329 |
Add below flag to train command above
|
330 |
|
|
|
345 |
|
346 |
Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
|
347 |
|
348 |
+
PRs are **greatly welcome**!
|
docker/Dockerfile-base
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
ARG CUDA_VERSION="11.8.0"
|
2 |
ARG CUDNN_VERSION="8"
|
3 |
ARG UBUNTU_VERSION="22.04"
|
|
|
4 |
|
5 |
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION as base-builder
|
6 |
|
@@ -39,6 +40,14 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
|
39 |
|
40 |
RUN git clone https://github.com/HazyResearch/flash-attention.git && \
|
41 |
cd flash-attention && \
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
python3 setup.py bdist_wheel
|
43 |
|
44 |
FROM base-builder AS deepspeed-builder
|
@@ -60,8 +69,12 @@ RUN cd apex && MAX_JOBS=1 python3 -m pip install --global-option="--cpp_ext" --g
|
|
60 |
RUN mkdir /workspace/wheels
|
61 |
COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
|
62 |
COPY --from=flash-attn-builder /workspace/flash-attention/dist/flash_attn-*.whl wheels
|
|
|
|
|
|
|
|
|
63 |
|
64 |
-
RUN pip3 install wheels/deepspeed-*.whl wheels/flash_attn-*.whl
|
65 |
RUN git lfs install --skip-repo
|
66 |
RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main" \
|
67 |
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
|
|
|
1 |
ARG CUDA_VERSION="11.8.0"
|
2 |
ARG CUDNN_VERSION="8"
|
3 |
ARG UBUNTU_VERSION="22.04"
|
4 |
+
ARG MAX_JOBS=4
|
5 |
|
6 |
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION as base-builder
|
7 |
|
|
|
40 |
|
41 |
RUN git clone https://github.com/HazyResearch/flash-attention.git && \
|
42 |
cd flash-attention && \
|
43 |
+
python3 setup.py bdist_wheel && \
|
44 |
+
cd csrc/fused_dense_lib && \
|
45 |
+
python3 setup.py bdist_wheel && \
|
46 |
+
cd csrc/xentropy && \
|
47 |
+
python3 setup.py bdist_wheel && \
|
48 |
+
cd csrc/rotary && \
|
49 |
+
python3 setup.py bdist_wheel && \
|
50 |
+
cd csrc/layer_norm && \
|
51 |
python3 setup.py bdist_wheel
|
52 |
|
53 |
FROM base-builder AS deepspeed-builder
|
|
|
69 |
RUN mkdir /workspace/wheels
|
70 |
COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
|
71 |
COPY --from=flash-attn-builder /workspace/flash-attention/dist/flash_attn-*.whl wheels
|
72 |
+
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/fused_dense_lib/dist/fused_dense_lib-*.whl wheels
|
73 |
+
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/xentropy/dist/xentropy-*.whl wheels
|
74 |
+
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/rotary/dist/rotary-*.whl wheels
|
75 |
+
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/layer_norm/dist/dropout_layer_norm-*.whl wheels
|
76 |
|
77 |
+
RUN pip3 install wheels/deepspeed-*.whl wheels/flash_attn-*.whl wheels/fused_dense_lib-*.whl wheels/xeontropy-*.whl wheels/rotary-*.whl wheels/dropout_layer_norm-*.whl
|
78 |
RUN git lfs install --skip-repo
|
79 |
RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main" \
|
80 |
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
|
docker/Dockerfile-runpod
CHANGED
@@ -1,11 +1,14 @@
|
|
1 |
ARG BASE_TAG=main
|
2 |
FROM winglian/axolotl:$BASE_TAG
|
3 |
|
|
|
|
|
4 |
RUN apt install --yes --no-install-recommends openssh-server tmux && \
|
5 |
mkdir -p ~/.ssh && \
|
6 |
chmod 700 ~/.ssh && \
|
7 |
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
|
8 |
-
chmod +x /workspace/axolotl/scripts/runpod-entrypoint.sh
|
|
|
9 |
|
10 |
-
ENTRYPOINT ["/
|
11 |
CMD ["sleep", "infinity"]
|
|
|
1 |
ARG BASE_TAG=main
|
2 |
FROM winglian/axolotl:$BASE_TAG
|
3 |
|
4 |
+
COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh
|
5 |
+
|
6 |
RUN apt install --yes --no-install-recommends openssh-server tmux && \
|
7 |
mkdir -p ~/.ssh && \
|
8 |
chmod 700 ~/.ssh && \
|
9 |
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
|
10 |
+
chmod +x /workspace/axolotl/scripts/runpod-entrypoint.sh && \
|
11 |
+
chmod +x /root/runpod-entrypoint.sh
|
12 |
|
13 |
+
ENTRYPOINT ["/root/runpod-entrypoint.sh"]
|
14 |
CMD ["sleep", "infinity"]
|
examples/replit-3b/config-lora.yml
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_model: replit/replit-code-v1-3b
|
2 |
+
base_model_config: replit/replit-code-v1-3b
|
3 |
+
trust_remote_code: true
|
4 |
+
load_in_8bit: false
|
5 |
+
datasets:
|
6 |
+
- path: vicgalle/alpaca-gpt4
|
7 |
+
type: alpaca
|
8 |
+
dataset_prepared_path: last_run_prepared
|
9 |
+
val_set_size: 0.05
|
10 |
+
adapter: lora
|
11 |
+
lora_model_dir:
|
12 |
+
sequence_len: 2048
|
13 |
+
max_packed_sequence_len:
|
14 |
+
lora_r: 8
|
15 |
+
lora_alpha: 16
|
16 |
+
lora_dropout: 0.05
|
17 |
+
lora_target_modules:
|
18 |
+
- Wqkv
|
19 |
+
- mlp_up
|
20 |
+
- mlp_down
|
21 |
+
lora_fan_in_fan_out:
|
22 |
+
wandb_project: lora-replit
|
23 |
+
wandb_watch:
|
24 |
+
wandb_run_id:
|
25 |
+
wandb_log_model:
|
26 |
+
output_dir: ./lora-replit
|
27 |
+
batch_size: 8
|
28 |
+
micro_batch_size: 1
|
29 |
+
num_epochs: 3
|
30 |
+
optimizer:
|
31 |
+
torchdistx_path:
|
32 |
+
lr_scheduler:
|
33 |
+
learning_rate: 0.00001
|
34 |
+
train_on_inputs: false
|
35 |
+
group_by_length: false
|
36 |
+
bf16: true
|
37 |
+
tf32: true
|
38 |
+
gradient_checkpointing:
|
39 |
+
early_stopping_patience:
|
40 |
+
resume_from_checkpoint:
|
41 |
+
local_rank:
|
42 |
+
logging_steps: 1
|
43 |
+
xformers_attention:
|
44 |
+
flash_attention:
|
45 |
+
gptq_groupsize:
|
46 |
+
gptq_model_v1:
|
47 |
+
warmup_steps: 20
|
48 |
+
eval_steps: 50
|
49 |
+
save_steps:
|
50 |
+
debug:
|
51 |
+
deepspeed:
|
52 |
+
weight_decay: 0
|
53 |
+
fsdp:
|
54 |
+
fsdp_config:
|
55 |
+
#special_tokens:
|
requirements.txt
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
peft @ git+https://github.com/huggingface/peft.git
|
2 |
transformers @ git+https://github.com/huggingface/transformers.git
|
|
|
3 |
attrdict
|
4 |
fire
|
5 |
PyYAML==6.0
|
6 |
black
|
7 |
-
bitsandbytes==0.37.2
|
8 |
datasets
|
9 |
-
accelerate
|
10 |
sentencepiece
|
11 |
wandb
|
12 |
einops
|
|
|
1 |
peft @ git+https://github.com/huggingface/peft.git
|
2 |
transformers @ git+https://github.com/huggingface/transformers.git
|
3 |
+
bitsandbytes>=0.39.0
|
4 |
attrdict
|
5 |
fire
|
6 |
PyYAML==6.0
|
7 |
black
|
|
|
8 |
datasets
|
9 |
+
accelerate>=0.19.0
|
10 |
sentencepiece
|
11 |
wandb
|
12 |
einops
|
scripts/finetune.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import importlib
|
2 |
import logging
|
3 |
import os
|
4 |
-
import pathlib
|
5 |
import random
|
6 |
import signal
|
7 |
import sys
|
@@ -10,12 +9,12 @@ from typing import Optional
|
|
10 |
|
11 |
import fire
|
12 |
import torch
|
13 |
-
import transformers
|
14 |
import yaml
|
15 |
from attrdict import AttrDefault
|
16 |
|
17 |
# add src to the pythonpath so we don't need to pip install this
|
18 |
from axolotl.utils.tokenization import check_dataset_labels
|
|
|
19 |
|
20 |
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
21 |
src_dir = os.path.join(project_root, "src")
|
@@ -33,7 +32,7 @@ DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
|
33 |
def choose_device(cfg):
|
34 |
def get_device():
|
35 |
if torch.cuda.is_available():
|
36 |
-
return "cuda"
|
37 |
else:
|
38 |
try:
|
39 |
if torch.backends.mps.is_available():
|
@@ -69,7 +68,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
|
69 |
instruction = get_multi_line_input()
|
70 |
if not instruction:
|
71 |
return
|
72 |
-
prompt = prompter_module().build_prompt(instruction=instruction)
|
73 |
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
74 |
|
75 |
model.eval()
|
@@ -133,7 +132,8 @@ def train(
|
|
133 |
# then overwrite the value
|
134 |
cfg_keys = dict(cfg).keys()
|
135 |
for k in kwargs:
|
136 |
-
if
|
|
|
137 |
# handle booleans
|
138 |
if isinstance(cfg[k], bool):
|
139 |
cfg[k] = bool(kwargs[k])
|
@@ -159,6 +159,8 @@ def train(
|
|
159 |
cfg.fp16 = True
|
160 |
cfg.bf16 = False
|
161 |
|
|
|
|
|
162 |
# Load the model and tokenizer
|
163 |
logging.info("loading model, tokenizer, and peft_config...")
|
164 |
model, tokenizer, peft_config = load_model(
|
@@ -171,6 +173,15 @@ def train(
|
|
171 |
inference=("inference" in kwargs),
|
172 |
)
|
173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
if "inference" in kwargs:
|
175 |
logging.info("calling do_inference function")
|
176 |
do_inference(cfg, model, tokenizer)
|
@@ -184,10 +195,6 @@ def train(
|
|
184 |
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
185 |
)
|
186 |
|
187 |
-
if prepare_ds_only:
|
188 |
-
logging.info("Finished preparing dataset. Exiting...")
|
189 |
-
return
|
190 |
-
|
191 |
if cfg.debug:
|
192 |
logging.info("check_dataset_labels...")
|
193 |
check_dataset_labels(
|
@@ -197,6 +204,10 @@ def train(
|
|
197 |
tokenizer,
|
198 |
)
|
199 |
|
|
|
|
|
|
|
|
|
200 |
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
|
201 |
|
202 |
model.config.use_cache = False
|
@@ -218,6 +229,8 @@ def train(
|
|
218 |
)
|
219 |
|
220 |
logging.info("Starting trainer...")
|
|
|
|
|
221 |
resume_from_checkpoint = cfg.resume_from_checkpoint
|
222 |
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
223 |
possible_checkpoints = [
|
@@ -236,7 +249,9 @@ def train(
|
|
236 |
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
237 |
|
238 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
239 |
-
|
|
|
|
|
240 |
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
241 |
|
242 |
|
|
|
1 |
import importlib
|
2 |
import logging
|
3 |
import os
|
|
|
4 |
import random
|
5 |
import signal
|
6 |
import sys
|
|
|
9 |
|
10 |
import fire
|
11 |
import torch
|
|
|
12 |
import yaml
|
13 |
from attrdict import AttrDefault
|
14 |
|
15 |
# add src to the pythonpath so we don't need to pip install this
|
16 |
from axolotl.utils.tokenization import check_dataset_labels
|
17 |
+
from axolotl.utils.validation import validate_config
|
18 |
|
19 |
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
20 |
src_dir = os.path.join(project_root, "src")
|
|
|
32 |
def choose_device(cfg):
|
33 |
def get_device():
|
34 |
if torch.cuda.is_available():
|
35 |
+
return f"cuda:{cfg.local_rank}"
|
36 |
else:
|
37 |
try:
|
38 |
if torch.backends.mps.is_available():
|
|
|
68 |
instruction = get_multi_line_input()
|
69 |
if not instruction:
|
70 |
return
|
71 |
+
prompt: str = next(prompter_module().build_prompt(instruction=instruction))
|
72 |
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
73 |
|
74 |
model.eval()
|
|
|
132 |
# then overwrite the value
|
133 |
cfg_keys = dict(cfg).keys()
|
134 |
for k in kwargs:
|
135 |
+
# if not strict, allow writing to cfg even if it's not in the yml already
|
136 |
+
if k in cfg_keys or cfg.strict is False:
|
137 |
# handle booleans
|
138 |
if isinstance(cfg[k], bool):
|
139 |
cfg[k] = bool(kwargs[k])
|
|
|
159 |
cfg.fp16 = True
|
160 |
cfg.bf16 = False
|
161 |
|
162 |
+
validate_config(cfg)
|
163 |
+
|
164 |
# Load the model and tokenizer
|
165 |
logging.info("loading model, tokenizer, and peft_config...")
|
166 |
model, tokenizer, peft_config = load_model(
|
|
|
173 |
inference=("inference" in kwargs),
|
174 |
)
|
175 |
|
176 |
+
if "merge_lora" in kwargs and cfg.adapter is not None:
|
177 |
+
logging.info("running merge of LoRA with base model")
|
178 |
+
model = model.merge_and_unload()
|
179 |
+
|
180 |
+
if cfg.local_rank == 0:
|
181 |
+
logging.info("saving merged model")
|
182 |
+
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
183 |
+
return
|
184 |
+
|
185 |
if "inference" in kwargs:
|
186 |
logging.info("calling do_inference function")
|
187 |
do_inference(cfg, model, tokenizer)
|
|
|
195 |
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
196 |
)
|
197 |
|
|
|
|
|
|
|
|
|
198 |
if cfg.debug:
|
199 |
logging.info("check_dataset_labels...")
|
200 |
check_dataset_labels(
|
|
|
204 |
tokenizer,
|
205 |
)
|
206 |
|
207 |
+
if prepare_ds_only:
|
208 |
+
logging.info("Finished preparing dataset. Exiting...")
|
209 |
+
return
|
210 |
+
|
211 |
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
|
212 |
|
213 |
model.config.use_cache = False
|
|
|
229 |
)
|
230 |
|
231 |
logging.info("Starting trainer...")
|
232 |
+
if cfg.group_by_length:
|
233 |
+
logging.info("hang tight... sorting dataset for group_by_length")
|
234 |
resume_from_checkpoint = cfg.resume_from_checkpoint
|
235 |
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
236 |
possible_checkpoints = [
|
|
|
249 |
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
250 |
|
251 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
252 |
+
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
253 |
+
if cfg.local_rank == 0:
|
254 |
+
model.save_pretrained(cfg.output_dir)
|
255 |
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
256 |
|
257 |
|
src/axolotl/datasets.py
CHANGED
@@ -106,7 +106,7 @@ class ConstantLengthDataset(IterableDataset):
|
|
106 |
}
|
107 |
else:
|
108 |
logging.warning(
|
109 |
-
"dropping batch due to tensor size mismatch"
|
110 |
)
|
111 |
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
|
112 |
buffer_len = 0
|
|
|
106 |
}
|
107 |
else:
|
108 |
logging.warning(
|
109 |
+
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
|
110 |
)
|
111 |
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
|
112 |
buffer_len = 0
|
src/axolotl/prompt_strategies/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
|
3 |
+
|
4 |
+
def load(strategy, tokenizer, cfg):
|
5 |
+
try:
|
6 |
+
load_fn = "load"
|
7 |
+
if strategy.split(".")[-1].startswith("load_"):
|
8 |
+
load_fn = strategy.split(".")[-1]
|
9 |
+
strategy = ".".join(strategy.split(".")[:-1])
|
10 |
+
m = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
|
11 |
+
fn = getattr(m, load_fn)
|
12 |
+
return fn(tokenizer, cfg)
|
13 |
+
except:
|
14 |
+
pass
|
src/axolotl/prompt_strategies/alpaca_chat.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from axolotl.prompt_tokenizers import (
|
2 |
+
AlpacaPromptTokenizingStrategy,
|
3 |
+
InstructionPromptTokenizingStrategy,
|
4 |
+
)
|
5 |
+
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
6 |
+
|
7 |
+
|
8 |
+
def load(tokenizer, cfg):
|
9 |
+
return AlpacaPromptTokenizingStrategy(
|
10 |
+
AlpacaPrompter(PromptStyle.chat.value),
|
11 |
+
tokenizer,
|
12 |
+
cfg.train_on_inputs,
|
13 |
+
cfg.sequence_len,
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
18 |
+
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
19 |
+
return (
|
20 |
+
prompt["question"],
|
21 |
+
"",
|
22 |
+
prompt["answer"],
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
def load_qa(tokenizer, cfg):
|
27 |
+
return AlpacaQAPromptTokenizingStrategy(
|
28 |
+
AlpacaPrompter(PromptStyle.chat.value),
|
29 |
+
tokenizer,
|
30 |
+
cfg.train_on_inputs,
|
31 |
+
cfg.sequence_len,
|
32 |
+
)
|
src/axolotl/prompt_strategies/alpaca_instruct.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
2 |
+
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
3 |
+
|
4 |
+
|
5 |
+
def load(tokenizer, cfg):
|
6 |
+
return AlpacaPromptTokenizingStrategy(
|
7 |
+
AlpacaPrompter(PromptStyle.instruct),
|
8 |
+
tokenizer,
|
9 |
+
cfg.train_on_inputs,
|
10 |
+
cfg.sequence_len,
|
11 |
+
)
|
src/axolotl/prompt_strategies/creative_acr.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union, Generator
|
2 |
+
|
3 |
+
import yaml
|
4 |
+
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
5 |
+
|
6 |
+
|
7 |
+
class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
8 |
+
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
9 |
+
question = prompt["instruction"]
|
10 |
+
answer = prompt[
|
11 |
+
"revision"
|
12 |
+
] # don't use prompt[answer], that's data we don't want in the dataset
|
13 |
+
return (
|
14 |
+
question,
|
15 |
+
"",
|
16 |
+
answer,
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
class CreativeCritiquePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
21 |
+
user_prompt = """Given the following Question and Response, critique the Response on a scale of 1-10. You should critique the answer in the following criteria:
|
22 |
+
refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question.
|
23 |
+
prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means there is prescriptive bias.
|
24 |
+
creativity: score how creative the Response is. 10 is most creative, 0 lacks creativity
|
25 |
+
comprehensiveness: score how comprehensive and in-depth the Response is. 10 is most comprehensive and in-depth, 0 lacks coverage and depth
|
26 |
+
Then for each criteria, give a detailed explanation of why the response doesn't meet the criteria (don't reference the score), what it did right, and how it could be improved for that criteria.
|
27 |
+
|
28 |
+
Output your answer in YAML format like so:
|
29 |
+
scores:
|
30 |
+
refusal: <0-10>
|
31 |
+
prescriptive_bias: <0-10>
|
32 |
+
creativity: <0-10>
|
33 |
+
comprehensiveness: <0-10>
|
34 |
+
critiques:
|
35 |
+
refusal:
|
36 |
+
explanation: ...
|
37 |
+
improvements: ...
|
38 |
+
prescriptive_bias:
|
39 |
+
explanation: ...
|
40 |
+
improvements: ...
|
41 |
+
creativity:
|
42 |
+
explanation: ...
|
43 |
+
improvements: ...
|
44 |
+
comprehensiveness:
|
45 |
+
explanation: ...
|
46 |
+
improvements: ...
|
47 |
+
|
48 |
+
Question: {question}
|
49 |
+
Answer: {answer}
|
50 |
+
"""
|
51 |
+
|
52 |
+
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
53 |
+
scores = yaml.dump(
|
54 |
+
prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
|
55 |
+
)
|
56 |
+
critiques = yaml.dump(
|
57 |
+
prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
|
58 |
+
)
|
59 |
+
evaluation = scores + critiques
|
60 |
+
question = prompt["instruction"]
|
61 |
+
answer = prompt["answer"]
|
62 |
+
return (
|
63 |
+
self.user_prompt.format(question=question, answer=answer),
|
64 |
+
"",
|
65 |
+
evaluation,
|
66 |
+
)
|
67 |
+
|
68 |
+
|
69 |
+
class CreativeRevisePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
70 |
+
user_prompt = """Definitions:
|
71 |
+
refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question.
|
72 |
+
prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means their is prescriptive bias.
|
73 |
+
creativity: score how creative the Response is. 10 is most creative, 0 lacks creativity
|
74 |
+
comprehensiveness: score how comprehensive and in-depth the Response is. 10 is most comprehensive and in-depth, 0 lacks coverage and depth
|
75 |
+
|
76 |
+
Given the following Question, Response, and Evaluation, revise the Response based on the Evaluation and recommendations for improvements. Reply only with the revised response.
|
77 |
+
|
78 |
+
Question: {question}
|
79 |
+
Answer: {answer}
|
80 |
+
Evaluation:
|
81 |
+
{evaluation}
|
82 |
+
"""
|
83 |
+
|
84 |
+
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
85 |
+
scores = yaml.dump(
|
86 |
+
prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
|
87 |
+
)
|
88 |
+
critiques = yaml.dump(
|
89 |
+
prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
|
90 |
+
)
|
91 |
+
evaluation = scores + critiques
|
92 |
+
question = prompt["instruction"]
|
93 |
+
answer = prompt["answer"]
|
94 |
+
return (
|
95 |
+
self.user_prompt.format(
|
96 |
+
question=question, answer=answer, evaluation=evaluation
|
97 |
+
),
|
98 |
+
"",
|
99 |
+
prompt["revision"],
|
100 |
+
)
|
101 |
+
|
102 |
+
|
103 |
+
class CreativePrompterBase:
|
104 |
+
system_prompt = ""
|
105 |
+
prompt_input = "{system_prompt}\nUSER: {instruction}\nASSISTANT:"
|
106 |
+
|
107 |
+
def build_prompt(
|
108 |
+
self,
|
109 |
+
instruction: str,
|
110 |
+
input: Union[None, str] = None,
|
111 |
+
output: Union[None, str] = None,
|
112 |
+
) -> Generator[str, None, None]:
|
113 |
+
if self.system_prompt:
|
114 |
+
res = f"{self.system_prompt}\nUSER: {instruction}\nASSISTANT:"
|
115 |
+
else:
|
116 |
+
res = f"USER: {instruction}\nASSISTANT:"
|
117 |
+
if output:
|
118 |
+
res = f"{res}{output}"
|
119 |
+
yield res
|
120 |
+
|
121 |
+
|
122 |
+
class CreativeAnswerPrompter(CreativePrompterBase):
|
123 |
+
system_prompt = "Answer the following question in a comprehensive, in-depth, and creative way. Additionally your response should be relevant, accurate, and free of any ambiguity."
|
124 |
+
|
125 |
+
|
126 |
+
class CreativeCritiquePrompter(CreativePrompterBase):
|
127 |
+
system_prompt = ""
|
128 |
+
|
129 |
+
|
130 |
+
class CreativeRevisePrompter(CreativePrompterBase):
|
131 |
+
system_prompt = ""
|
132 |
+
|
133 |
+
|
134 |
+
def load_answer(tokenizer, cfg):
|
135 |
+
return CreativeAnsweringPromptTokenizingStrategy(
|
136 |
+
CreativeAnswerPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
137 |
+
)
|
138 |
+
|
139 |
+
|
140 |
+
def load_critique(tokenizer, cfg):
|
141 |
+
return CreativeCritiquePromptTokenizingStrategy(
|
142 |
+
CreativeCritiquePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
143 |
+
)
|
144 |
+
|
145 |
+
|
146 |
+
def load_revise(tokenizer, cfg):
|
147 |
+
return CreativeRevisePromptTokenizingStrategy(
|
148 |
+
CreativeRevisePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
149 |
+
)
|
src/axolotl/prompt_strategies/pygmalion.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import logging
|
3 |
+
from collections import defaultdict
|
4 |
+
from typing import Generator
|
5 |
+
|
6 |
+
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
7 |
+
|
8 |
+
IGNORE_TOKEN_ID = -100
|
9 |
+
|
10 |
+
|
11 |
+
class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
12 |
+
bot_prefix_token_ids = []
|
13 |
+
|
14 |
+
def __init__(self, prompter, tokenizer, *args, **kwargs):
|
15 |
+
super().__init__(prompter, tokenizer)
|
16 |
+
res = self._tokenize("<|model|>", add_eos_token=False, strip_bos_token=True)
|
17 |
+
self.bot_prefix_token_ids = res["input_ids"]
|
18 |
+
|
19 |
+
def tokenize_prompt(self, prompt):
|
20 |
+
result = {
|
21 |
+
"input_ids": [],
|
22 |
+
"attention_mask": [],
|
23 |
+
"labels": [],
|
24 |
+
}
|
25 |
+
current_len = 0
|
26 |
+
for i, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
|
27 |
+
role, message = part
|
28 |
+
if role == "system":
|
29 |
+
prefix = "<|system|>"
|
30 |
+
# this should include a bos token, no eos token, strip trailing "\n<START>"
|
31 |
+
if message.endswith("\n<START>"):
|
32 |
+
message = message[:-8]
|
33 |
+
res = self._tokenize(
|
34 |
+
prefix + "Persona: " + message.strip(),
|
35 |
+
add_eos_token=False,
|
36 |
+
strip_bos_token=False,
|
37 |
+
)
|
38 |
+
# everything from this is masked out from the labels
|
39 |
+
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
40 |
+
elif role == "human":
|
41 |
+
prefix = "<|user|>"
|
42 |
+
res = self._tokenize(
|
43 |
+
prefix + " " + message.strip(),
|
44 |
+
add_eos_token=False,
|
45 |
+
strip_bos_token=True,
|
46 |
+
)
|
47 |
+
# everything from this is masked out from the labels
|
48 |
+
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
49 |
+
elif role == "bot":
|
50 |
+
prefix = "<|model|>"
|
51 |
+
res = self._tokenize(
|
52 |
+
prefix + " " + message.strip(),
|
53 |
+
add_eos_token=True,
|
54 |
+
strip_bos_token=True,
|
55 |
+
)
|
56 |
+
# mask out the prefix token, rest is not masked out from labels
|
57 |
+
# make sure we create the labels first, otherwise we get incorrect lengths
|
58 |
+
labels = [IGNORE_TOKEN_ID] * len(self.bot_prefix_token_ids) + [
|
59 |
+
*copy.deepcopy(res["input_ids"])
|
60 |
+
][len(self.bot_prefix_token_ids) :]
|
61 |
+
else:
|
62 |
+
logging.warning(f"unknown role in conversation: {role}")
|
63 |
+
res = defaultdict(lambda: [])
|
64 |
+
input_ids = res["input_ids"]
|
65 |
+
input_len = len(input_ids)
|
66 |
+
result["input_ids"][current_len : current_len + input_len] = input_ids
|
67 |
+
result["attention_mask"][current_len : current_len + input_len] = [
|
68 |
+
1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids
|
69 |
+
]
|
70 |
+
result["labels"][current_len : current_len + input_len] = labels
|
71 |
+
current_len += input_len
|
72 |
+
return result
|
73 |
+
|
74 |
+
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
75 |
+
result = self.tokenizer(
|
76 |
+
prompt,
|
77 |
+
truncation=True,
|
78 |
+
max_length=self.sequence_len,
|
79 |
+
padding=False,
|
80 |
+
return_tensors=None,
|
81 |
+
)
|
82 |
+
if (
|
83 |
+
result["input_ids"][-1] != self.tokenizer.eos_token_id
|
84 |
+
and len(result["input_ids"]) < self.sequence_len
|
85 |
+
and add_eos_token
|
86 |
+
):
|
87 |
+
result["input_ids"].append(self.tokenizer.eos_token_id)
|
88 |
+
result["attention_mask"].append(1)
|
89 |
+
|
90 |
+
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
91 |
+
result["input_ids"] = result["input_ids"][1:]
|
92 |
+
result["attention_mask"] = result["attention_mask"][1:]
|
93 |
+
|
94 |
+
result["labels"] = result["input_ids"].copy()
|
95 |
+
return result
|
96 |
+
|
97 |
+
|
98 |
+
class PygmalionPrompter:
|
99 |
+
def __init__(self, *args, **kwargs):
|
100 |
+
pass
|
101 |
+
|
102 |
+
def build_prompt(self, source, *args, **kwargs) -> Generator[str, None, None]:
|
103 |
+
for msg in source:
|
104 |
+
yield msg["role"], msg["value"]
|
105 |
+
|
106 |
+
|
107 |
+
def load(tokenizer, cfg):
|
108 |
+
return PygmalionPromptTokenizingStrategy(
|
109 |
+
PygmalionPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
110 |
+
)
|
src/axolotl/prompt_tokenizers.py
CHANGED
@@ -1,7 +1,12 @@
|
|
1 |
import abc
|
|
|
|
|
|
|
2 |
|
3 |
from transformers import PreTrainedTokenizer
|
4 |
|
|
|
|
|
5 |
IGNORE_INDEX = -100
|
6 |
LLAMA_DEFAULT_PAD_TOKEN = "[PAD]"
|
7 |
LLAMA_DEFAULT_EOS_TOKEN = "</s>"
|
@@ -30,6 +35,20 @@ class PromptTokenizingStrategy(abc.ABC):
|
|
30 |
def tokenize_prompt(self, prompt):
|
31 |
pass
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
35 |
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
@@ -40,9 +59,13 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
40 |
full_prompt = self._build_full_prompt(instruction, input, response)
|
41 |
tokenized_full_prompt = self._tokenize(full_prompt)
|
42 |
if not self.train_on_inputs:
|
43 |
-
user_prompt =
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
46 |
)
|
47 |
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
48 |
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
@@ -54,13 +77,17 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
54 |
return tokenized_full_prompt
|
55 |
|
56 |
def _build_full_prompt(self, instruction, input, response):
|
57 |
-
return
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
61 |
)
|
62 |
|
63 |
-
def _tokenize(self, prompt, add_eos_token=True):
|
64 |
result = self.tokenizer(
|
65 |
prompt,
|
66 |
truncation=True,
|
@@ -76,6 +103,10 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
76 |
result["input_ids"].append(self.tokenizer.eos_token_id)
|
77 |
result["attention_mask"].append(1)
|
78 |
|
|
|
|
|
|
|
|
|
79 |
result["labels"] = result["input_ids"].copy()
|
80 |
return result
|
81 |
|
@@ -89,6 +120,15 @@ class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
|
89 |
)
|
90 |
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
93 |
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
94 |
return (
|
@@ -107,6 +147,15 @@ class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy)
|
|
107 |
)
|
108 |
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
111 |
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
112 |
return (
|
@@ -131,13 +180,13 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
|
131 |
|
132 |
def tokenize_prompt(self, prompt):
|
133 |
instruction = self.parse_instruction_fields(prompt)
|
134 |
-
full_prompt = self._build_full_prompt(instruction)
|
135 |
tokenized_full_prompt = self._tokenize(full_prompt)
|
136 |
|
137 |
return tokenized_full_prompt
|
138 |
|
139 |
-
def _build_full_prompt(self, instruction):
|
140 |
-
return self.prompter.build_prompt(instruction)
|
141 |
|
142 |
|
143 |
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
@@ -157,9 +206,13 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
157 |
)
|
158 |
tokenized_full_prompt = self._tokenize(full_prompt)
|
159 |
if not self.train_on_inputs:
|
160 |
-
user_prompt =
|
161 |
-
|
162 |
-
|
|
|
|
|
|
|
|
|
163 |
)
|
164 |
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
165 |
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
@@ -171,12 +224,16 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
171 |
return tokenized_full_prompt
|
172 |
|
173 |
def _build_full_prompt(self, instruction, input, output, reflection, corrected):
|
174 |
-
return
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
|
|
|
|
|
|
|
|
180 |
)
|
181 |
|
182 |
def _tokenize(self, prompt, add_eos_token=True):
|
@@ -212,7 +269,80 @@ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
|
|
212 |
|
213 |
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
214 |
def tokenize_prompt(self, prompt):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
try:
|
216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
except (KeyError, AssertionError, IndexError) as e:
|
218 |
raise InvalidDataException(str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import abc
|
2 |
+
import copy
|
3 |
+
import functools
|
4 |
+
import logging
|
5 |
|
6 |
from transformers import PreTrainedTokenizer
|
7 |
|
8 |
+
from axolotl.prompters import IGNORE_TOKEN_ID
|
9 |
+
|
10 |
IGNORE_INDEX = -100
|
11 |
LLAMA_DEFAULT_PAD_TOKEN = "[PAD]"
|
12 |
LLAMA_DEFAULT_EOS_TOKEN = "</s>"
|
|
|
35 |
def tokenize_prompt(self, prompt):
|
36 |
pass
|
37 |
|
38 |
+
@functools.cache
|
39 |
+
def _get_user_token(self):
|
40 |
+
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
|
41 |
+
if isinstance(id_or_ids, (int,)):
|
42 |
+
return id_or_ids
|
43 |
+
return False
|
44 |
+
|
45 |
+
@functools.cache
|
46 |
+
def _get_assistant_token(self):
|
47 |
+
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
|
48 |
+
if isinstance(id_or_ids, (int,)):
|
49 |
+
return id_or_ids
|
50 |
+
return False
|
51 |
+
|
52 |
|
53 |
class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
54 |
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
|
|
59 |
full_prompt = self._build_full_prompt(instruction, input, response)
|
60 |
tokenized_full_prompt = self._tokenize(full_prompt)
|
61 |
if not self.train_on_inputs:
|
62 |
+
user_prompt = next(
|
63 |
+
iter(
|
64 |
+
self.prompter.build_prompt(
|
65 |
+
instruction,
|
66 |
+
input,
|
67 |
+
)
|
68 |
+
)
|
69 |
)
|
70 |
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
71 |
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
|
|
77 |
return tokenized_full_prompt
|
78 |
|
79 |
def _build_full_prompt(self, instruction, input, response):
|
80 |
+
return next(
|
81 |
+
iter(
|
82 |
+
self.prompter.build_prompt(
|
83 |
+
instruction,
|
84 |
+
input,
|
85 |
+
response,
|
86 |
+
)
|
87 |
+
)
|
88 |
)
|
89 |
|
90 |
+
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
91 |
result = self.tokenizer(
|
92 |
prompt,
|
93 |
truncation=True,
|
|
|
103 |
result["input_ids"].append(self.tokenizer.eos_token_id)
|
104 |
result["attention_mask"].append(1)
|
105 |
|
106 |
+
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
107 |
+
result["input_ids"] = result["input_ids"][1:]
|
108 |
+
result["attention_mask"] = result["attention_mask"][1:]
|
109 |
+
|
110 |
result["labels"] = result["input_ids"].copy()
|
111 |
return result
|
112 |
|
|
|
120 |
)
|
121 |
|
122 |
|
123 |
+
class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
124 |
+
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
125 |
+
return (
|
126 |
+
prompt["question"],
|
127 |
+
"\n".join(f'- "{choice}"' for choice in prompt["choices"]),
|
128 |
+
prompt["solution"] if "solution" in prompt else prompt["explanation"],
|
129 |
+
)
|
130 |
+
|
131 |
+
|
132 |
class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
133 |
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
134 |
return (
|
|
|
147 |
)
|
148 |
|
149 |
|
150 |
+
class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
151 |
+
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
152 |
+
return (
|
153 |
+
prompt["article"],
|
154 |
+
"",
|
155 |
+
prompt["summary"],
|
156 |
+
)
|
157 |
+
|
158 |
+
|
159 |
class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
160 |
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
161 |
return (
|
|
|
180 |
|
181 |
def tokenize_prompt(self, prompt):
|
182 |
instruction = self.parse_instruction_fields(prompt)
|
183 |
+
full_prompt = self._build_full_prompt(instruction, None, None)
|
184 |
tokenized_full_prompt = self._tokenize(full_prompt)
|
185 |
|
186 |
return tokenized_full_prompt
|
187 |
|
188 |
+
def _build_full_prompt(self, instruction, input, response):
|
189 |
+
return next(iter(self.prompter.build_prompt(instruction)))
|
190 |
|
191 |
|
192 |
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
|
206 |
)
|
207 |
tokenized_full_prompt = self._tokenize(full_prompt)
|
208 |
if not self.train_on_inputs:
|
209 |
+
user_prompt = next(
|
210 |
+
iter(
|
211 |
+
self.prompter.build_prompt(
|
212 |
+
instruction,
|
213 |
+
input,
|
214 |
+
)
|
215 |
+
)
|
216 |
)
|
217 |
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
218 |
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
|
|
224 |
return tokenized_full_prompt
|
225 |
|
226 |
def _build_full_prompt(self, instruction, input, output, reflection, corrected):
|
227 |
+
return next(
|
228 |
+
iter(
|
229 |
+
self.prompter.build_prompt(
|
230 |
+
instruction,
|
231 |
+
input,
|
232 |
+
output,
|
233 |
+
reflection,
|
234 |
+
corrected,
|
235 |
+
)
|
236 |
+
)
|
237 |
)
|
238 |
|
239 |
def _tokenize(self, prompt, add_eos_token=True):
|
|
|
269 |
|
270 |
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
271 |
def tokenize_prompt(self, prompt):
|
272 |
+
result = {
|
273 |
+
"input_ids": [],
|
274 |
+
"attention_mask": [],
|
275 |
+
"labels": [],
|
276 |
+
}
|
277 |
+
current_len = 0
|
278 |
+
user_token = self._get_user_token()
|
279 |
+
assistant_token = self._get_assistant_token()
|
280 |
try:
|
281 |
+
for i, part in enumerate(
|
282 |
+
self.prompter.build_prompt(prompt["conversations"])
|
283 |
+
):
|
284 |
+
if isinstance(part, tuple):
|
285 |
+
if part[0] == "USER:":
|
286 |
+
part = part[0] + part[1] if not user_token else part[1]
|
287 |
+
# this is still the user query, we should
|
288 |
+
res = self._tokenize(
|
289 |
+
part.strip(), add_eos_token=False, strip_bos_token=True
|
290 |
+
)
|
291 |
+
if user_token:
|
292 |
+
res["input_ids"] = [user_token, *res["input_ids"]]
|
293 |
+
# everything from this is masked out from the labels
|
294 |
+
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
295 |
+
elif part[0] == "ASSISTANT:":
|
296 |
+
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
297 |
+
part = part[0] + part[1] if not assistant_token else part[1]
|
298 |
+
# this should be the assistent response, should end with an eos token
|
299 |
+
res = self._tokenize(
|
300 |
+
part.strip(), add_eos_token=True, strip_bos_token=True
|
301 |
+
)
|
302 |
+
if assistant_token:
|
303 |
+
res["input_ids"] = [assistant_token, *res["input_ids"]]
|
304 |
+
# not masked out from labels
|
305 |
+
labels = copy.deepcopy(res["input_ids"])
|
306 |
+
else:
|
307 |
+
logging.warning("unhandled role: " + part[0])
|
308 |
+
else:
|
309 |
+
# this is only ever the first part, should include the bos token and the user query
|
310 |
+
res = self._tokenize(
|
311 |
+
part.strip(), add_eos_token=False, strip_bos_token=False
|
312 |
+
)
|
313 |
+
# everything from this is masked out from the labels
|
314 |
+
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
315 |
+
input_ids = res["input_ids"]
|
316 |
+
input_len = len(input_ids)
|
317 |
+
result["input_ids"][current_len : current_len + input_len] = input_ids
|
318 |
+
result["attention_mask"][current_len : current_len + input_len] = [
|
319 |
+
1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids
|
320 |
+
]
|
321 |
+
result["labels"][current_len : current_len + input_len] = labels
|
322 |
+
current_len += input_len
|
323 |
+
return result
|
324 |
except (KeyError, AssertionError, IndexError) as e:
|
325 |
raise InvalidDataException(str(e))
|
326 |
+
|
327 |
+
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
328 |
+
result = self.tokenizer(
|
329 |
+
prompt,
|
330 |
+
truncation=True,
|
331 |
+
max_length=self.sequence_len,
|
332 |
+
padding=False,
|
333 |
+
return_tensors=None,
|
334 |
+
)
|
335 |
+
if (
|
336 |
+
result["input_ids"][-1] != self.tokenizer.eos_token_id
|
337 |
+
and len(result["input_ids"]) < self.sequence_len
|
338 |
+
and add_eos_token
|
339 |
+
):
|
340 |
+
result["input_ids"].append(self.tokenizer.eos_token_id)
|
341 |
+
result["attention_mask"].append(1)
|
342 |
+
|
343 |
+
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
344 |
+
result["input_ids"] = result["input_ids"][1:]
|
345 |
+
result["attention_mask"] = result["attention_mask"][1:]
|
346 |
+
|
347 |
+
result["labels"] = result["input_ids"].copy()
|
348 |
+
return result
|
src/axolotl/prompters.py
CHANGED
@@ -1,22 +1,52 @@
|
|
1 |
import copy
|
2 |
import dataclasses
|
|
|
3 |
from enum import auto, Enum
|
4 |
-
from typing import List, Tuple, Any, Union
|
5 |
|
6 |
IGNORE_TOKEN_ID = -100
|
7 |
|
8 |
|
|
|
|
|
|
|
|
|
|
|
9 |
class AlpacaPrompter:
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
def build_prompt(
|
15 |
self,
|
16 |
instruction: str,
|
17 |
input: Union[None, str] = None,
|
18 |
output: Union[None, str] = None,
|
19 |
-
) -> str:
|
20 |
# returns the full prompt from instruction and optional input
|
21 |
# if a label (=response, =output) is provided, it's also appended.
|
22 |
if input:
|
@@ -25,19 +55,42 @@ class AlpacaPrompter:
|
|
25 |
res = self.prompt_no_input.format(instruction=instruction)
|
26 |
if output:
|
27 |
res = f"{res}{output}"
|
28 |
-
|
29 |
|
30 |
def get_response(self, output: str) -> str:
|
31 |
return output.split(self.response_split)[1].strip()
|
32 |
|
33 |
|
|
|
|
|
|
|
|
|
|
|
34 |
class JeopardyPrompter(AlpacaPrompter):
|
35 |
prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
36 |
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
class CompletionPrompter(AlpacaPrompter):
|
39 |
-
def build_prompt(
|
40 |
-
|
|
|
|
|
41 |
|
42 |
def get_response(self, output: str) -> str:
|
43 |
return output.strip()
|
@@ -52,11 +105,44 @@ class NomicGPT4AllPrompter(AlpacaPrompter):
|
|
52 |
|
53 |
|
54 |
class ReflectAlpacaPrompter:
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
58 |
response_split = "### Response:"
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
def build_prompt(
|
61 |
self,
|
62 |
instruction: str,
|
@@ -64,7 +150,7 @@ class ReflectAlpacaPrompter:
|
|
64 |
output: Union[None, str] = None,
|
65 |
reflection: Union[None, str] = None,
|
66 |
corrected: Union[None, str] = None,
|
67 |
-
) -> str:
|
68 |
# returns the full prompt from instruction and optional input
|
69 |
# if a label (=response, =output) is provided, it's also appended.
|
70 |
if input:
|
@@ -76,7 +162,7 @@ class ReflectAlpacaPrompter:
|
|
76 |
output=output, reflection=reflection, corrected=corrected
|
77 |
)
|
78 |
res = f"{res}{label}"
|
79 |
-
|
80 |
|
81 |
def get_response(self, output: str) -> str:
|
82 |
return output.split(self.response_split)[1].strip()
|
@@ -103,15 +189,16 @@ class Conversation:
|
|
103 |
sep: str = "###"
|
104 |
sep2: str = None
|
105 |
|
106 |
-
def get_prompt(self):
|
107 |
seps = [self.sep, self.sep2]
|
108 |
-
|
|
|
109 |
for i, (role, message) in enumerate(self.messages):
|
110 |
if message:
|
111 |
-
|
112 |
else:
|
113 |
-
|
114 |
-
|
115 |
|
116 |
def copy(self):
|
117 |
return Conversation(
|
@@ -136,12 +223,24 @@ conv_vicuna_v1_1 = Conversation(
|
|
136 |
offset=0,
|
137 |
sep_style=SeparatorStyle.TWO,
|
138 |
sep=" ",
|
139 |
-
sep2="
|
140 |
)
|
141 |
|
142 |
|
143 |
class ShareGPTPrompter:
|
144 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
# ignore the system prompt if provided
|
146 |
if source[0]["from"] == "system":
|
147 |
source.pop(0)
|
@@ -171,61 +270,6 @@ class ShareGPTPrompter:
|
|
171 |
role = roles[sentence["from"]]
|
172 |
assert role == conv.roles[j % 2]
|
173 |
conv.append_message(role, sentence["value"])
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
# Tokenize conversations
|
178 |
-
tokenized_result = tokenizer(
|
179 |
-
conversation,
|
180 |
-
truncation=True,
|
181 |
-
max_length=sequence_len, # FIXME
|
182 |
-
padding=False,
|
183 |
-
return_tensors=None,
|
184 |
-
)
|
185 |
-
target = copy.deepcopy(tokenized_result["input_ids"])
|
186 |
-
|
187 |
-
# Mask targets
|
188 |
-
sep = conv.sep + conv.roles[1] + ": "
|
189 |
-
|
190 |
-
rounds = conversation.split(conv.sep2)
|
191 |
-
rounds = [r + conv.sep2 for r in rounds]
|
192 |
-
cur_len = 1
|
193 |
-
target[0] = IGNORE_TOKEN_ID # mask out the bos
|
194 |
-
for i, rou in enumerate(rounds):
|
195 |
-
if rou == "":
|
196 |
-
break
|
197 |
-
|
198 |
-
parts = rou.split(sep)
|
199 |
-
if len(parts) != 2:
|
200 |
-
break
|
201 |
-
parts[0] += sep
|
202 |
-
round_len = (
|
203 |
-
len(tokenizer(rou)["input_ids"]) - 1
|
204 |
-
) # -1 ignores the bos_token generated for this
|
205 |
-
# we have to strip the initial part, any dangling whitespace creates an additional ghost token
|
206 |
-
instruction_len = (
|
207 |
-
len(tokenizer(parts[0].strip())["input_ids"]) - 1
|
208 |
-
) # -1 ignores the bos_token generated for this
|
209 |
-
target[cur_len : cur_len + instruction_len] = [
|
210 |
-
IGNORE_TOKEN_ID
|
211 |
-
] * instruction_len
|
212 |
-
|
213 |
-
cur_len += round_len
|
214 |
-
if cur_len >= sequence_len:
|
215 |
-
break
|
216 |
-
|
217 |
-
# Fix: Truncate the target to have the same length as input_ids
|
218 |
-
target = target[: len(tokenized_result["input_ids"])]
|
219 |
-
# target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
|
220 |
-
|
221 |
-
attention_mask = [
|
222 |
-
1 if x != tokenizer.pad_token_id else 0
|
223 |
-
for x in tokenized_result["input_ids"]
|
224 |
-
]
|
225 |
-
|
226 |
-
# TODO truncate len to sequence_len
|
227 |
-
return dict(
|
228 |
-
input_ids=tokenized_result["input_ids"],
|
229 |
-
labels=target,
|
230 |
-
attention_mask=attention_mask,
|
231 |
-
)
|
|
|
1 |
import copy
|
2 |
import dataclasses
|
3 |
+
import logging
|
4 |
from enum import auto, Enum
|
5 |
+
from typing import List, Tuple, Any, Union, Generator
|
6 |
|
7 |
IGNORE_TOKEN_ID = -100
|
8 |
|
9 |
|
10 |
+
class PromptStyle(Enum):
|
11 |
+
instruct = "instruct"
|
12 |
+
chat = "chat"
|
13 |
+
|
14 |
+
|
15 |
class AlpacaPrompter:
|
16 |
+
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
17 |
+
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
18 |
+
prompt_style = None
|
19 |
+
|
20 |
+
def __init__(self, prompt_style="instruct"):
|
21 |
+
self.prompt_style = prompt_style
|
22 |
+
self.match_prompt_style()
|
23 |
+
|
24 |
+
def match_prompt_style(self):
|
25 |
+
if self.prompt_style == PromptStyle.instruct.value:
|
26 |
+
self.prompt_input = (
|
27 |
+
self.system_prompt
|
28 |
+
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
29 |
+
)
|
30 |
+
self.prompt_no_input = (
|
31 |
+
self.system_no_input_prompt
|
32 |
+
+ "### Instruction:\n{instruction}\n\n### Response:\n"
|
33 |
+
)
|
34 |
+
self.response_split = "### Response:"
|
35 |
+
if self.prompt_style == PromptStyle.chat.value:
|
36 |
+
self.prompt_input = (
|
37 |
+
self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
|
38 |
+
)
|
39 |
+
self.prompt_no_input = (
|
40 |
+
self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
|
41 |
+
)
|
42 |
+
self.response_split = "ASSISTANT:"
|
43 |
|
44 |
def build_prompt(
|
45 |
self,
|
46 |
instruction: str,
|
47 |
input: Union[None, str] = None,
|
48 |
output: Union[None, str] = None,
|
49 |
+
) -> Generator[str, None, None]:
|
50 |
# returns the full prompt from instruction and optional input
|
51 |
# if a label (=response, =output) is provided, it's also appended.
|
52 |
if input:
|
|
|
55 |
res = self.prompt_no_input.format(instruction=instruction)
|
56 |
if output:
|
57 |
res = f"{res}{output}"
|
58 |
+
yield res
|
59 |
|
60 |
def get_response(self, output: str) -> str:
|
61 |
return output.split(self.response_split)[1].strip()
|
62 |
|
63 |
|
64 |
+
class UnpromptedPrompter(AlpacaPrompter):
|
65 |
+
system_prompt = ""
|
66 |
+
system_no_input_prompt = ""
|
67 |
+
|
68 |
+
|
69 |
class JeopardyPrompter(AlpacaPrompter):
|
70 |
prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
71 |
|
72 |
|
73 |
+
class MultipleChoiceExplainPrompter(AlpacaPrompter):
|
74 |
+
system_prompt = (
|
75 |
+
"Choose the answer that best answers the question. Explain your reasoning."
|
76 |
+
)
|
77 |
+
|
78 |
+
|
79 |
+
class MultipleChoiceConcisePrompter(AlpacaPrompter):
|
80 |
+
prompt_input = "Choose the answer that best answers the question. Be concise in your response.\n\nUSER: {instruction}\n{input}\nASSISTANT:\n"
|
81 |
+
|
82 |
+
|
83 |
+
class SummarizeTLDRPrompter(AlpacaPrompter):
|
84 |
+
prompt_no_input = (
|
85 |
+
"USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
|
86 |
+
)
|
87 |
+
|
88 |
+
|
89 |
class CompletionPrompter(AlpacaPrompter):
|
90 |
+
def build_prompt(
|
91 |
+
self, instruction: str, input=None, output=None
|
92 |
+
) -> Generator[str, None, None]:
|
93 |
+
yield instruction
|
94 |
|
95 |
def get_response(self, output: str) -> str:
|
96 |
return output.strip()
|
|
|
105 |
|
106 |
|
107 |
class ReflectAlpacaPrompter:
|
108 |
+
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
|
109 |
+
system_no_input_prompt = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
|
110 |
+
|
111 |
+
prompt_input = (
|
112 |
+
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
113 |
+
)
|
114 |
+
prompt_no_input = "### Instruction:\n{instruction}\n\n### Response:\n"
|
115 |
+
agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
|
116 |
response_split = "### Response:"
|
117 |
|
118 |
+
def __init__(self, prompt_style="instruct"):
|
119 |
+
self.prompt_style = prompt_style
|
120 |
+
self.match_prompt_style()
|
121 |
+
|
122 |
+
def match_prompt_style(self):
|
123 |
+
if self.prompt_style == PromptStyle.instruct.value:
|
124 |
+
self.prompt_input = (
|
125 |
+
self.system_prompt
|
126 |
+
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
127 |
+
)
|
128 |
+
self.prompt_no_input = (
|
129 |
+
self.system_no_input_prompt
|
130 |
+
+ "### Instruction:\n{instruction}\n\n### Response:\n"
|
131 |
+
)
|
132 |
+
self.agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
|
133 |
+
self.response_split = "### Final Response:"
|
134 |
+
if self.prompt_style == PromptStyle.chat.value:
|
135 |
+
self.prompt_input = (
|
136 |
+
self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
|
137 |
+
)
|
138 |
+
self.prompt_no_input = (
|
139 |
+
self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
|
140 |
+
)
|
141 |
+
self.agent_label = (
|
142 |
+
"\nTHOUGHT: {output}\nASSISTANT REFLECTION: {reflection}\nASSISTANT:"
|
143 |
+
)
|
144 |
+
self.response_split = "ASSISTANT:"
|
145 |
+
|
146 |
def build_prompt(
|
147 |
self,
|
148 |
instruction: str,
|
|
|
150 |
output: Union[None, str] = None,
|
151 |
reflection: Union[None, str] = None,
|
152 |
corrected: Union[None, str] = None,
|
153 |
+
) -> Generator[str, None, None]:
|
154 |
# returns the full prompt from instruction and optional input
|
155 |
# if a label (=response, =output) is provided, it's also appended.
|
156 |
if input:
|
|
|
162 |
output=output, reflection=reflection, corrected=corrected
|
163 |
)
|
164 |
res = f"{res}{label}"
|
165 |
+
yield res
|
166 |
|
167 |
def get_response(self, output: str) -> str:
|
168 |
return output.split(self.response_split)[1].strip()
|
|
|
189 |
sep: str = "###"
|
190 |
sep2: str = None
|
191 |
|
192 |
+
def get_prompt(self) -> Generator[str, None, None]:
|
193 |
seps = [self.sep, self.sep2]
|
194 |
+
preamble = self.system + seps[0]
|
195 |
+
yield preamble
|
196 |
for i, (role, message) in enumerate(self.messages):
|
197 |
if message:
|
198 |
+
yield (role + ":", " " + message)
|
199 |
else:
|
200 |
+
logging.warning("role with empty message: " + role)
|
201 |
+
yield (role + ":",)
|
202 |
|
203 |
def copy(self):
|
204 |
return Conversation(
|
|
|
223 |
offset=0,
|
224 |
sep_style=SeparatorStyle.TWO,
|
225 |
sep=" ",
|
226 |
+
sep2=" ",
|
227 |
)
|
228 |
|
229 |
|
230 |
class ShareGPTPrompter:
|
231 |
+
def __init__(self, prompt_style=None):
|
232 |
+
if prompt_style != PromptStyle.chat.value:
|
233 |
+
raise Exception(
|
234 |
+
f"unsupported prompt_style for ShareGPTPrompter({prompt_style})"
|
235 |
+
)
|
236 |
+
|
237 |
+
# def match_prompt_style(self):
|
238 |
+
# if self.prompt_style == PromptStyle.chat.value:
|
239 |
+
# self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
|
240 |
+
# self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
|
241 |
+
# self.response_split = "ASSISTANT:"
|
242 |
+
|
243 |
+
def build_prompt(self, source, *args, **kwargs) -> Generator[str, None, None]:
|
244 |
# ignore the system prompt if provided
|
245 |
if source[0]["from"] == "system":
|
246 |
source.pop(0)
|
|
|
270 |
role = roles[sentence["from"]]
|
271 |
assert role == conv.roles[j % 2]
|
272 |
conv.append_message(role, sentence["value"])
|
273 |
+
|
274 |
+
for part in conv.get_prompt():
|
275 |
+
yield part
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/axolotl/utils/data.py
CHANGED
@@ -8,10 +8,13 @@ from datasets import (
|
|
8 |
IterableDataset,
|
9 |
Dataset,
|
10 |
concatenate_datasets,
|
|
|
11 |
)
|
12 |
from huggingface_hub import hf_hub_download
|
|
|
13 |
|
14 |
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
|
|
|
15 |
from axolotl.prompt_tokenizers import (
|
16 |
AlpacaPromptTokenizingStrategy,
|
17 |
GPTeacherPromptTokenizingStrategy,
|
@@ -20,6 +23,8 @@ from axolotl.prompt_tokenizers import (
|
|
20 |
ShareGPTPromptTokenizingStrategy,
|
21 |
JeopardyPromptTokenizingStrategy,
|
22 |
CompletionPromptTokenizingStrategy,
|
|
|
|
|
23 |
)
|
24 |
from axolotl.prompters import (
|
25 |
AlpacaPrompter,
|
@@ -28,16 +33,24 @@ from axolotl.prompters import (
|
|
28 |
ShareGPTPrompter,
|
29 |
JeopardyPrompter,
|
30 |
CompletionPrompter,
|
|
|
|
|
|
|
31 |
)
|
32 |
|
33 |
|
34 |
-
def load_tokenized_prepared_datasets(
|
|
|
|
|
|
|
35 |
ds_hash = str(
|
36 |
md5(
|
37 |
(
|
38 |
str(cfg.sequence_len)
|
39 |
+ "@"
|
40 |
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
|
|
|
|
41 |
).encode("utf-8")
|
42 |
).hexdigest()
|
43 |
)
|
@@ -46,8 +59,19 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
|
|
46 |
if cfg.dataset_prepared_path
|
47 |
else Path(default_dataset_prepared_path) / ds_hash
|
48 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
-
if
|
|
|
|
|
51 |
logging.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
52 |
dataset = load_from_disk(str(prepared_ds_path))
|
53 |
logging.info("Prepared dataset loaded from disk...")
|
@@ -59,7 +83,7 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
|
|
59 |
ds = None
|
60 |
ds_from_hub = False
|
61 |
try:
|
62 |
-
load_dataset(d.path, streaming=True)
|
63 |
ds_from_hub = True
|
64 |
except FileNotFoundError:
|
65 |
pass
|
@@ -67,64 +91,117 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
|
|
67 |
# prefer local dataset, even if hub exists
|
68 |
if Path(d.path).exists():
|
69 |
ds: IterableDataset = load_dataset(
|
70 |
-
"json", data_files=d.path, streaming=
|
71 |
)
|
72 |
elif ds_from_hub:
|
73 |
if d.data_files:
|
74 |
-
ds = load_dataset(
|
|
|
|
|
|
|
|
|
|
|
75 |
else:
|
76 |
-
ds = load_dataset(d.path, streaming=True)
|
77 |
else:
|
78 |
fp = hf_hub_download(
|
79 |
repo_id=d.path, repo_type="dataset", filename=d.data_files
|
80 |
)
|
81 |
-
ds = load_dataset("json", data_files=fp, streaming=
|
82 |
if not ds:
|
83 |
raise Exception("unhandled dataset load")
|
84 |
-
|
85 |
-
if d.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
ds_strategy = AlpacaPromptTokenizingStrategy(
|
87 |
-
AlpacaPrompter(),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
)
|
89 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
90 |
datasets.append(ds_wrapper)
|
91 |
-
elif
|
92 |
ds_strategy = JeopardyPromptTokenizingStrategy(
|
93 |
-
JeopardyPrompter(),
|
|
|
|
|
|
|
94 |
)
|
95 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
96 |
datasets.append(ds_wrapper)
|
97 |
-
elif
|
98 |
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
99 |
-
AlpacaPrompter(),
|
|
|
|
|
|
|
100 |
)
|
101 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
102 |
datasets.append(ds_wrapper)
|
103 |
-
elif
|
104 |
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
105 |
-
GPTeacherPrompter(),
|
106 |
tokenizer,
|
107 |
cfg.train_on_inputs,
|
108 |
cfg.sequence_len,
|
109 |
)
|
110 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
111 |
datasets.append(ds_wrapper)
|
112 |
-
elif
|
113 |
ds_strategy = AlpacaReflectionPTStrategy(
|
114 |
-
ReflectAlpacaPrompter(),
|
115 |
tokenizer,
|
116 |
cfg.train_on_inputs,
|
117 |
cfg.sequence_len,
|
118 |
)
|
119 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
120 |
datasets.append(ds_wrapper)
|
121 |
-
elif
|
122 |
ds_strategy = ShareGPTPromptTokenizingStrategy(
|
123 |
-
ShareGPTPrompter(),
|
|
|
|
|
|
|
124 |
)
|
125 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
126 |
datasets.append(ds_wrapper)
|
127 |
-
elif
|
128 |
ds_strategy = CompletionPromptTokenizingStrategy(
|
129 |
CompletionPrompter(),
|
130 |
tokenizer,
|
@@ -146,11 +223,20 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
|
|
146 |
f"Saving merged prepared dataset to disk... {prepared_ds_path}"
|
147 |
)
|
148 |
dataset.save_to_disk(prepared_ds_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
return dataset
|
151 |
|
152 |
|
153 |
-
def load_prepare_datasets(
|
|
|
|
|
154 |
max_packed_sequence_len = (
|
155 |
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
156 |
)
|
@@ -158,16 +244,20 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
158 |
max_packed_sequence_len, cfg.sequence_len
|
159 |
) # make sure we don't accidentally set it larger than sequence_len
|
160 |
|
|
|
161 |
if cfg.max_packed_sequence_len is not None:
|
162 |
# see if we can go ahead and load the stacked dataset
|
163 |
-
|
164 |
ds_hash = str(
|
165 |
md5(
|
166 |
(
|
167 |
str(cfg.sequence_len)
|
168 |
+ "@"
|
169 |
+ str(max_packed_sequence_len)
|
|
|
170 |
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
|
|
|
|
171 |
).encode("utf-8")
|
172 |
).hexdigest()
|
173 |
)
|
@@ -177,17 +267,42 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
177 |
else Path(default_dataset_prepared_path) / ds_hash
|
178 |
)
|
179 |
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
logging.info(
|
182 |
f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
|
183 |
)
|
184 |
dataset = load_from_disk(str(prepared_ds_path))
|
185 |
logging.info("Prepared packed dataset loaded from disk...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
else:
|
187 |
dataset = load_tokenized_prepared_datasets(
|
188 |
tokenizer, cfg, default_dataset_prepared_path
|
189 |
)
|
190 |
|
|
|
|
|
|
|
191 |
constant_len_dataset = ConstantLengthDataset(
|
192 |
tokenizer,
|
193 |
[dataset],
|
@@ -204,9 +319,9 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
204 |
d
|
205 |
for d in dataset
|
206 |
if len(d["input_ids"]) < cfg.sequence_len
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
]
|
211 |
)
|
212 |
|
@@ -215,6 +330,13 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
215 |
f"Saving packed prepared dataset to disk... {prepared_ds_path}"
|
216 |
)
|
217 |
dataset.save_to_disk(prepared_ds_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
else:
|
219 |
dataset = load_tokenized_prepared_datasets(
|
220 |
tokenizer, cfg, default_dataset_prepared_path
|
|
|
8 |
IterableDataset,
|
9 |
Dataset,
|
10 |
concatenate_datasets,
|
11 |
+
DatasetDict,
|
12 |
)
|
13 |
from huggingface_hub import hf_hub_download
|
14 |
+
from transformers import PreTrainedTokenizerBase
|
15 |
|
16 |
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
|
17 |
+
from axolotl.prompt_strategies import load
|
18 |
from axolotl.prompt_tokenizers import (
|
19 |
AlpacaPromptTokenizingStrategy,
|
20 |
GPTeacherPromptTokenizingStrategy,
|
|
|
23 |
ShareGPTPromptTokenizingStrategy,
|
24 |
JeopardyPromptTokenizingStrategy,
|
25 |
CompletionPromptTokenizingStrategy,
|
26 |
+
AlpacaMultipleChoicePromptTokenizingStrategy,
|
27 |
+
SummarizeTLDRPromptTokenizingStrategy,
|
28 |
)
|
29 |
from axolotl.prompters import (
|
30 |
AlpacaPrompter,
|
|
|
33 |
ShareGPTPrompter,
|
34 |
JeopardyPrompter,
|
35 |
CompletionPrompter,
|
36 |
+
MultipleChoiceExplainPrompter,
|
37 |
+
SummarizeTLDRPrompter,
|
38 |
+
MultipleChoiceConcisePrompter,
|
39 |
)
|
40 |
|
41 |
|
42 |
+
def load_tokenized_prepared_datasets(
|
43 |
+
tokenizer, cfg, default_dataset_prepared_path
|
44 |
+
) -> DatasetDict:
|
45 |
+
tokenizer_name = tokenizer.__class__.__name__
|
46 |
ds_hash = str(
|
47 |
md5(
|
48 |
(
|
49 |
str(cfg.sequence_len)
|
50 |
+ "@"
|
51 |
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
52 |
+
+ "|"
|
53 |
+
+ tokenizer_name
|
54 |
).encode("utf-8")
|
55 |
).hexdigest()
|
56 |
)
|
|
|
59 |
if cfg.dataset_prepared_path
|
60 |
else Path(default_dataset_prepared_path) / ds_hash
|
61 |
)
|
62 |
+
dataset = None
|
63 |
+
try:
|
64 |
+
if cfg.push_dataset_to_hub:
|
65 |
+
dataset = load_dataset(
|
66 |
+
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=True
|
67 |
+
)
|
68 |
+
dataset = dataset["train"]
|
69 |
+
except:
|
70 |
+
pass
|
71 |
|
72 |
+
if dataset:
|
73 |
+
...
|
74 |
+
elif any(prepared_ds_path.glob("*")):
|
75 |
logging.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
76 |
dataset = load_from_disk(str(prepared_ds_path))
|
77 |
logging.info("Prepared dataset loaded from disk...")
|
|
|
83 |
ds = None
|
84 |
ds_from_hub = False
|
85 |
try:
|
86 |
+
load_dataset(d.path, streaming=True, use_auth_token=True)
|
87 |
ds_from_hub = True
|
88 |
except FileNotFoundError:
|
89 |
pass
|
|
|
91 |
# prefer local dataset, even if hub exists
|
92 |
if Path(d.path).exists():
|
93 |
ds: IterableDataset = load_dataset(
|
94 |
+
"json", data_files=d.path, streaming=False, split=None
|
95 |
)
|
96 |
elif ds_from_hub:
|
97 |
if d.data_files:
|
98 |
+
ds = load_dataset(
|
99 |
+
d.path,
|
100 |
+
streaming=False,
|
101 |
+
data_files=d.data_files,
|
102 |
+
use_auth_token=True,
|
103 |
+
)
|
104 |
else:
|
105 |
+
ds = load_dataset(d.path, streaming=False, use_auth_token=True)
|
106 |
else:
|
107 |
fp = hf_hub_download(
|
108 |
repo_id=d.path, repo_type="dataset", filename=d.data_files
|
109 |
)
|
110 |
+
ds = load_dataset("json", data_files=fp, streaming=False, split=None)
|
111 |
if not ds:
|
112 |
raise Exception("unhandled dataset load")
|
113 |
+
# support for using a subset of the data
|
114 |
+
if d.shards:
|
115 |
+
ds = ds.shuffle(seed=42)["train"].shard(num_shards=cfg.shards, index=0)
|
116 |
+
d_type = d.type
|
117 |
+
d_type_split = d_type.split(":")
|
118 |
+
d_base_type = d_type_split[0]
|
119 |
+
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
120 |
+
if ds_strategy := load(d.type, tokenizer, cfg):
|
121 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
122 |
+
datasets.append(ds_wrapper)
|
123 |
+
elif d_base_type == "alpaca":
|
124 |
ds_strategy = AlpacaPromptTokenizingStrategy(
|
125 |
+
AlpacaPrompter(d_prompt_style),
|
126 |
+
tokenizer,
|
127 |
+
cfg.train_on_inputs,
|
128 |
+
cfg.sequence_len,
|
129 |
+
)
|
130 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
131 |
+
datasets.append(ds_wrapper)
|
132 |
+
elif d_base_type == "explainchoice":
|
133 |
+
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
134 |
+
MultipleChoiceExplainPrompter(d_prompt_style),
|
135 |
+
tokenizer,
|
136 |
+
cfg.train_on_inputs,
|
137 |
+
cfg.sequence_len,
|
138 |
+
)
|
139 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
140 |
+
datasets.append(ds_wrapper)
|
141 |
+
elif d_base_type == "concisechoice":
|
142 |
+
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
143 |
+
MultipleChoiceConcisePrompter(d_prompt_style),
|
144 |
+
tokenizer,
|
145 |
+
cfg.train_on_inputs,
|
146 |
+
cfg.sequence_len,
|
147 |
+
)
|
148 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
149 |
+
datasets.append(ds_wrapper)
|
150 |
+
elif d_base_type == "summarizetldr":
|
151 |
+
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
|
152 |
+
SummarizeTLDRPrompter(d_prompt_style),
|
153 |
+
tokenizer,
|
154 |
+
cfg.train_on_inputs,
|
155 |
+
cfg.sequence_len,
|
156 |
)
|
157 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
158 |
datasets.append(ds_wrapper)
|
159 |
+
elif d_base_type == "jeopardy":
|
160 |
ds_strategy = JeopardyPromptTokenizingStrategy(
|
161 |
+
JeopardyPrompter(d_prompt_style),
|
162 |
+
tokenizer,
|
163 |
+
cfg.train_on_inputs,
|
164 |
+
cfg.sequence_len,
|
165 |
)
|
166 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
167 |
datasets.append(ds_wrapper)
|
168 |
+
elif d_base_type == "oasst":
|
169 |
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
170 |
+
AlpacaPrompter(d_prompt_style),
|
171 |
+
tokenizer,
|
172 |
+
cfg.train_on_inputs,
|
173 |
+
cfg.sequence_len,
|
174 |
)
|
175 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
176 |
datasets.append(ds_wrapper)
|
177 |
+
elif d_base_type == "gpteacher":
|
178 |
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
179 |
+
GPTeacherPrompter(d_prompt_style),
|
180 |
tokenizer,
|
181 |
cfg.train_on_inputs,
|
182 |
cfg.sequence_len,
|
183 |
)
|
184 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
185 |
datasets.append(ds_wrapper)
|
186 |
+
elif d_base_type == "reflection":
|
187 |
ds_strategy = AlpacaReflectionPTStrategy(
|
188 |
+
ReflectAlpacaPrompter(d_prompt_style),
|
189 |
tokenizer,
|
190 |
cfg.train_on_inputs,
|
191 |
cfg.sequence_len,
|
192 |
)
|
193 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
194 |
datasets.append(ds_wrapper)
|
195 |
+
elif d_base_type == "sharegpt":
|
196 |
ds_strategy = ShareGPTPromptTokenizingStrategy(
|
197 |
+
ShareGPTPrompter(d_prompt_style),
|
198 |
+
tokenizer,
|
199 |
+
cfg.train_on_inputs,
|
200 |
+
cfg.sequence_len,
|
201 |
)
|
202 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
203 |
datasets.append(ds_wrapper)
|
204 |
+
elif d_base_type == "completion":
|
205 |
ds_strategy = CompletionPromptTokenizingStrategy(
|
206 |
CompletionPrompter(),
|
207 |
tokenizer,
|
|
|
223 |
f"Saving merged prepared dataset to disk... {prepared_ds_path}"
|
224 |
)
|
225 |
dataset.save_to_disk(prepared_ds_path)
|
226 |
+
if cfg.push_dataset_to_hub:
|
227 |
+
logging.info(
|
228 |
+
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
229 |
+
)
|
230 |
+
dataset.push_to_hub(
|
231 |
+
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
232 |
+
)
|
233 |
|
234 |
return dataset
|
235 |
|
236 |
|
237 |
+
def load_prepare_datasets(
|
238 |
+
tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path
|
239 |
+
) -> (Dataset, Dataset):
|
240 |
max_packed_sequence_len = (
|
241 |
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
242 |
)
|
|
|
244 |
max_packed_sequence_len, cfg.sequence_len
|
245 |
) # make sure we don't accidentally set it larger than sequence_len
|
246 |
|
247 |
+
tokenizer_name = tokenizer.__class__.__name__
|
248 |
if cfg.max_packed_sequence_len is not None:
|
249 |
# see if we can go ahead and load the stacked dataset
|
250 |
+
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
251 |
ds_hash = str(
|
252 |
md5(
|
253 |
(
|
254 |
str(cfg.sequence_len)
|
255 |
+ "@"
|
256 |
+ str(max_packed_sequence_len)
|
257 |
+
+ seed
|
258 |
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
259 |
+
+ "|"
|
260 |
+
+ tokenizer_name
|
261 |
).encode("utf-8")
|
262 |
).hexdigest()
|
263 |
)
|
|
|
267 |
else Path(default_dataset_prepared_path) / ds_hash
|
268 |
)
|
269 |
|
270 |
+
dataset = None
|
271 |
+
try:
|
272 |
+
if cfg.push_dataset_to_hub:
|
273 |
+
logging.info(
|
274 |
+
f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
275 |
+
)
|
276 |
+
dataset = load_dataset(
|
277 |
+
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=True
|
278 |
+
)
|
279 |
+
dataset = dataset["train"]
|
280 |
+
except:
|
281 |
+
pass
|
282 |
+
|
283 |
+
if dataset:
|
284 |
+
...
|
285 |
+
elif any(prepared_ds_path.glob("*")):
|
286 |
logging.info(
|
287 |
f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
|
288 |
)
|
289 |
dataset = load_from_disk(str(prepared_ds_path))
|
290 |
logging.info("Prepared packed dataset loaded from disk...")
|
291 |
+
if cfg.push_dataset_to_hub:
|
292 |
+
logging.info(
|
293 |
+
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
294 |
+
)
|
295 |
+
dataset.push_to_hub(
|
296 |
+
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
297 |
+
)
|
298 |
else:
|
299 |
dataset = load_tokenized_prepared_datasets(
|
300 |
tokenizer, cfg, default_dataset_prepared_path
|
301 |
)
|
302 |
|
303 |
+
if cfg.seed:
|
304 |
+
dataset = dataset.shuffle(seed=cfg.seed)
|
305 |
+
|
306 |
constant_len_dataset = ConstantLengthDataset(
|
307 |
tokenizer,
|
308 |
[dataset],
|
|
|
319 |
d
|
320 |
for d in dataset
|
321 |
if len(d["input_ids"]) < cfg.sequence_len
|
322 |
+
and len(d["input_ids"]) > 0
|
323 |
+
and len(d["input_ids"]) == len(d["attention_mask"])
|
324 |
+
and len(d["input_ids"]) == len(d["labels"])
|
325 |
]
|
326 |
)
|
327 |
|
|
|
330 |
f"Saving packed prepared dataset to disk... {prepared_ds_path}"
|
331 |
)
|
332 |
dataset.save_to_disk(prepared_ds_path)
|
333 |
+
if cfg.push_dataset_to_hub:
|
334 |
+
logging.info(
|
335 |
+
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
336 |
+
)
|
337 |
+
dataset.push_to_hub(
|
338 |
+
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
339 |
+
)
|
340 |
else:
|
341 |
dataset = load_tokenized_prepared_datasets(
|
342 |
tokenizer, cfg, default_dataset_prepared_path
|
src/axolotl/utils/models.py
CHANGED
@@ -1,15 +1,18 @@
|
|
1 |
import logging
|
|
|
2 |
import os
|
3 |
from pathlib import Path
|
4 |
from typing import Optional, Tuple, TYPE_CHECKING
|
5 |
|
6 |
import torch
|
7 |
import transformers
|
|
|
8 |
from transformers import (
|
9 |
AutoModelForCausalLM,
|
10 |
AutoTokenizer,
|
11 |
PreTrainedModel,
|
12 |
AutoConfig,
|
|
|
13 |
)
|
14 |
|
15 |
try:
|
@@ -80,6 +83,16 @@ def load_model(
|
|
80 |
logging.exception(e)
|
81 |
raise e
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
try:
|
84 |
if cfg.load_4bit and is_llama_derived_model:
|
85 |
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
@@ -123,16 +136,46 @@ def load_model(
|
|
123 |
model = LlamaForCausalLM.from_pretrained(
|
124 |
base_model,
|
125 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
|
|
126 |
torch_dtype=torch_dtype,
|
127 |
device_map=cfg.device_map,
|
|
|
128 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
elif model_type:
|
130 |
model = getattr(transformers, model_type).from_pretrained(
|
131 |
base_model,
|
132 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
|
|
133 |
torch_dtype=torch_dtype,
|
134 |
device_map=cfg.device_map,
|
135 |
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
|
|
136 |
)
|
137 |
else:
|
138 |
config = AutoConfig.from_pretrained(
|
@@ -143,9 +186,11 @@ def load_model(
|
|
143 |
base_model,
|
144 |
config=config,
|
145 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
|
|
146 |
torch_dtype=torch_dtype,
|
147 |
device_map=cfg.device_map,
|
148 |
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
|
|
149 |
)
|
150 |
except Exception as e:
|
151 |
logging.error(
|
@@ -158,16 +203,26 @@ def load_model(
|
|
158 |
torch_dtype=torch_dtype,
|
159 |
device_map=cfg.device_map,
|
160 |
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
|
|
161 |
)
|
162 |
|
163 |
if not tokenizer:
|
164 |
try:
|
165 |
if is_llama_derived_model and "LlamaTokenizer" in globals():
|
166 |
-
tokenizer = LlamaTokenizer.from_pretrained(
|
|
|
|
|
|
|
167 |
else:
|
168 |
-
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
|
|
|
|
|
|
169 |
except:
|
170 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
|
|
171 |
|
172 |
logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
173 |
logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
@@ -181,14 +236,18 @@ def load_model(
|
|
181 |
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
182 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
183 |
|
184 |
-
if cfg.
|
185 |
-
for k, v in cfg.
|
186 |
tokenizer.add_special_tokens({k: v})
|
|
|
|
|
187 |
|
188 |
-
|
189 |
-
|
190 |
|
191 |
-
if
|
|
|
|
|
192 |
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
193 |
model = prepare_model_for_int8_training(model)
|
194 |
|
@@ -209,7 +268,11 @@ def load_model(
|
|
209 |
m.scales = m.scales.half()
|
210 |
m.bias = m.bias.half()
|
211 |
|
212 |
-
if
|
|
|
|
|
|
|
|
|
213 |
# llama is PROBABLY model parallelizable, but the default isn't that it is
|
214 |
# so let's only set it for the 4bit, see
|
215 |
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
|
@@ -222,6 +285,7 @@ def load_model(
|
|
222 |
requires_grad.append(f"{name}: {param.requires_grad}")
|
223 |
if len(requires_grad) == 0:
|
224 |
logging.warning("there are no parameters that require gradient updates")
|
|
|
225 |
|
226 |
# TODO resume_from_checkpoint handling
|
227 |
return model, tokenizer, lora_config
|
@@ -232,7 +296,7 @@ def load_adapter(model, cfg, adapter):
|
|
232 |
|
233 |
if adapter is None:
|
234 |
return model, None
|
235 |
-
if adapter == "lora":
|
236 |
return load_lora(model, cfg)
|
237 |
if adapter == "llama-adapter":
|
238 |
return load_llama_adapter(model, cfg)
|
@@ -254,7 +318,8 @@ def load_llama_adapter(model, cfg):
|
|
254 |
task_type="CAUSAL_LM",
|
255 |
)
|
256 |
|
257 |
-
if cfg.
|
|
|
258 |
model = PeftModel.from_pretrained(
|
259 |
model,
|
260 |
cfg.lora_model_dir,
|
@@ -296,7 +361,7 @@ def load_lora(model, cfg):
|
|
296 |
model,
|
297 |
cfg.lora_model_dir,
|
298 |
device_map=cfg.device_map,
|
299 |
-
torch_dtype=torch.float16,
|
300 |
)
|
301 |
else:
|
302 |
model = get_peft_model(model, lora_config)
|
|
|
1 |
import logging
|
2 |
+
import math
|
3 |
import os
|
4 |
from pathlib import Path
|
5 |
from typing import Optional, Tuple, TYPE_CHECKING
|
6 |
|
7 |
import torch
|
8 |
import transformers
|
9 |
+
from torch import nn
|
10 |
from transformers import (
|
11 |
AutoModelForCausalLM,
|
12 |
AutoTokenizer,
|
13 |
PreTrainedModel,
|
14 |
AutoConfig,
|
15 |
+
BitsAndBytesConfig,
|
16 |
)
|
17 |
|
18 |
try:
|
|
|
83 |
logging.exception(e)
|
84 |
raise e
|
85 |
|
86 |
+
model_kwargs = {}
|
87 |
+
if cfg.adapter == "qlora":
|
88 |
+
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
89 |
+
load_in_4bit=True,
|
90 |
+
llm_int8_threshold=6.0,
|
91 |
+
llm_int8_has_fp16_weight=False,
|
92 |
+
bnb_4bit_compute_dtype=torch.float16,
|
93 |
+
bnb_4bit_use_double_quant=True,
|
94 |
+
bnb_4bit_quant_type="nf4",
|
95 |
+
)
|
96 |
try:
|
97 |
if cfg.load_4bit and is_llama_derived_model:
|
98 |
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
|
|
136 |
model = LlamaForCausalLM.from_pretrained(
|
137 |
base_model,
|
138 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
139 |
+
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
140 |
torch_dtype=torch_dtype,
|
141 |
device_map=cfg.device_map,
|
142 |
+
**model_kwargs,
|
143 |
)
|
144 |
+
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
145 |
+
# This is a WIP, still an issue with the backward pass
|
146 |
+
# RuntimeError: grad can be implicitly created only for scalar outputs
|
147 |
+
# TODO: try config.sequence_parallel = False
|
148 |
+
# # https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/tests/models/test_gpt_neox.py#L12
|
149 |
+
# # https://github.com/HazyResearch/flash-attention/tree/main/training#model-components
|
150 |
+
# # add `**kwargs` to https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/flash_attn/models/gpt.py#L442
|
151 |
+
# from flash_attn.utils.pretrained import state_dict_from_pretrained
|
152 |
+
# from flash_attn.models.gpt import GPTLMHeadModel
|
153 |
+
# from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config
|
154 |
+
# from transformers import GPTNeoXConfig
|
155 |
+
# config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(base_model))
|
156 |
+
# config.use_flash_attn = True
|
157 |
+
# config.fused_bias_fc = True
|
158 |
+
# config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast"
|
159 |
+
# config.activation_function = "gelu_fast"
|
160 |
+
# config.fused_dropout_add_ln = True
|
161 |
+
# # config.residual_in_fp32 = True
|
162 |
+
#
|
163 |
+
# model: GPTLMHeadModel = GPTLMHeadModel.from_pretrained(
|
164 |
+
# base_model,
|
165 |
+
# config,
|
166 |
+
# dtype=torch_dtype,
|
167 |
+
# device=cfg.device,
|
168 |
+
# )
|
169 |
+
# model.train() # sets to train instead of eval mode
|
170 |
elif model_type:
|
171 |
model = getattr(transformers, model_type).from_pretrained(
|
172 |
base_model,
|
173 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
174 |
+
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
175 |
torch_dtype=torch_dtype,
|
176 |
device_map=cfg.device_map,
|
177 |
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
178 |
+
**model_kwargs,
|
179 |
)
|
180 |
else:
|
181 |
config = AutoConfig.from_pretrained(
|
|
|
186 |
base_model,
|
187 |
config=config,
|
188 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
189 |
+
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
190 |
torch_dtype=torch_dtype,
|
191 |
device_map=cfg.device_map,
|
192 |
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
193 |
+
**model_kwargs,
|
194 |
)
|
195 |
except Exception as e:
|
196 |
logging.error(
|
|
|
203 |
torch_dtype=torch_dtype,
|
204 |
device_map=cfg.device_map,
|
205 |
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
206 |
+
**model_kwargs,
|
207 |
)
|
208 |
|
209 |
if not tokenizer:
|
210 |
try:
|
211 |
if is_llama_derived_model and "LlamaTokenizer" in globals():
|
212 |
+
tokenizer = LlamaTokenizer.from_pretrained(
|
213 |
+
model,
|
214 |
+
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
215 |
+
)
|
216 |
else:
|
217 |
+
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
218 |
+
model,
|
219 |
+
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
220 |
+
)
|
221 |
except:
|
222 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
223 |
+
base_model_config,
|
224 |
+
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
225 |
+
)
|
226 |
|
227 |
logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
228 |
logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
|
|
236 |
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
237 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
238 |
|
239 |
+
if cfg.special_tokens:
|
240 |
+
for k, v in cfg.special_tokens.items():
|
241 |
tokenizer.add_special_tokens({k: v})
|
242 |
+
if cfg.tokens:
|
243 |
+
tokenizer.add_tokens(list(cfg.tokens))
|
244 |
|
245 |
+
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
246 |
+
model.resize_token_embeddings(embeddings_len)
|
247 |
|
248 |
+
if (
|
249 |
+
(cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora"
|
250 |
+
) and not cfg.load_4bit:
|
251 |
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
252 |
model = prepare_model_for_int8_training(model)
|
253 |
|
|
|
268 |
m.scales = m.scales.half()
|
269 |
m.bias = m.bias.half()
|
270 |
|
271 |
+
if (
|
272 |
+
torch.cuda.device_count() > 1
|
273 |
+
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
274 |
+
and cfg.load_4bit
|
275 |
+
):
|
276 |
# llama is PROBABLY model parallelizable, but the default isn't that it is
|
277 |
# so let's only set it for the 4bit, see
|
278 |
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
|
|
|
285 |
requires_grad.append(f"{name}: {param.requires_grad}")
|
286 |
if len(requires_grad) == 0:
|
287 |
logging.warning("there are no parameters that require gradient updates")
|
288 |
+
model.config.use_cache = False
|
289 |
|
290 |
# TODO resume_from_checkpoint handling
|
291 |
return model, tokenizer, lora_config
|
|
|
296 |
|
297 |
if adapter is None:
|
298 |
return model, None
|
299 |
+
if adapter == "lora" or adapter == "qlora":
|
300 |
return load_lora(model, cfg)
|
301 |
if adapter == "llama-adapter":
|
302 |
return load_llama_adapter(model, cfg)
|
|
|
318 |
task_type="CAUSAL_LM",
|
319 |
)
|
320 |
|
321 |
+
if cfg.lora_model_dir:
|
322 |
+
logging.info("Loading pretained LORA")
|
323 |
model = PeftModel.from_pretrained(
|
324 |
model,
|
325 |
cfg.lora_model_dir,
|
|
|
361 |
model,
|
362 |
cfg.lora_model_dir,
|
363 |
device_map=cfg.device_map,
|
364 |
+
# torch_dtype=torch.float16,
|
365 |
)
|
366 |
else:
|
367 |
model = get_peft_model(model, lora_config)
|
src/axolotl/utils/trainer.py
CHANGED
@@ -9,13 +9,33 @@ import torch.cuda
|
|
9 |
import transformers
|
10 |
from torch import nn
|
11 |
from torch.optim.lr_scheduler import OneCycleLR
|
12 |
-
from transformers import EarlyStoppingCallback
|
13 |
from transformers.trainer_pt_utils import get_parameter_names
|
14 |
|
15 |
from axolotl.utils.schedulers import InterpolatingLogScheduler
|
16 |
from axolotl.utils.callbacks import SavePeftModelCallback
|
17 |
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
20 |
total_num_steps = int(
|
21 |
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
@@ -38,6 +58,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
38 |
training_arguments_kwargs["bf16_full_eval"] = True
|
39 |
else:
|
40 |
training_arguments_kwargs["bf16"] = cfg.bf16
|
|
|
41 |
training_arguments_kwargs["tf32"] = cfg.tf32
|
42 |
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
43 |
training_arguments_kwargs["logging_steps"] = logging_steps
|
@@ -119,6 +140,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
119 |
cfg.optimizer == "adamw_bnb_8bit"
|
120 |
and not cfg.load_4bit
|
121 |
and not "deepspeed" in training_arguments_kwargs
|
|
|
122 |
):
|
123 |
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
124 |
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
@@ -157,7 +179,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
157 |
cfg.learning_rate,
|
158 |
total_steps=total_num_steps,
|
159 |
epochs=cfg.num_epochs,
|
160 |
-
div_factor=
|
161 |
**lr_scheduler_kwargs,
|
162 |
)
|
163 |
elif cfg.lr_scheduler == "log_sweep":
|
@@ -182,8 +204,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
182 |
cfg.early_stopping_patience,
|
183 |
)
|
184 |
callbacks.append(early_stop_cb)
|
185 |
-
|
186 |
-
if cfg.local_rank == 0 and cfg.adapter ==
|
187 |
callbacks.append(SavePeftModelCallback)
|
188 |
|
189 |
data_collator_kwargs = {
|
@@ -194,7 +216,12 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
194 |
else:
|
195 |
data_collator_kwargs["pad_to_multiple_of"] = 8
|
196 |
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
198 |
model=model,
|
199 |
train_dataset=train_dataset,
|
200 |
eval_dataset=eval_dataset,
|
|
|
9 |
import transformers
|
10 |
from torch import nn
|
11 |
from torch.optim.lr_scheduler import OneCycleLR
|
12 |
+
from transformers import EarlyStoppingCallback, Trainer
|
13 |
from transformers.trainer_pt_utils import get_parameter_names
|
14 |
|
15 |
from axolotl.utils.schedulers import InterpolatingLogScheduler
|
16 |
from axolotl.utils.callbacks import SavePeftModelCallback
|
17 |
|
18 |
|
19 |
+
class OneCycleLRSchedulerTrainer(Trainer):
|
20 |
+
def create_scheduler(
|
21 |
+
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
22 |
+
):
|
23 |
+
optimizer = self.optimizer if optimizer is None else optimizer
|
24 |
+
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
25 |
+
num_training_steps = num_training_steps
|
26 |
+
pct_start = num_warmup_steps / num_training_steps
|
27 |
+
|
28 |
+
self.lr_scheduler = OneCycleLR(
|
29 |
+
optimizer,
|
30 |
+
max_lr=self.args.learning_rate,
|
31 |
+
total_steps=num_training_steps,
|
32 |
+
pct_start=pct_start,
|
33 |
+
div_factor=6,
|
34 |
+
)
|
35 |
+
|
36 |
+
return self.lr_scheduler
|
37 |
+
|
38 |
+
|
39 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
40 |
total_num_steps = int(
|
41 |
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
|
|
58 |
training_arguments_kwargs["bf16_full_eval"] = True
|
59 |
else:
|
60 |
training_arguments_kwargs["bf16"] = cfg.bf16
|
61 |
+
training_arguments_kwargs["fp16"] = True if cfg.fp16 and not cfg.bf16 else False
|
62 |
training_arguments_kwargs["tf32"] = cfg.tf32
|
63 |
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
64 |
training_arguments_kwargs["logging_steps"] = logging_steps
|
|
|
140 |
cfg.optimizer == "adamw_bnb_8bit"
|
141 |
and not cfg.load_4bit
|
142 |
and not "deepspeed" in training_arguments_kwargs
|
143 |
+
and not cfg.fsdp
|
144 |
):
|
145 |
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
146 |
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
|
|
179 |
cfg.learning_rate,
|
180 |
total_steps=total_num_steps,
|
181 |
epochs=cfg.num_epochs,
|
182 |
+
div_factor=cfg.lr_div_factor if cfg.lr_div_factor else 6,
|
183 |
**lr_scheduler_kwargs,
|
184 |
)
|
185 |
elif cfg.lr_scheduler == "log_sweep":
|
|
|
204 |
cfg.early_stopping_patience,
|
205 |
)
|
206 |
callbacks.append(early_stop_cb)
|
207 |
+
|
208 |
+
if cfg.local_rank == 0 and cfg.adapter == "lora": # only save in rank 0
|
209 |
callbacks.append(SavePeftModelCallback)
|
210 |
|
211 |
data_collator_kwargs = {
|
|
|
216 |
else:
|
217 |
data_collator_kwargs["pad_to_multiple_of"] = 8
|
218 |
|
219 |
+
trainer_cls = (
|
220 |
+
OneCycleLRSchedulerTrainer
|
221 |
+
if cfg.lr_scheduler == "one_cycle" and cfg.fsdp
|
222 |
+
else transformers.Trainer
|
223 |
+
)
|
224 |
+
trainer = trainer_cls(
|
225 |
model=model,
|
226 |
train_dataset=train_dataset,
|
227 |
eval_dataset=eval_dataset,
|