Nanobit winglian commited on
Commit
043c386
·
unverified ·
1 Parent(s): 0f10080

fix: `train_on_inputs: true` ignored for sharegpt (#1045) [skip ci]

Browse files

* fix: `train_on_inputs: true` ignored for sharegpt

* enable unit test for train_on_inputs for sharegpt

---------

Co-authored-by: Wing Lian <[email protected]>

src/axolotl/prompt_tokenizers.py CHANGED
@@ -379,10 +379,12 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
379
  add_eos_token=False,
380
  strip_bos_token=True,
381
  )
382
- # everything from this is masked out from the labels
383
- labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
 
 
 
384
  elif assistant in role:
385
- # TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
386
  role = (
387
  role.replace(role_remap[1]["from"], role_remap[1]["to"])
388
  if role_remap
@@ -406,18 +408,24 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
406
  add_eos_token=False,
407
  strip_bos_token=True,
408
  )
409
- # not masked out from labels
410
  labels = copy.deepcopy(res["input_ids"])
411
- len_role = len(role_res["input_ids"])
412
- labels[:len_role] = [IGNORE_TOKEN_ID] * min(len_role, len(labels))
 
 
 
 
413
  elif role == "":
414
  turn = content
415
  # this is only ever the first part, should include the bos token and the user query
416
  res = self._tokenize(
417
  turn, add_eos_token=False, strip_bos_token=False
418
  )
419
- # everything from this is masked out from the labels
420
- labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
 
 
 
421
  else:
422
  LOG.warning(f"unhandled role: {role}")
423
  continue
 
379
  add_eos_token=False,
380
  strip_bos_token=True,
381
  )
382
+ if self.train_on_inputs:
383
+ labels = copy.deepcopy(res["input_ids"])
384
+ else:
385
+ # everything from this is masked out from the labels
386
+ labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
387
  elif assistant in role:
 
388
  role = (
389
  role.replace(role_remap[1]["from"], role_remap[1]["to"])
390
  if role_remap
 
408
  add_eos_token=False,
409
  strip_bos_token=True,
410
  )
 
411
  labels = copy.deepcopy(res["input_ids"])
412
+ if not self.train_on_inputs:
413
+ # mask out role tokens from the labels
414
+ len_role = len(role_res["input_ids"])
415
+ labels[:len_role] = [IGNORE_TOKEN_ID] * min(
416
+ len_role, len(labels)
417
+ )
418
  elif role == "":
419
  turn = content
420
  # this is only ever the first part, should include the bos token and the user query
421
  res = self._tokenize(
422
  turn, add_eos_token=False, strip_bos_token=False
423
  )
424
+ if self.train_on_inputs:
425
+ labels = copy.deepcopy(res["input_ids"])
426
+ else:
427
+ # everything from this is masked out from the labels
428
+ labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
429
  else:
430
  LOG.warning(f"unhandled role: {role}")
431
  continue
tests/prompt_strategies/test_sharegpt.py CHANGED
@@ -104,7 +104,7 @@ class TestSharegpt:
104
  role_key_human=None,
105
  ),
106
  tokenizer,
107
- True, # train_on_inputs
108
  2048, # sequence_len
109
  )
110
 
@@ -124,30 +124,30 @@ class TestSharegpt:
124
  ]
125
  # fmt: on
126
 
127
- # def test_no_train_on_input(self, sharegpt_dataset, tokenizer):
128
- # strategy = SimpleShareGPTPromptTokenizingStrategy(
129
- # ShareGPTPrompterV2(
130
- # conversation="chatml",
131
- # role_key_model=None,
132
- # role_key_human=None,
133
- # ),
134
- # tokenizer,
135
- # False, # train_on_inputs
136
- # 2048, # sequence_len
137
- # )
138
- #
139
- # dataset_wrapper = TokenizedPromptDataset(
140
- # strategy, sharegpt_dataset, process_count=1
141
- # )
142
- #
143
- # labels = dataset_wrapper[0]["labels"]
144
- # # fmt: off
145
- # assert labels == [
146
- # 1, # bos
147
- # 32001, 1587, 13, 25997, 32000, 28705, 13, # system
148
- # 32001, 2188, 13, 21558, 32000, 28705, 13, # human
149
- # 32001, 13892, 13, 21558, 32000, 28705, 13, # gpt
150
- # 32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human
151
- # 32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
152
- # ]
153
- # # fmt: on
 
104
  role_key_human=None,
105
  ),
106
  tokenizer,
107
+ False, # train_on_inputs
108
  2048, # sequence_len
109
  )
110
 
 
124
  ]
125
  # fmt: on
126
 
127
+ def test_no_train_on_input(self, sharegpt_dataset, tokenizer):
128
+ strategy = SimpleShareGPTPromptTokenizingStrategy(
129
+ ShareGPTPrompterV2(
130
+ conversation="chatml",
131
+ role_key_model=None,
132
+ role_key_human=None,
133
+ ),
134
+ tokenizer,
135
+ True, # train_on_inputs
136
+ 2048, # sequence_len
137
+ )
138
+
139
+ dataset_wrapper = TokenizedPromptDataset(
140
+ strategy, sharegpt_dataset, process_count=1
141
+ )
142
+
143
+ labels = dataset_wrapper[0]["labels"]
144
+ # fmt: off
145
+ assert labels == [
146
+ 1, # bos
147
+ 32001, 1587, 13, 25997, 32000, 28705, 13, # system
148
+ 32001, 2188, 13, 21558, 32000, 28705, 13, # human
149
+ 32001, 13892, 13, 21558, 32000, 28705, 13, # gpt
150
+ 32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human
151
+ 32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
152
+ ]
153
+ # fmt: on