winglian commited on
Commit
651b7a3
·
unverified ·
1 Parent(s): 04b978b

fix double eos token for chatml (#1054) [skip ci]

Browse files

* fix double eos token for chatml

* isolate fix to chatml conversation

* fix add special tokens to include rstrip

* add test for train_on_inputs for sharegpt

* don't use rstrip for chatml

src/axolotl/prompt_tokenizers.py CHANGED
@@ -392,9 +392,13 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
392
  # this should be the assistant response, should end with an eos token
393
  if not content.strip():
394
  LOG.warning(f"assistant turn has empty text: {prompt}")
 
 
 
 
395
  res = self._tokenize(
396
  turn,
397
- add_eos_token=True,
398
  strip_bos_token=True,
399
  )
400
  role_res = self._tokenize(
 
392
  # this should be the assistant response, should end with an eos token
393
  if not content.strip():
394
  LOG.warning(f"assistant turn has empty text: {prompt}")
395
+ add_eos_token = not (
396
+ conversation.name == "chatml"
397
+ and conversation.sep == self.tokenizer.eos_token
398
+ )
399
  res = self._tokenize(
400
  turn,
401
+ add_eos_token=add_eos_token,
402
  strip_bos_token=True,
403
  )
404
  role_res = self._tokenize(
tests/prompt_strategies/test_sharegpt.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test module for sharegpt integration w chatml
3
+ """
4
+ import pytest
5
+ from datasets import Dataset
6
+ from tokenizers import AddedToken
7
+ from transformers import AutoTokenizer
8
+
9
+ from axolotl.datasets import TokenizedPromptDataset
10
+ from axolotl.prompt_strategies.sharegpt import SimpleShareGPTPromptTokenizingStrategy
11
+ from axolotl.prompters import ShareGPTPrompterV2
12
+
13
+
14
+ @pytest.fixture(name="sharegpt_dataset")
15
+ def fixture_sharegpt_dataset():
16
+ return Dataset.from_list(
17
+ [
18
+ {
19
+ "conversations": [
20
+ {
21
+ "from": "system",
22
+ "value": "repeat",
23
+ },
24
+ {
25
+ "from": "human",
26
+ "value": "hello",
27
+ },
28
+ {
29
+ "from": "gpt",
30
+ "value": "hello",
31
+ },
32
+ {
33
+ "from": "human",
34
+ "value": "goodbye",
35
+ },
36
+ {
37
+ "from": "gpt",
38
+ "value": "goodbye",
39
+ },
40
+ ]
41
+ }
42
+ ]
43
+ )
44
+
45
+
46
+ @pytest.fixture(name="tokenizer")
47
+ def fixture_tokenizer():
48
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
49
+ tokenizer.add_special_tokens(
50
+ {
51
+ "eos_token": AddedToken(
52
+ "<|im_end|>", rstrip=False, lstrip=False, normalized=False
53
+ )
54
+ }
55
+ )
56
+ tokenizer.add_tokens(
57
+ [
58
+ AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
59
+ ]
60
+ )
61
+
62
+ return tokenizer
63
+
64
+
65
+ class TestSharegpt:
66
+ """
67
+ Test class for sharegpt prompter
68
+ """
69
+
70
+ def test_no_double_im_end(self, sharegpt_dataset, tokenizer):
71
+ strategy = SimpleShareGPTPromptTokenizingStrategy(
72
+ ShareGPTPrompterV2(
73
+ conversation="chatml",
74
+ role_key_model=None,
75
+ role_key_human=None,
76
+ ),
77
+ tokenizer,
78
+ False, # train_on_inputs
79
+ 2048, # sequence_len
80
+ )
81
+
82
+ dataset_wrapper = TokenizedPromptDataset(
83
+ strategy, sharegpt_dataset, process_count=1
84
+ )
85
+
86
+ input_ids = dataset_wrapper[0]["input_ids"]
87
+ # fmt: off
88
+ assert input_ids == [
89
+ # 28705, 13, is " \n"
90
+ 1, # bos
91
+ 32001, 1587, 13, 25997, 32000, 28705, 13, # system
92
+ 32001, 2188, 13, 21558, 32000, 28705, 13, # human
93
+ 32001, 13892, 13, 21558, 32000, 28705, 13, # gpt
94
+ 32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human
95
+ 32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
96
+ ]
97
+ # fmt: on
98
+
99
+ def test_w_train_on_input(self, sharegpt_dataset, tokenizer):
100
+ strategy = SimpleShareGPTPromptTokenizingStrategy(
101
+ ShareGPTPrompterV2(
102
+ conversation="chatml",
103
+ role_key_model=None,
104
+ role_key_human=None,
105
+ ),
106
+ tokenizer,
107
+ True, # train_on_inputs
108
+ 2048, # sequence_len
109
+ )
110
+
111
+ dataset_wrapper = TokenizedPromptDataset(
112
+ strategy, sharegpt_dataset, process_count=1
113
+ )
114
+
115
+ labels = dataset_wrapper[0]["labels"]
116
+ # fmt: off
117
+ assert labels == [
118
+ -100, # bos
119
+ -100, -100, -100, -100, -100, -100, -100, # system
120
+ -100, -100, -100, -100, -100, -100, -100, # human
121
+ -100, -100, 13, 21558, 32000, 28705, 13, # gpt
122
+ -100, -100, -100, -100, -100, -100, -100, -100, # human
123
+ -100, -100, 13, 12684, 17664, 32000, 28705, 13, # gpt
124
+ ]
125
+ # fmt: on
126
+
127
+ # def test_no_train_on_input(self, sharegpt_dataset, tokenizer):
128
+ # strategy = SimpleShareGPTPromptTokenizingStrategy(
129
+ # ShareGPTPrompterV2(
130
+ # conversation="chatml",
131
+ # role_key_model=None,
132
+ # role_key_human=None,
133
+ # ),
134
+ # tokenizer,
135
+ # False, # train_on_inputs
136
+ # 2048, # sequence_len
137
+ # )
138
+ #
139
+ # dataset_wrapper = TokenizedPromptDataset(
140
+ # strategy, sharegpt_dataset, process_count=1
141
+ # )
142
+ #
143
+ # labels = dataset_wrapper[0]["labels"]
144
+ # # fmt: off
145
+ # assert labels == [
146
+ # 1, # bos
147
+ # 32001, 1587, 13, 25997, 32000, 28705, 13, # system
148
+ # 32001, 2188, 13, 21558, 32000, 28705, 13, # human
149
+ # 32001, 13892, 13, 21558, 32000, 28705, 13, # gpt
150
+ # 32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human
151
+ # 32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
152
+ # ]
153
+ # # fmt: on