improve: Enhance code readability of prompt_tokenizers.py (#707)
Browse files- src/axolotl/prompt_tokenizers.py +80 -107
src/axolotl/prompt_tokenizers.py
CHANGED
@@ -45,6 +45,8 @@ class PromptTokenizingStrategy(abc.ABC):
|
|
45 |
self.prompter = prompter
|
46 |
self.tokenizer: PreTrainedTokenizer = tokenizer
|
47 |
self.train_on_inputs = train_on_inputs
|
|
|
|
|
48 |
self.sequence_len = sequence_len
|
49 |
self.max_length = sequence_len
|
50 |
|
@@ -59,34 +61,31 @@ class PromptTokenizingStrategy(abc.ABC):
|
|
59 |
def _tokenize(
|
60 |
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
61 |
) -> BatchEncoding:
|
62 |
-
|
63 |
if not prompt:
|
64 |
LOG.warning("Empty text requested for tokenization.")
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
if len(result["input_ids"]) == 0:
|
75 |
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
|
|
|
|
|
76 |
if (
|
77 |
-
|
78 |
-
and result["input_ids"][-1] != self.tokenizer.eos_token_id
|
79 |
and len(result["input_ids"]) < self.max_length
|
80 |
and add_eos_token
|
81 |
):
|
82 |
result["input_ids"].append(self.tokenizer.eos_token_id)
|
83 |
result["attention_mask"].append(1)
|
84 |
|
85 |
-
if
|
86 |
-
len(result["input_ids"]) > 0
|
87 |
-
and result["input_ids"][0] == self.tokenizer.bos_token_id
|
88 |
-
and strip_bos_token
|
89 |
-
):
|
90 |
result["input_ids"] = result["input_ids"][1:]
|
91 |
result["attention_mask"] = result["attention_mask"][1:]
|
92 |
|
@@ -122,7 +121,7 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
122 |
if not self.train_on_inputs:
|
123 |
user_prompt_len = len(tokenized_prompt["input_ids"])
|
124 |
# TODO this could be sped up using numpy array slicing
|
125 |
-
tokenized_prompt["labels"] = [
|
126 |
tokenized_res_prompt = self._tokenize(
|
127 |
response, strip_bos_token=True, add_eos_token=True
|
128 |
)
|
@@ -270,7 +269,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
270 |
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
271 |
# TODO this could be sped up using numpy array slicing
|
272 |
tokenized_full_prompt["labels"] = [
|
273 |
-
|
274 |
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
|
275 |
|
276 |
return tokenized_full_prompt
|
@@ -334,6 +333,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
334 |
return prompt["conversations"]
|
335 |
|
336 |
def tokenize_prompt(self, prompt):
|
|
|
337 |
result, current_len = tokenize_prompt_default()
|
338 |
conversation: Conversation = (
|
339 |
self.prompter._conversation.copy() # pylint: disable=protected-access
|
@@ -355,62 +355,67 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
355 |
for _, part in enumerate(
|
356 |
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
357 |
):
|
358 |
-
if isinstance(part, tuple):
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
|
|
|
|
|
|
|
|
|
|
414 |
|
415 |
# pylint: disable=duplicate-code
|
416 |
result, current_len = parse_tokenized_to_result(
|
@@ -424,38 +429,6 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
424 |
except (KeyError, AssertionError, IndexError) as err:
|
425 |
raise InvalidDataException(str(err)) from err
|
426 |
|
427 |
-
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
428 |
-
if not prompt.strip():
|
429 |
-
LOG.warning("Empty text requested for tokenization.")
|
430 |
-
result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
|
431 |
-
else:
|
432 |
-
result = self.tokenizer(
|
433 |
-
prompt,
|
434 |
-
truncation=True,
|
435 |
-
max_length=self.sequence_len,
|
436 |
-
padding=False,
|
437 |
-
return_tensors=None,
|
438 |
-
)
|
439 |
-
if (
|
440 |
-
len(result["input_ids"]) > 0
|
441 |
-
and result["input_ids"][-1] != self.tokenizer.eos_token_id
|
442 |
-
and len(result["input_ids"]) < self.sequence_len
|
443 |
-
and add_eos_token
|
444 |
-
):
|
445 |
-
result["input_ids"].append(self.tokenizer.eos_token_id)
|
446 |
-
result["attention_mask"].append(1)
|
447 |
-
|
448 |
-
if (
|
449 |
-
len(result["input_ids"]) > 0
|
450 |
-
and result["input_ids"][0] == self.tokenizer.bos_token_id
|
451 |
-
and strip_bos_token
|
452 |
-
):
|
453 |
-
result["input_ids"] = result["input_ids"][1:]
|
454 |
-
result["attention_mask"] = result["attention_mask"][1:]
|
455 |
-
|
456 |
-
result["labels"] = result["input_ids"].copy()
|
457 |
-
return result
|
458 |
-
|
459 |
|
460 |
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
|
461 |
"""
|
|
|
45 |
self.prompter = prompter
|
46 |
self.tokenizer: PreTrainedTokenizer = tokenizer
|
47 |
self.train_on_inputs = train_on_inputs
|
48 |
+
# sequence_len and max_length can be different for CompletionPromptTokenizingStrategy.
|
49 |
+
# TODO: Document how they are different.
|
50 |
self.sequence_len = sequence_len
|
51 |
self.max_length = sequence_len
|
52 |
|
|
|
61 |
def _tokenize(
|
62 |
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
63 |
) -> BatchEncoding:
|
64 |
+
empty = BatchEncoding(data={"input_ids": [], "attention_mask": []})
|
65 |
if not prompt:
|
66 |
LOG.warning("Empty text requested for tokenization.")
|
67 |
+
return empty
|
68 |
+
|
69 |
+
result = self.tokenizer(
|
70 |
+
prompt,
|
71 |
+
truncation=True,
|
72 |
+
max_length=self.max_length,
|
73 |
+
padding=False,
|
74 |
+
return_tensors=None,
|
75 |
+
)
|
76 |
if len(result["input_ids"]) == 0:
|
77 |
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
|
78 |
+
return empty
|
79 |
+
|
80 |
if (
|
81 |
+
result["input_ids"][-1] != self.tokenizer.eos_token_id
|
|
|
82 |
and len(result["input_ids"]) < self.max_length
|
83 |
and add_eos_token
|
84 |
):
|
85 |
result["input_ids"].append(self.tokenizer.eos_token_id)
|
86 |
result["attention_mask"].append(1)
|
87 |
|
88 |
+
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
|
|
|
|
|
|
|
|
89 |
result["input_ids"] = result["input_ids"][1:]
|
90 |
result["attention_mask"] = result["attention_mask"][1:]
|
91 |
|
|
|
121 |
if not self.train_on_inputs:
|
122 |
user_prompt_len = len(tokenized_prompt["input_ids"])
|
123 |
# TODO this could be sped up using numpy array slicing
|
124 |
+
tokenized_prompt["labels"] = [IGNORE_INDEX] * user_prompt_len
|
125 |
tokenized_res_prompt = self._tokenize(
|
126 |
response, strip_bos_token=True, add_eos_token=True
|
127 |
)
|
|
|
269 |
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
270 |
# TODO this could be sped up using numpy array slicing
|
271 |
tokenized_full_prompt["labels"] = [
|
272 |
+
IGNORE_INDEX
|
273 |
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
|
274 |
|
275 |
return tokenized_full_prompt
|
|
|
333 |
return prompt["conversations"]
|
334 |
|
335 |
def tokenize_prompt(self, prompt):
|
336 |
+
# Initial values. We will append to these as we go through the conversation.
|
337 |
result, current_len = tokenize_prompt_default()
|
338 |
conversation: Conversation = (
|
339 |
self.prompter._conversation.copy() # pylint: disable=protected-access
|
|
|
355 |
for _, part in enumerate(
|
356 |
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
357 |
):
|
358 |
+
if not isinstance(part, tuple):
|
359 |
+
LOG.warning(f"expected tuple, got {part}")
|
360 |
+
continue
|
361 |
+
|
362 |
+
user, assistant = conversation.roles
|
363 |
+
role, content = part
|
364 |
+
|
365 |
+
# Uses "in" because role contains extra characters
|
366 |
+
if user in role:
|
367 |
+
role = (
|
368 |
+
role.replace(role_remap[0]["from"], role_remap[0]["to"])
|
369 |
+
if role_remap
|
370 |
+
else role
|
371 |
+
)
|
372 |
+
turn = role + content
|
373 |
+
# this is still the user query, we should
|
374 |
+
if not content.strip():
|
375 |
+
LOG.warning(f"user turn has empty text: {prompt}")
|
376 |
+
res = self._tokenize(
|
377 |
+
turn,
|
378 |
+
add_eos_token=False,
|
379 |
+
strip_bos_token=True,
|
380 |
+
)
|
381 |
+
# everything from this is masked out from the labels
|
382 |
+
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
383 |
+
elif assistant in role:
|
384 |
+
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
385 |
+
role = (
|
386 |
+
role.replace(role_remap[1]["from"], role_remap[1]["to"])
|
387 |
+
if role_remap
|
388 |
+
else role
|
389 |
+
)
|
390 |
+
turn = role + content
|
391 |
+
# this should be the assistant response, should end with an eos token
|
392 |
+
if not content.strip():
|
393 |
+
LOG.warning(f"assistant turn has empty text: {prompt}")
|
394 |
+
res = self._tokenize(
|
395 |
+
turn,
|
396 |
+
add_eos_token=True,
|
397 |
+
strip_bos_token=True,
|
398 |
+
)
|
399 |
+
role_res = self._tokenize(
|
400 |
+
role.rstrip(),
|
401 |
+
add_eos_token=False,
|
402 |
+
strip_bos_token=True,
|
403 |
+
)
|
404 |
+
# not masked out from labels
|
405 |
+
labels = copy.deepcopy(res["input_ids"])
|
406 |
+
len_role = len(role_res["input_ids"])
|
407 |
+
labels[:len_role] = [IGNORE_TOKEN_ID] * min(len_role, len(labels))
|
408 |
+
elif role == "":
|
409 |
+
turn = content
|
410 |
+
# this is only ever the first part, should include the bos token and the user query
|
411 |
+
res = self._tokenize(
|
412 |
+
turn, add_eos_token=False, strip_bos_token=False
|
413 |
+
)
|
414 |
+
# everything from this is masked out from the labels
|
415 |
+
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
416 |
+
else:
|
417 |
+
LOG.warning(f"unhandled role: {role}")
|
418 |
+
continue
|
419 |
|
420 |
# pylint: disable=duplicate-code
|
421 |
result, current_len = parse_tokenized_to_result(
|
|
|
429 |
except (KeyError, AssertionError, IndexError) as err:
|
430 |
raise InvalidDataException(str(err)) from err
|
431 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
432 |
|
433 |
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
|
434 |
"""
|