File size: 4,076 Bytes
2809f3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f74464
 
2809f3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import copy
import logging
from collections import defaultdict
from typing import Generator

from axolotl.prompt_tokenizers import PromptTokenizingStrategy

IGNORE_TOKEN_ID = -100


class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
    bot_prefix_token_ids = []

    def __init__(self, prompter, tokenizer, *args, **kwargs):
        super().__init__(prompter, tokenizer)
        res = self._tokenize("<|model|>", add_eos_token=False, strip_bos_token=True)
        self.bot_prefix_token_ids = res["input_ids"]

    def tokenize_prompt(self, prompt):
        result = {
            "input_ids": [],
            "attention_mask": [],
            "labels": [],
        }
        current_len = 0
        for i, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
            role, message = part
            if role == "system":
                prefix = "<|system|>"
                # this should include a bos token, no eos token, strip trailing "\n<START>"
                if message.endswith("\n<START>"):
                    message = message[:-8]
                res = self._tokenize(prefix + "Persona: " + message.strip(), add_eos_token=False, strip_bos_token=False)
                # everything from this is masked out from the labels
                labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"])
            elif role == "human":
                prefix = "<|user|>"
                res = self._tokenize(prefix + " " + message.strip(), add_eos_token=False, strip_bos_token=True)
                # everything from this is masked out from the labels
                labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"])
            elif role == "bot":
                prefix = "<|model|>"
                res = self._tokenize(prefix + " " + message.strip(), add_eos_token=True, strip_bos_token=True)
                # mask out the prefix token, rest is not masked out from labels
                # make sure we create the labels first, otherwise we get incorrect lengths
                labels = [ IGNORE_TOKEN_ID ] * len(self.bot_prefix_token_ids) + [*copy.deepcopy(res["input_ids"])][len(self.bot_prefix_token_ids):]
            else:
                logging.warning(f"unknown role in conversation: {role}")
                res = defaultdict(lambda: [])
            input_ids = res["input_ids"]
            input_len = len(input_ids)
            result["input_ids"][current_len : current_len + input_len] = input_ids
            result["attention_mask"][current_len : current_len + input_len] = [
                1 if x != self.tokenizer.pad_token_id else 0
                for x in input_ids
            ]
            result["labels"][current_len : current_len + input_len] = labels
            current_len += input_len
        return result

    def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
        result = self.tokenizer(
            prompt,
            truncation=True,
            max_length=self.sequence_len,
            padding=False,
            return_tensors=None,
        )
        if (
            result["input_ids"][-1] != self.tokenizer.eos_token_id
            and len(result["input_ids"]) < self.sequence_len
            and add_eos_token
        ):
            result["input_ids"].append(self.tokenizer.eos_token_id)
            result["attention_mask"].append(1)

        if (
            result["input_ids"][0] == self.tokenizer.bos_token_id
            and strip_bos_token
        ):
            result["input_ids"] = result["input_ids"][1:]
            result["attention_mask"] = result["attention_mask"][1:]

        result["labels"] = result["input_ids"].copy()
        return result


class PygmalionPrompter:
    def __init__(self, *args, **kwargs):
        pass

    def build_prompt(self, source, *args, **kwargs) -> Generator[str, None, None]:
        for msg in source:
            yield msg["role"], msg["value"]


def load(tokenizer, cfg):
    return PygmalionPromptTokenizingStrategy(
        PygmalionPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
    )