fix: `train_on_inputs: true` ignored for sharegpt (#1045) [skip ci]
Browse files* fix: `train_on_inputs: true` ignored for sharegpt
* enable unit test for train_on_inputs for sharegpt
---------
Co-authored-by: Wing Lian <[email protected]>
src/axolotl/prompt_tokenizers.py
CHANGED
@@ -379,10 +379,12 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
379 |
add_eos_token=False,
|
380 |
strip_bos_token=True,
|
381 |
)
|
382 |
-
|
383 |
-
|
|
|
|
|
|
|
384 |
elif assistant in role:
|
385 |
-
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
386 |
role = (
|
387 |
role.replace(role_remap[1]["from"], role_remap[1]["to"])
|
388 |
if role_remap
|
@@ -406,18 +408,24 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
406 |
add_eos_token=False,
|
407 |
strip_bos_token=True,
|
408 |
)
|
409 |
-
# not masked out from labels
|
410 |
labels = copy.deepcopy(res["input_ids"])
|
411 |
-
|
412 |
-
|
|
|
|
|
|
|
|
|
413 |
elif role == "":
|
414 |
turn = content
|
415 |
# this is only ever the first part, should include the bos token and the user query
|
416 |
res = self._tokenize(
|
417 |
turn, add_eos_token=False, strip_bos_token=False
|
418 |
)
|
419 |
-
|
420 |
-
|
|
|
|
|
|
|
421 |
else:
|
422 |
LOG.warning(f"unhandled role: {role}")
|
423 |
continue
|
|
|
379 |
add_eos_token=False,
|
380 |
strip_bos_token=True,
|
381 |
)
|
382 |
+
if self.train_on_inputs:
|
383 |
+
labels = copy.deepcopy(res["input_ids"])
|
384 |
+
else:
|
385 |
+
# everything from this is masked out from the labels
|
386 |
+
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
387 |
elif assistant in role:
|
|
|
388 |
role = (
|
389 |
role.replace(role_remap[1]["from"], role_remap[1]["to"])
|
390 |
if role_remap
|
|
|
408 |
add_eos_token=False,
|
409 |
strip_bos_token=True,
|
410 |
)
|
|
|
411 |
labels = copy.deepcopy(res["input_ids"])
|
412 |
+
if not self.train_on_inputs:
|
413 |
+
# mask out role tokens from the labels
|
414 |
+
len_role = len(role_res["input_ids"])
|
415 |
+
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
|
416 |
+
len_role, len(labels)
|
417 |
+
)
|
418 |
elif role == "":
|
419 |
turn = content
|
420 |
# this is only ever the first part, should include the bos token and the user query
|
421 |
res = self._tokenize(
|
422 |
turn, add_eos_token=False, strip_bos_token=False
|
423 |
)
|
424 |
+
if self.train_on_inputs:
|
425 |
+
labels = copy.deepcopy(res["input_ids"])
|
426 |
+
else:
|
427 |
+
# everything from this is masked out from the labels
|
428 |
+
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
429 |
else:
|
430 |
LOG.warning(f"unhandled role: {role}")
|
431 |
continue
|
tests/prompt_strategies/test_sharegpt.py
CHANGED
@@ -104,7 +104,7 @@ class TestSharegpt:
|
|
104 |
role_key_human=None,
|
105 |
),
|
106 |
tokenizer,
|
107 |
-
|
108 |
2048, # sequence_len
|
109 |
)
|
110 |
|
@@ -124,30 +124,30 @@ class TestSharegpt:
|
|
124 |
]
|
125 |
# fmt: on
|
126 |
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
|
|
104 |
role_key_human=None,
|
105 |
),
|
106 |
tokenizer,
|
107 |
+
False, # train_on_inputs
|
108 |
2048, # sequence_len
|
109 |
)
|
110 |
|
|
|
124 |
]
|
125 |
# fmt: on
|
126 |
|
127 |
+
def test_no_train_on_input(self, sharegpt_dataset, tokenizer):
|
128 |
+
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
129 |
+
ShareGPTPrompterV2(
|
130 |
+
conversation="chatml",
|
131 |
+
role_key_model=None,
|
132 |
+
role_key_human=None,
|
133 |
+
),
|
134 |
+
tokenizer,
|
135 |
+
True, # train_on_inputs
|
136 |
+
2048, # sequence_len
|
137 |
+
)
|
138 |
+
|
139 |
+
dataset_wrapper = TokenizedPromptDataset(
|
140 |
+
strategy, sharegpt_dataset, process_count=1
|
141 |
+
)
|
142 |
+
|
143 |
+
labels = dataset_wrapper[0]["labels"]
|
144 |
+
# fmt: off
|
145 |
+
assert labels == [
|
146 |
+
1, # bos
|
147 |
+
32001, 1587, 13, 25997, 32000, 28705, 13, # system
|
148 |
+
32001, 2188, 13, 21558, 32000, 28705, 13, # human
|
149 |
+
32001, 13892, 13, 21558, 32000, 28705, 13, # gpt
|
150 |
+
32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human
|
151 |
+
32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
|
152 |
+
]
|
153 |
+
# fmt: on
|