File size: 4,552 Bytes
01c8a33 2809f3f 01c8a33 2809f3f 01c8a33 2809f3f 01c8a33 2809f3f ce34d64 2809f3f ce34d64 2809f3f ce34d64 2809f3f ce34d64 2809f3f ce34d64 2809f3f 0f74464 ce34d64 2809f3f ce34d64 2809f3f ce34d64 2809f3f 01c8a33 2809f3f 01c8a33 2809f3f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
"""Module containing the PygmalionPromptTokenizingStrategy and PygmalionPrompter class"""
import copy
import logging
from collections import defaultdict
from typing import Generator
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
IGNORE_TOKEN_ID = -100
class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for Pygmalion.
"""
bot_prefix_token_ids = []
def __init__(self, prompter, tokenizer, *args, **kwargs):
super().__init__(prompter, tokenizer, *args, **kwargs)
res = self._tokenize("<|model|>", add_eos_token=False, strip_bos_token=True)
self.bot_prefix_token_ids = res["input_ids"]
def tokenize_prompt(self, prompt):
result = {
"input_ids": [],
"attention_mask": [],
"labels": [],
}
current_len = 0
for _, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
role, message = part
if role == "system":
prefix = "<|system|>"
# this should include a bos token, no eos token, strip trailing "\n<START>"
if message.endswith("\n<START>"):
message = message[:-8]
res = self._tokenize(
prefix + "Persona: " + message.strip(),
add_eos_token=False,
strip_bos_token=False,
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif role == "human":
prefix = "<|user|>"
res = self._tokenize(
prefix + " " + message.strip(),
add_eos_token=False,
strip_bos_token=True,
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif role == "bot":
prefix = "<|model|>"
res = self._tokenize(
prefix + " " + message.strip(),
add_eos_token=True,
strip_bos_token=True,
)
# mask out the prefix token, rest is not masked out from labels
# make sure we create the labels first, otherwise we get incorrect lengths
labels = [IGNORE_TOKEN_ID] * len(self.bot_prefix_token_ids) + [
*copy.deepcopy(res["input_ids"])
][len(self.bot_prefix_token_ids) :]
else:
logging.warning(f"unknown role in conversation: {role}")
res = defaultdict(lambda: [])
input_ids = res["input_ids"]
input_len = len(input_ids)
result["input_ids"][current_len : current_len + input_len] = input_ids
result["attention_mask"][current_len : current_len + input_len] = [
1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids
]
result["labels"][current_len : current_len + input_len] = labels
current_len += input_len
return result
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.sequence_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.sequence_len
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:]
result["labels"] = result["input_ids"].copy()
return result
class PygmalionPrompter:
"""
Prompter for Pygmalion.
"""
def __init__(self, *args, **kwargs):
pass
def build_prompt(
self, source, *args, **kwargs # pylint: disable=unused-argument
) -> Generator[str, None, None]:
for msg in source:
yield msg["role"], msg["value"]
def load(tokenizer, cfg):
return PygmalionPromptTokenizingStrategy(
PygmalionPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
)
|