fozziethebeat
commited on
Commit
•
cc11c6b
1
Parent(s):
5f91064
Generalizing the chat_template prompt strategy (#1660) [skip ci]
Browse filesThe strategy now supports configuring several fields: * The data field holding message arrays * the role and
content fields for each message * role mapping from source to target types
additionally this adds a sample llama3-8b instruct template using the chat template
examples/llama-3/instruct-lora-8b.yml
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_model: meta-llama/Meta-Llama-3-8B-Instruct
|
2 |
+
model_type: LlamaForCausalLM
|
3 |
+
tokenizer_type: AutoTokenizer
|
4 |
+
|
5 |
+
load_in_8bit: true
|
6 |
+
load_in_4bit: false
|
7 |
+
strict: false
|
8 |
+
|
9 |
+
chat_template: llama3
|
10 |
+
datasets:
|
11 |
+
- path: fozziethebeat/alpaca_messages_2k_test
|
12 |
+
type: chat_template
|
13 |
+
chat_template: llama3
|
14 |
+
field_messages: messages
|
15 |
+
message_field_role: role
|
16 |
+
message_field_content: content
|
17 |
+
roles:
|
18 |
+
user:
|
19 |
+
- user
|
20 |
+
assistant:
|
21 |
+
- assistant
|
22 |
+
|
23 |
+
dataset_prepared_path:
|
24 |
+
val_set_size: 0.05
|
25 |
+
output_dir: ./outputs/lora-out
|
26 |
+
|
27 |
+
sequence_len: 4096
|
28 |
+
sample_packing: false
|
29 |
+
pad_to_sequence_len: true
|
30 |
+
|
31 |
+
adapter: lora
|
32 |
+
lora_model_dir:
|
33 |
+
lora_r: 32
|
34 |
+
lora_alpha: 16
|
35 |
+
lora_dropout: 0.05
|
36 |
+
lora_target_linear: true
|
37 |
+
lora_fan_in_fan_out:
|
38 |
+
|
39 |
+
wandb_project:
|
40 |
+
wandb_entity:
|
41 |
+
wandb_watch:
|
42 |
+
wandb_name:
|
43 |
+
wandb_log_model:
|
44 |
+
|
45 |
+
gradient_accumulation_steps: 4
|
46 |
+
micro_batch_size: 2
|
47 |
+
num_epochs: 4
|
48 |
+
optimizer: adamw_bnb_8bit
|
49 |
+
lr_scheduler: cosine
|
50 |
+
learning_rate: 0.0002
|
51 |
+
|
52 |
+
train_on_inputs: false
|
53 |
+
group_by_length: false
|
54 |
+
bf16: auto
|
55 |
+
fp16:
|
56 |
+
tf32: false
|
57 |
+
|
58 |
+
gradient_checkpointing: true
|
59 |
+
early_stopping_patience:
|
60 |
+
resume_from_checkpoint:
|
61 |
+
local_rank:
|
62 |
+
logging_steps: 1
|
63 |
+
xformers_attention:
|
64 |
+
flash_attention: true
|
65 |
+
s2_attention:
|
66 |
+
|
67 |
+
warmup_steps: 10
|
68 |
+
evals_per_epoch: 4
|
69 |
+
eval_table_size:
|
70 |
+
eval_max_new_tokens: 128
|
71 |
+
saves_per_epoch: 1
|
72 |
+
debug:
|
73 |
+
deepspeed:
|
74 |
+
weight_decay: 0.0
|
75 |
+
fsdp:
|
76 |
+
fsdp_config:
|
src/axolotl/prompt_strategies/chat_template.py
CHANGED
@@ -1,24 +1,55 @@
|
|
1 |
"""
|
2 |
HF Chat Templates prompt strategy
|
3 |
"""
|
4 |
-
|
|
|
|
|
5 |
|
6 |
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
7 |
from axolotl.prompters import Prompter
|
8 |
from axolotl.utils.chat_templates import chat_templates
|
9 |
|
|
|
|
|
10 |
|
11 |
class ChatTemplatePrompter(Prompter):
|
12 |
"""prompter for HF chat templates"""
|
13 |
|
14 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
self.tokenizer = tokenizer
|
16 |
self.chat_template = chat_template
|
17 |
self.max_length = max_length
|
18 |
|
19 |
def build_prompt(self, conversation, add_generation_prompt=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
return self.tokenizer.apply_chat_template(
|
21 |
-
|
22 |
truncation=True,
|
23 |
max_length=self.max_length,
|
24 |
add_generation_prompt=add_generation_prompt,
|
@@ -31,9 +62,19 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|
31 |
Tokenizing strategy for instruction-based prompts.
|
32 |
"""
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
def tokenize_prompt(self, prompt):
|
35 |
turns = self.get_conversation_thread(prompt)
|
36 |
-
prompt_ids = self.prompter.build_prompt(
|
37 |
input_ids = self.prompter.build_prompt(turns)
|
38 |
|
39 |
if not self.train_on_inputs:
|
@@ -51,28 +92,37 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|
51 |
return tokenized_prompt
|
52 |
|
53 |
def get_conversation_thread(self, prompt):
|
54 |
-
|
55 |
-
# remap roles - allow for assistant turn
|
56 |
-
role_map = {
|
57 |
-
"human": "user",
|
58 |
-
"user": "user",
|
59 |
-
"assistant": "assistant",
|
60 |
-
"gpt": "assistant",
|
61 |
-
}
|
62 |
-
turns = [
|
63 |
-
{"role": role_map[t["from"]], "content": t["value"]} for t in conversations
|
64 |
-
]
|
65 |
-
return turns
|
66 |
|
67 |
|
68 |
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
69 |
chat_template = (
|
70 |
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
|
71 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
strategy = ChatTemplateStrategy(
|
73 |
-
ChatTemplatePrompter(
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
tokenizer,
|
75 |
cfg.train_on_inputs,
|
76 |
cfg.sequence_len,
|
77 |
)
|
|
|
|
|
78 |
return strategy
|
|
|
1 |
"""
|
2 |
HF Chat Templates prompt strategy
|
3 |
"""
|
4 |
+
|
5 |
+
import logging
|
6 |
+
from typing import Any, Dict, List, Optional
|
7 |
|
8 |
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
9 |
from axolotl.prompters import Prompter
|
10 |
from axolotl.utils.chat_templates import chat_templates
|
11 |
|
12 |
+
LOG = logging.getLogger("axolotl")
|
13 |
+
|
14 |
|
15 |
class ChatTemplatePrompter(Prompter):
|
16 |
"""prompter for HF chat templates"""
|
17 |
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
tokenizer,
|
21 |
+
chat_template=None,
|
22 |
+
max_length=2048,
|
23 |
+
message_field_role: str = "from",
|
24 |
+
message_field_content: str = "value",
|
25 |
+
roles: Optional[Dict[str, List[str]]] = None,
|
26 |
+
):
|
27 |
+
if roles:
|
28 |
+
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
29 |
+
else:
|
30 |
+
self.roles = {
|
31 |
+
"human": "user",
|
32 |
+
"user": "user",
|
33 |
+
"assistant": "assistant",
|
34 |
+
"gpt": "assistant",
|
35 |
+
}
|
36 |
+
self.message_field_role = message_field_role
|
37 |
+
self.message_field_content = message_field_content
|
38 |
self.tokenizer = tokenizer
|
39 |
self.chat_template = chat_template
|
40 |
self.max_length = max_length
|
41 |
|
42 |
def build_prompt(self, conversation, add_generation_prompt=False):
|
43 |
+
turns = [
|
44 |
+
{
|
45 |
+
"role": self.roles[t[self.message_field_role]],
|
46 |
+
"content": t[self.message_field_content],
|
47 |
+
}
|
48 |
+
for t in conversation
|
49 |
+
]
|
50 |
+
|
51 |
return self.tokenizer.apply_chat_template(
|
52 |
+
turns,
|
53 |
truncation=True,
|
54 |
max_length=self.max_length,
|
55 |
add_generation_prompt=add_generation_prompt,
|
|
|
62 |
Tokenizing strategy for instruction-based prompts.
|
63 |
"""
|
64 |
|
65 |
+
_messages = "conversations"
|
66 |
+
|
67 |
+
@property
|
68 |
+
def messages(self):
|
69 |
+
return self._messages
|
70 |
+
|
71 |
+
@messages.setter
|
72 |
+
def messages(self, messages):
|
73 |
+
self._messages = messages
|
74 |
+
|
75 |
def tokenize_prompt(self, prompt):
|
76 |
turns = self.get_conversation_thread(prompt)
|
77 |
+
prompt_ids = self.prompter.build_prompt(turns[:-1], add_generation_prompt=True)
|
78 |
input_ids = self.prompter.build_prompt(turns)
|
79 |
|
80 |
if not self.train_on_inputs:
|
|
|
92 |
return tokenized_prompt
|
93 |
|
94 |
def get_conversation_thread(self, prompt):
|
95 |
+
return prompt[self.messages]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
|
98 |
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
99 |
chat_template = (
|
100 |
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
|
101 |
)
|
102 |
+
message_field_role = (
|
103 |
+
ds_cfg["message_field_role"]
|
104 |
+
if ds_cfg and "message_field_role" in ds_cfg
|
105 |
+
else "from"
|
106 |
+
)
|
107 |
+
message_field_content = (
|
108 |
+
ds_cfg["message_field_content"]
|
109 |
+
if ds_cfg and "message_field_content" in ds_cfg
|
110 |
+
else "value"
|
111 |
+
)
|
112 |
+
roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None
|
113 |
+
|
114 |
strategy = ChatTemplateStrategy(
|
115 |
+
ChatTemplatePrompter(
|
116 |
+
tokenizer,
|
117 |
+
chat_templates(chat_template),
|
118 |
+
message_field_role=message_field_role,
|
119 |
+
message_field_content=message_field_content,
|
120 |
+
roles=roles,
|
121 |
+
),
|
122 |
tokenizer,
|
123 |
cfg.train_on_inputs,
|
124 |
cfg.sequence_len,
|
125 |
)
|
126 |
+
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
127 |
+
strategy.messages = ds_cfg["field_messages"]
|
128 |
return strategy
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
@@ -110,6 +110,8 @@ class SFTDataset(BaseModel):
|
|
110 |
field_human: Optional[str] = None
|
111 |
field_model: Optional[str] = None
|
112 |
field_messages: Optional[str] = None
|
|
|
|
|
113 |
|
114 |
roles: Optional[Dict[str, List[str]]] = None
|
115 |
|
|
|
110 |
field_human: Optional[str] = None
|
111 |
field_model: Optional[str] = None
|
112 |
field_messages: Optional[str] = None
|
113 |
+
message_field_role: Optional[str] = None
|
114 |
+
message_field_content: Optional[str] = None
|
115 |
|
116 |
roles: Optional[Dict[str, List[str]]] = None
|
117 |
|
tests/prompt_strategies/test_chat_templates.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
"""
|
2 |
tests for chat_template prompt strategy
|
3 |
"""
|
|
|
4 |
import unittest
|
5 |
|
6 |
import pytest
|
@@ -10,8 +11,39 @@ from transformers import AutoTokenizer
|
|
10 |
from axolotl.prompt_strategies.chat_template import (
|
11 |
ChatTemplatePrompter,
|
12 |
ChatTemplateStrategy,
|
|
|
13 |
)
|
14 |
from axolotl.utils.chat_templates import chat_templates
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
@pytest.fixture(name="sharegpt_dataset")
|
@@ -51,6 +83,87 @@ def fixture_llama3_tokenizer():
|
|
51 |
return tokenizer
|
52 |
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
class TestSharegptChatTemplateLlama3:
|
55 |
"""
|
56 |
Test class for ShareGPT style datasets with llama-3 prompts using the chat_template strategy.
|
|
|
1 |
"""
|
2 |
tests for chat_template prompt strategy
|
3 |
"""
|
4 |
+
|
5 |
import unittest
|
6 |
|
7 |
import pytest
|
|
|
11 |
from axolotl.prompt_strategies.chat_template import (
|
12 |
ChatTemplatePrompter,
|
13 |
ChatTemplateStrategy,
|
14 |
+
load,
|
15 |
)
|
16 |
from axolotl.utils.chat_templates import chat_templates
|
17 |
+
from axolotl.utils.dict import DictDefault
|
18 |
+
|
19 |
+
|
20 |
+
@pytest.fixture(name="assistant_dataset")
|
21 |
+
def fixture_assistant_dataset():
|
22 |
+
# pylint: disable=duplicate-code
|
23 |
+
return Dataset.from_list(
|
24 |
+
[
|
25 |
+
{
|
26 |
+
"messages": [
|
27 |
+
{
|
28 |
+
"role": "user",
|
29 |
+
"content": "hello",
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"role": "assistant",
|
33 |
+
"content": "hello",
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"role": "user",
|
37 |
+
"content": "goodbye",
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"role": "assistant",
|
41 |
+
"content": "goodbye",
|
42 |
+
},
|
43 |
+
]
|
44 |
+
}
|
45 |
+
]
|
46 |
+
)
|
47 |
|
48 |
|
49 |
@pytest.fixture(name="sharegpt_dataset")
|
|
|
83 |
return tokenizer
|
84 |
|
85 |
|
86 |
+
class TestAssistantChatTemplateLlama3:
|
87 |
+
"""
|
88 |
+
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
|
89 |
+
"""
|
90 |
+
|
91 |
+
def test_llama3_load(self, llama3_tokenizer, assistant_dataset):
|
92 |
+
# pylint: disable=duplicate-code
|
93 |
+
strategy = load(
|
94 |
+
llama3_tokenizer,
|
95 |
+
DictDefault(
|
96 |
+
{
|
97 |
+
"train_on_inputs": False,
|
98 |
+
"sequence_len": 512,
|
99 |
+
}
|
100 |
+
),
|
101 |
+
DictDefault(
|
102 |
+
{
|
103 |
+
"chat_template": "llama3",
|
104 |
+
"message_field_role": "role",
|
105 |
+
"message_field_content": "content",
|
106 |
+
"roles": {
|
107 |
+
"user": ["user"],
|
108 |
+
"assistant": ["assistant"],
|
109 |
+
"system": ["system"],
|
110 |
+
},
|
111 |
+
"field_messages": "messages",
|
112 |
+
}
|
113 |
+
),
|
114 |
+
)
|
115 |
+
res = strategy.tokenize_prompt(assistant_dataset[0])
|
116 |
+
input_ids = res["input_ids"]
|
117 |
+
# fmt: off
|
118 |
+
assert input_ids == [
|
119 |
+
128000, # bos
|
120 |
+
128006, 882, 128007, # user header
|
121 |
+
271, 15339, 128009, # user prompt eot
|
122 |
+
128006, 78191, 128007, # assistant header
|
123 |
+
271, 15339, 128009, # assistant response eot
|
124 |
+
128006, 882, 128007,
|
125 |
+
271, 19045, 29474, 128009,
|
126 |
+
128006, 78191, 128007,
|
127 |
+
271, 19045, 29474, 128009,
|
128 |
+
]
|
129 |
+
# fmt: on
|
130 |
+
|
131 |
+
def test_llama3(self, llama3_tokenizer, assistant_dataset):
|
132 |
+
# pylint: disable=duplicate-code
|
133 |
+
strategy = ChatTemplateStrategy(
|
134 |
+
ChatTemplatePrompter(
|
135 |
+
llama3_tokenizer,
|
136 |
+
chat_templates("llama3"),
|
137 |
+
message_field_role="role",
|
138 |
+
message_field_content="content",
|
139 |
+
roles={
|
140 |
+
"user": ["user"],
|
141 |
+
"assistant": ["assistant"],
|
142 |
+
"system": ["system"],
|
143 |
+
},
|
144 |
+
),
|
145 |
+
llama3_tokenizer,
|
146 |
+
False,
|
147 |
+
512,
|
148 |
+
)
|
149 |
+
strategy.messages = "messages"
|
150 |
+
res = strategy.tokenize_prompt(assistant_dataset[0])
|
151 |
+
input_ids = res["input_ids"]
|
152 |
+
# fmt: off
|
153 |
+
assert input_ids == [
|
154 |
+
128000, # bos
|
155 |
+
128006, 882, 128007, # user header
|
156 |
+
271, 15339, 128009, # user prompt eot
|
157 |
+
128006, 78191, 128007, # assistant header
|
158 |
+
271, 15339, 128009, # assistant response eot
|
159 |
+
128006, 882, 128007,
|
160 |
+
271, 19045, 29474, 128009,
|
161 |
+
128006, 78191, 128007,
|
162 |
+
271, 19045, 29474, 128009,
|
163 |
+
]
|
164 |
+
# fmt: on
|
165 |
+
|
166 |
+
|
167 |
class TestSharegptChatTemplateLlama3:
|
168 |
"""
|
169 |
Test class for ShareGPT style datasets with llama-3 prompts using the chat_template strategy.
|