fix(tokenizer): handle fast tokenizer properly for bos/eos (#914)
Browse files- src/axolotl/utils/models.py +18 -0
src/axolotl/utils/models.py
CHANGED
@@ -92,6 +92,7 @@ def load_tokenizer(cfg):
|
|
92 |
"LlamaTokenizer",
|
93 |
"LlamaTokenizerFast",
|
94 |
"CodeLlamaTokenizer",
|
|
|
95 |
]
|
96 |
and hasattr(tokenizer, "pad_token")
|
97 |
and not tokenizer.pad_token
|
@@ -124,6 +125,23 @@ def load_tokenizer(cfg):
|
|
124 |
tokenizer.add_special_tokens(
|
125 |
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
|
126 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
if cfg.tokens:
|
128 |
tokenizer.add_tokens(
|
129 |
[
|
|
|
92 |
"LlamaTokenizer",
|
93 |
"LlamaTokenizerFast",
|
94 |
"CodeLlamaTokenizer",
|
95 |
+
"CodeLlamaTokenizerFast",
|
96 |
]
|
97 |
and hasattr(tokenizer, "pad_token")
|
98 |
and not tokenizer.pad_token
|
|
|
125 |
tokenizer.add_special_tokens(
|
126 |
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
|
127 |
)
|
128 |
+
|
129 |
+
# If we add bos_token and eos_token, we need to update the post processor to
|
130 |
+
# handle them correctly.
|
131 |
+
# https://github.com/huggingface/transformers/pull/24132
|
132 |
+
bos_or_eos_in_special_tokens = (
|
133 |
+
"bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens
|
134 |
+
)
|
135 |
+
if (
|
136 |
+
tokenizer.__class__.__name__
|
137 |
+
in (
|
138 |
+
"LlamaTokenizerFast",
|
139 |
+
"CodeLlamaTokenizerFast",
|
140 |
+
)
|
141 |
+
and bos_or_eos_in_special_tokens
|
142 |
+
):
|
143 |
+
tokenizer.update_post_processor()
|
144 |
+
|
145 |
if cfg.tokens:
|
146 |
tokenizer.add_tokens(
|
147 |
[
|