hamel tokestermw commited on
Commit
5ada140
·
unverified ·
1 Parent(s): 712fd27

Fix prompt assembly for llama (#952)

Browse files

* start at index 0

* add test to check for missing turns

* apply black

* Update test_prompt_tokenizers.py

* Update src/axolotl/monkeypatch/fastchat_conversation_turns.py

Co-authored-by: Motoki Wu <[email protected]>

* fix linting

* apply black

* add more tests for llama/sharegpt

* make logic clearer

---------

Co-authored-by: Motoki Wu <[email protected]>

src/axolotl/monkeypatch/fastchat_conversation_turns.py CHANGED
@@ -83,14 +83,21 @@ def get_turns( # pylint: disable=too-many-return-statements
83
  yield role + ":", ""
84
  return
85
  if self.sep_style == SeparatorStyle.LLAMA2:
86
- seps = [self.sep, self.sep2]
87
  if self.system_message:
 
 
 
 
 
 
88
  yield "", system_prompt
89
- else:
90
- yield "", "[INST] "
91
- for i, (role, message) in enumerate(self.messages[1:]):
92
  if message:
93
- yield role + " ", message + seps[i % 2]
 
 
 
 
94
  else:
95
  yield role, ""
96
  return
 
83
  yield role + ":", ""
84
  return
85
  if self.sep_style == SeparatorStyle.LLAMA2:
 
86
  if self.system_message:
87
+ if self.messages:
88
+ # For llama, the system message is incorporated into the first human instruction
89
+ first_role, first_msg = self.messages[0]
90
+ if first_role == self.roles[0]:
91
+ system_prompt += first_msg
92
+ self.messages.pop(0)
93
  yield "", system_prompt
94
+ for i, (role, message) in enumerate(self.messages):
 
 
95
  if message:
96
+ if (i % 2 == 0 and not self.system_message) or (
97
+ i % 2 != 0 and self.system_message
98
+ ):
99
+ role = "<s> " + role
100
+ yield role + " ", message
101
  else:
102
  yield role, ""
103
  return
tests/test_prompt_tokenizers.py CHANGED
@@ -114,6 +114,76 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
114
  in self._caplog.records[0].message
115
  )
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  def test_sharegpt_changes_roles(self):
118
  conversation = {
119
  "roles": ["USER", "CHARACTER"],
 
114
  in self._caplog.records[0].message
115
  )
116
 
117
+ def test_sharegpt_llama(self):
118
+ "Make sure the sharegpt/llama is tokenized and formatted correctly."
119
+ prompter = ShareGPTPrompterV2(conversation="llama-2")
120
+ strat = ShareGPTPromptTokenizingStrategy(
121
+ prompter,
122
+ self.tokenizer,
123
+ False,
124
+ 2048,
125
+ )
126
+
127
+ def tokenize(conv):
128
+ return strat.tokenize_prompt(conv)["input_ids"]
129
+
130
+ def decode(ids):
131
+ return strat.tokenizer.decode(ids)
132
+
133
+ # Multi-turn conversations
134
+ multi_turn_conv = {
135
+ "conversations": [
136
+ {"from": "system", "value": "lorem"},
137
+ {"from": "human", "value": "abc"},
138
+ {"from": "gpt", "value": "ipsum"},
139
+ {"from": "human", "value": "123"},
140
+ {"from": "gpt", "value": "sit"},
141
+ ]
142
+ }
143
+ # fmt: off
144
+ mt_ids = tokenize(multi_turn_conv)
145
+ assert decode(mt_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
146
+ assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
147
+
148
+ # Single-turn conversations
149
+ single_turn_conv = {
150
+ "conversations": [
151
+ {"from": "system", "value": "lorem"},
152
+ {"from": "human", "value": "abc"},
153
+ {"from": "gpt", "value": "ipsum"},
154
+ ]
155
+ }
156
+
157
+ st_ids = tokenize(single_turn_conv)
158
+ assert decode(st_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s>'
159
+ assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
160
+
161
+ # No system message, single-turn
162
+ no_sys_conv = {
163
+ "conversations": [
164
+ {"from": "human", "value": "abc"},
165
+ {"from": "gpt", "value": "ipsum"},
166
+ ]
167
+ }
168
+
169
+ ns_ids = tokenize(no_sys_conv)
170
+ assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
171
+ assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
172
+
173
+ # No system message, multi-turn
174
+ no_sys_mt_conv = {
175
+ "conversations": [
176
+ {"from": "human", "value": "abc"},
177
+ {"from": "gpt", "value": "ipsum"},
178
+ {"from": "human", "value": "123"},
179
+ {"from": "gpt", "value": "sit"},
180
+ ]
181
+ }
182
+ ns_mt_ids = tokenize(no_sys_mt_conv)
183
+ assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
184
+ assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
185
+ # fmt: on
186
+
187
  def test_sharegpt_changes_roles(self):
188
  conversation = {
189
  "roles": ["USER", "CHARACTER"],