File size: 4,552 Bytes
01c8a33
 
2809f3f
 
 
 
 
 
 
 
 
 
 
01c8a33
 
 
 
2809f3f
 
 
01c8a33
2809f3f
 
 
 
 
 
 
 
 
 
01c8a33
2809f3f
 
 
 
 
 
ce34d64
 
 
 
 
2809f3f
ce34d64
2809f3f
 
ce34d64
 
 
 
 
2809f3f
ce34d64
2809f3f
 
ce34d64
 
 
 
 
2809f3f
0f74464
ce34d64
 
 
2809f3f
 
 
 
 
 
 
ce34d64
2809f3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce34d64
2809f3f
 
 
 
 
 
 
 
01c8a33
 
 
 
2809f3f
 
 
01c8a33
 
 
2809f3f
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""Module containing the PygmalionPromptTokenizingStrategy and PygmalionPrompter class"""

import copy
import logging
from collections import defaultdict
from typing import Generator

from axolotl.prompt_tokenizers import PromptTokenizingStrategy

IGNORE_TOKEN_ID = -100


class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
    """
    Tokenizing strategy for Pygmalion.
    """

    bot_prefix_token_ids = []

    def __init__(self, prompter, tokenizer, *args, **kwargs):
        super().__init__(prompter, tokenizer, *args, **kwargs)
        res = self._tokenize("<|model|>", add_eos_token=False, strip_bos_token=True)
        self.bot_prefix_token_ids = res["input_ids"]

    def tokenize_prompt(self, prompt):
        result = {
            "input_ids": [],
            "attention_mask": [],
            "labels": [],
        }
        current_len = 0
        for _, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
            role, message = part
            if role == "system":
                prefix = "<|system|>"
                # this should include a bos token, no eos token, strip trailing "\n<START>"
                if message.endswith("\n<START>"):
                    message = message[:-8]
                res = self._tokenize(
                    prefix + "Persona: " + message.strip(),
                    add_eos_token=False,
                    strip_bos_token=False,
                )
                # everything from this is masked out from the labels
                labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
            elif role == "human":
                prefix = "<|user|>"
                res = self._tokenize(
                    prefix + " " + message.strip(),
                    add_eos_token=False,
                    strip_bos_token=True,
                )
                # everything from this is masked out from the labels
                labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
            elif role == "bot":
                prefix = "<|model|>"
                res = self._tokenize(
                    prefix + " " + message.strip(),
                    add_eos_token=True,
                    strip_bos_token=True,
                )
                # mask out the prefix token, rest is not masked out from labels
                # make sure we create the labels first, otherwise we get incorrect lengths
                labels = [IGNORE_TOKEN_ID] * len(self.bot_prefix_token_ids) + [
                    *copy.deepcopy(res["input_ids"])
                ][len(self.bot_prefix_token_ids) :]
            else:
                logging.warning(f"unknown role in conversation: {role}")
                res = defaultdict(lambda: [])
            input_ids = res["input_ids"]
            input_len = len(input_ids)
            result["input_ids"][current_len : current_len + input_len] = input_ids
            result["attention_mask"][current_len : current_len + input_len] = [
                1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids
            ]
            result["labels"][current_len : current_len + input_len] = labels
            current_len += input_len
        return result

    def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
        result = self.tokenizer(
            prompt,
            truncation=True,
            max_length=self.sequence_len,
            padding=False,
            return_tensors=None,
        )
        if (
            result["input_ids"][-1] != self.tokenizer.eos_token_id
            and len(result["input_ids"]) < self.sequence_len
            and add_eos_token
        ):
            result["input_ids"].append(self.tokenizer.eos_token_id)
            result["attention_mask"].append(1)

        if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
            result["input_ids"] = result["input_ids"][1:]
            result["attention_mask"] = result["attention_mask"][1:]

        result["labels"] = result["input_ids"].copy()
        return result


class PygmalionPrompter:
    """
    Prompter for Pygmalion.
    """

    def __init__(self, *args, **kwargs):
        pass

    def build_prompt(
        self, source, *args, **kwargs  # pylint: disable=unused-argument
    ) -> Generator[str, None, None]:
        for msg in source:
            yield msg["role"], msg["value"]


def load(tokenizer, cfg):
    return PygmalionPromptTokenizingStrategy(
        PygmalionPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
    )