fozziethebeat commited on
Commit
cc11c6b
1 Parent(s): 5f91064

Generalizing the chat_template prompt strategy (#1660) [skip ci]

Browse files

The strategy now supports configuring several fields: * The data field holding message arrays * the role and
content fields for each message * role mapping from source to target types

additionally this adds a sample llama3-8b instruct template using the chat template

examples/llama-3/instruct-lora-8b.yml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: meta-llama/Meta-Llama-3-8B-Instruct
2
+ model_type: LlamaForCausalLM
3
+ tokenizer_type: AutoTokenizer
4
+
5
+ load_in_8bit: true
6
+ load_in_4bit: false
7
+ strict: false
8
+
9
+ chat_template: llama3
10
+ datasets:
11
+ - path: fozziethebeat/alpaca_messages_2k_test
12
+ type: chat_template
13
+ chat_template: llama3
14
+ field_messages: messages
15
+ message_field_role: role
16
+ message_field_content: content
17
+ roles:
18
+ user:
19
+ - user
20
+ assistant:
21
+ - assistant
22
+
23
+ dataset_prepared_path:
24
+ val_set_size: 0.05
25
+ output_dir: ./outputs/lora-out
26
+
27
+ sequence_len: 4096
28
+ sample_packing: false
29
+ pad_to_sequence_len: true
30
+
31
+ adapter: lora
32
+ lora_model_dir:
33
+ lora_r: 32
34
+ lora_alpha: 16
35
+ lora_dropout: 0.05
36
+ lora_target_linear: true
37
+ lora_fan_in_fan_out:
38
+
39
+ wandb_project:
40
+ wandb_entity:
41
+ wandb_watch:
42
+ wandb_name:
43
+ wandb_log_model:
44
+
45
+ gradient_accumulation_steps: 4
46
+ micro_batch_size: 2
47
+ num_epochs: 4
48
+ optimizer: adamw_bnb_8bit
49
+ lr_scheduler: cosine
50
+ learning_rate: 0.0002
51
+
52
+ train_on_inputs: false
53
+ group_by_length: false
54
+ bf16: auto
55
+ fp16:
56
+ tf32: false
57
+
58
+ gradient_checkpointing: true
59
+ early_stopping_patience:
60
+ resume_from_checkpoint:
61
+ local_rank:
62
+ logging_steps: 1
63
+ xformers_attention:
64
+ flash_attention: true
65
+ s2_attention:
66
+
67
+ warmup_steps: 10
68
+ evals_per_epoch: 4
69
+ eval_table_size:
70
+ eval_max_new_tokens: 128
71
+ saves_per_epoch: 1
72
+ debug:
73
+ deepspeed:
74
+ weight_decay: 0.0
75
+ fsdp:
76
+ fsdp_config:
src/axolotl/prompt_strategies/chat_template.py CHANGED
@@ -1,24 +1,55 @@
1
  """
2
  HF Chat Templates prompt strategy
3
  """
4
- from typing import Any, Dict, Optional
 
 
5
 
6
  from axolotl.prompt_tokenizers import PromptTokenizingStrategy
7
  from axolotl.prompters import Prompter
8
  from axolotl.utils.chat_templates import chat_templates
9
 
 
 
10
 
11
  class ChatTemplatePrompter(Prompter):
12
  """prompter for HF chat templates"""
13
 
14
- def __init__(self, tokenizer, chat_template=None, max_length=2048):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  self.tokenizer = tokenizer
16
  self.chat_template = chat_template
17
  self.max_length = max_length
18
 
19
  def build_prompt(self, conversation, add_generation_prompt=False):
 
 
 
 
 
 
 
 
20
  return self.tokenizer.apply_chat_template(
21
- conversation,
22
  truncation=True,
23
  max_length=self.max_length,
24
  add_generation_prompt=add_generation_prompt,
@@ -31,9 +62,19 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
31
  Tokenizing strategy for instruction-based prompts.
32
  """
33
 
 
 
 
 
 
 
 
 
 
 
34
  def tokenize_prompt(self, prompt):
35
  turns = self.get_conversation_thread(prompt)
36
- prompt_ids = self.prompter.build_prompt([turns[0]], add_generation_prompt=True)
37
  input_ids = self.prompter.build_prompt(turns)
38
 
39
  if not self.train_on_inputs:
@@ -51,28 +92,37 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
51
  return tokenized_prompt
52
 
53
  def get_conversation_thread(self, prompt):
54
- conversations = prompt["conversations"]
55
- # remap roles - allow for assistant turn
56
- role_map = {
57
- "human": "user",
58
- "user": "user",
59
- "assistant": "assistant",
60
- "gpt": "assistant",
61
- }
62
- turns = [
63
- {"role": role_map[t["from"]], "content": t["value"]} for t in conversations
64
- ]
65
- return turns
66
 
67
 
68
  def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
69
  chat_template = (
70
  ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
71
  )
 
 
 
 
 
 
 
 
 
 
 
 
72
  strategy = ChatTemplateStrategy(
73
- ChatTemplatePrompter(tokenizer, chat_templates(chat_template)),
 
 
 
 
 
 
74
  tokenizer,
75
  cfg.train_on_inputs,
76
  cfg.sequence_len,
77
  )
 
 
78
  return strategy
 
1
  """
2
  HF Chat Templates prompt strategy
3
  """
4
+
5
+ import logging
6
+ from typing import Any, Dict, List, Optional
7
 
8
  from axolotl.prompt_tokenizers import PromptTokenizingStrategy
9
  from axolotl.prompters import Prompter
10
  from axolotl.utils.chat_templates import chat_templates
11
 
12
+ LOG = logging.getLogger("axolotl")
13
+
14
 
15
  class ChatTemplatePrompter(Prompter):
16
  """prompter for HF chat templates"""
17
 
18
+ def __init__(
19
+ self,
20
+ tokenizer,
21
+ chat_template=None,
22
+ max_length=2048,
23
+ message_field_role: str = "from",
24
+ message_field_content: str = "value",
25
+ roles: Optional[Dict[str, List[str]]] = None,
26
+ ):
27
+ if roles:
28
+ self.roles = {s: t for t, sources in roles.items() for s in sources}
29
+ else:
30
+ self.roles = {
31
+ "human": "user",
32
+ "user": "user",
33
+ "assistant": "assistant",
34
+ "gpt": "assistant",
35
+ }
36
+ self.message_field_role = message_field_role
37
+ self.message_field_content = message_field_content
38
  self.tokenizer = tokenizer
39
  self.chat_template = chat_template
40
  self.max_length = max_length
41
 
42
  def build_prompt(self, conversation, add_generation_prompt=False):
43
+ turns = [
44
+ {
45
+ "role": self.roles[t[self.message_field_role]],
46
+ "content": t[self.message_field_content],
47
+ }
48
+ for t in conversation
49
+ ]
50
+
51
  return self.tokenizer.apply_chat_template(
52
+ turns,
53
  truncation=True,
54
  max_length=self.max_length,
55
  add_generation_prompt=add_generation_prompt,
 
62
  Tokenizing strategy for instruction-based prompts.
63
  """
64
 
65
+ _messages = "conversations"
66
+
67
+ @property
68
+ def messages(self):
69
+ return self._messages
70
+
71
+ @messages.setter
72
+ def messages(self, messages):
73
+ self._messages = messages
74
+
75
  def tokenize_prompt(self, prompt):
76
  turns = self.get_conversation_thread(prompt)
77
+ prompt_ids = self.prompter.build_prompt(turns[:-1], add_generation_prompt=True)
78
  input_ids = self.prompter.build_prompt(turns)
79
 
80
  if not self.train_on_inputs:
 
92
  return tokenized_prompt
93
 
94
  def get_conversation_thread(self, prompt):
95
+ return prompt[self.messages]
 
 
 
 
 
 
 
 
 
 
 
96
 
97
 
98
  def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
99
  chat_template = (
100
  ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
101
  )
102
+ message_field_role = (
103
+ ds_cfg["message_field_role"]
104
+ if ds_cfg and "message_field_role" in ds_cfg
105
+ else "from"
106
+ )
107
+ message_field_content = (
108
+ ds_cfg["message_field_content"]
109
+ if ds_cfg and "message_field_content" in ds_cfg
110
+ else "value"
111
+ )
112
+ roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None
113
+
114
  strategy = ChatTemplateStrategy(
115
+ ChatTemplatePrompter(
116
+ tokenizer,
117
+ chat_templates(chat_template),
118
+ message_field_role=message_field_role,
119
+ message_field_content=message_field_content,
120
+ roles=roles,
121
+ ),
122
  tokenizer,
123
  cfg.train_on_inputs,
124
  cfg.sequence_len,
125
  )
126
+ if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
127
+ strategy.messages = ds_cfg["field_messages"]
128
  return strategy
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -110,6 +110,8 @@ class SFTDataset(BaseModel):
110
  field_human: Optional[str] = None
111
  field_model: Optional[str] = None
112
  field_messages: Optional[str] = None
 
 
113
 
114
  roles: Optional[Dict[str, List[str]]] = None
115
 
 
110
  field_human: Optional[str] = None
111
  field_model: Optional[str] = None
112
  field_messages: Optional[str] = None
113
+ message_field_role: Optional[str] = None
114
+ message_field_content: Optional[str] = None
115
 
116
  roles: Optional[Dict[str, List[str]]] = None
117
 
tests/prompt_strategies/test_chat_templates.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
  tests for chat_template prompt strategy
3
  """
 
4
  import unittest
5
 
6
  import pytest
@@ -10,8 +11,39 @@ from transformers import AutoTokenizer
10
  from axolotl.prompt_strategies.chat_template import (
11
  ChatTemplatePrompter,
12
  ChatTemplateStrategy,
 
13
  )
14
  from axolotl.utils.chat_templates import chat_templates
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  @pytest.fixture(name="sharegpt_dataset")
@@ -51,6 +83,87 @@ def fixture_llama3_tokenizer():
51
  return tokenizer
52
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  class TestSharegptChatTemplateLlama3:
55
  """
56
  Test class for ShareGPT style datasets with llama-3 prompts using the chat_template strategy.
 
1
  """
2
  tests for chat_template prompt strategy
3
  """
4
+
5
  import unittest
6
 
7
  import pytest
 
11
  from axolotl.prompt_strategies.chat_template import (
12
  ChatTemplatePrompter,
13
  ChatTemplateStrategy,
14
+ load,
15
  )
16
  from axolotl.utils.chat_templates import chat_templates
17
+ from axolotl.utils.dict import DictDefault
18
+
19
+
20
+ @pytest.fixture(name="assistant_dataset")
21
+ def fixture_assistant_dataset():
22
+ # pylint: disable=duplicate-code
23
+ return Dataset.from_list(
24
+ [
25
+ {
26
+ "messages": [
27
+ {
28
+ "role": "user",
29
+ "content": "hello",
30
+ },
31
+ {
32
+ "role": "assistant",
33
+ "content": "hello",
34
+ },
35
+ {
36
+ "role": "user",
37
+ "content": "goodbye",
38
+ },
39
+ {
40
+ "role": "assistant",
41
+ "content": "goodbye",
42
+ },
43
+ ]
44
+ }
45
+ ]
46
+ )
47
 
48
 
49
  @pytest.fixture(name="sharegpt_dataset")
 
83
  return tokenizer
84
 
85
 
86
+ class TestAssistantChatTemplateLlama3:
87
+ """
88
+ Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
89
+ """
90
+
91
+ def test_llama3_load(self, llama3_tokenizer, assistant_dataset):
92
+ # pylint: disable=duplicate-code
93
+ strategy = load(
94
+ llama3_tokenizer,
95
+ DictDefault(
96
+ {
97
+ "train_on_inputs": False,
98
+ "sequence_len": 512,
99
+ }
100
+ ),
101
+ DictDefault(
102
+ {
103
+ "chat_template": "llama3",
104
+ "message_field_role": "role",
105
+ "message_field_content": "content",
106
+ "roles": {
107
+ "user": ["user"],
108
+ "assistant": ["assistant"],
109
+ "system": ["system"],
110
+ },
111
+ "field_messages": "messages",
112
+ }
113
+ ),
114
+ )
115
+ res = strategy.tokenize_prompt(assistant_dataset[0])
116
+ input_ids = res["input_ids"]
117
+ # fmt: off
118
+ assert input_ids == [
119
+ 128000, # bos
120
+ 128006, 882, 128007, # user header
121
+ 271, 15339, 128009, # user prompt eot
122
+ 128006, 78191, 128007, # assistant header
123
+ 271, 15339, 128009, # assistant response eot
124
+ 128006, 882, 128007,
125
+ 271, 19045, 29474, 128009,
126
+ 128006, 78191, 128007,
127
+ 271, 19045, 29474, 128009,
128
+ ]
129
+ # fmt: on
130
+
131
+ def test_llama3(self, llama3_tokenizer, assistant_dataset):
132
+ # pylint: disable=duplicate-code
133
+ strategy = ChatTemplateStrategy(
134
+ ChatTemplatePrompter(
135
+ llama3_tokenizer,
136
+ chat_templates("llama3"),
137
+ message_field_role="role",
138
+ message_field_content="content",
139
+ roles={
140
+ "user": ["user"],
141
+ "assistant": ["assistant"],
142
+ "system": ["system"],
143
+ },
144
+ ),
145
+ llama3_tokenizer,
146
+ False,
147
+ 512,
148
+ )
149
+ strategy.messages = "messages"
150
+ res = strategy.tokenize_prompt(assistant_dataset[0])
151
+ input_ids = res["input_ids"]
152
+ # fmt: off
153
+ assert input_ids == [
154
+ 128000, # bos
155
+ 128006, 882, 128007, # user header
156
+ 271, 15339, 128009, # user prompt eot
157
+ 128006, 78191, 128007, # assistant header
158
+ 271, 15339, 128009, # assistant response eot
159
+ 128006, 882, 128007,
160
+ 271, 19045, 29474, 128009,
161
+ 128006, 78191, 128007,
162
+ 271, 19045, 29474, 128009,
163
+ ]
164
+ # fmt: on
165
+
166
+
167
  class TestSharegptChatTemplateLlama3:
168
  """
169
  Test class for ShareGPT style datasets with llama-3 prompts using the chat_template strategy.