File size: 9,285 Bytes
2ea70eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7803f09
2ea70eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d1d22f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ea70eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d1d22f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
"""chatml prompt tokenization strategy for ORPO"""
from typing import Any, Dict, Generator, List, Optional, Tuple

from pydantic import BaseModel

from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
from axolotl.prompters import Prompter
from axolotl.utils.chat_templates import chat_templates


class Message(BaseModel):
    """message/turn"""

    role: str
    content: str
    label: Optional[bool] = None


class MessageList(BaseModel):
    """conversation"""

    messages: List[Message]


def load(
    tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, **kwargs
):  # pylint: disable=possibly-unused-variable,unused-argument
    """
    chatml transforms for datasets with system, input, chosen, rejected
    """

    chat_template = chat_templates("chatml")
    if ds_cfg and "chat_template" in ds_cfg:
        chat_template = ds_cfg["chat_template"]
        try:
            chat_template = chat_templates(chat_template)
        except ValueError:
            pass
    tokenizer.chat_template = chat_template

    return ORPOTokenizingStrategy(
        ORPOPrompter(chat_template, tokenizer),
        tokenizer,
        cfg.train_on_inputs,
        cfg.sequence_len,
        dataset_parser=ORPODatasetParsingStrategy(),
    )


class ORPODatasetParsingStrategy:
    """Strategy to parse chosen rejected dataset into messagelist"""

    def get_chosen_conversation_thread(self, prompt) -> MessageList:
        """Dataset structure mappings"""

        messages: List[Message] = []
        if system := prompt.get("system", None):
            messages.append(Message(role="system", content=system, label=False))
        messages.append(Message(role="user", content=prompt["prompt"], label=False))
        messages.append(
            Message(
                role="assistant", content=prompt["chosen"][1]["content"], label=True
            )
        )
        return MessageList(messages=messages)

    def get_rejected_conversation_thread(self, prompt) -> MessageList:
        """Dataset structure mappings"""

        messages: List[Message] = []
        if system := prompt.get("system", None):
            messages.append(Message(role="system", content=system, label=False))
        messages.append(Message(role="user", content=prompt["prompt"], label=False))
        messages.append(
            Message(
                role="assistant", content=prompt["rejected"][1]["content"], label=True
            )
        )
        return MessageList(messages=messages)

    def get_prompt(self, prompt) -> MessageList:
        """Map the data to extract everything up to the last turn"""
        total_msg_len = len(prompt["chosen"])
        total_msg_turns, remainder = divmod(total_msg_len, 2)
        assert remainder == 0, "invalid number of turns"

        messages: List[Message] = []
        if system := prompt.get("system", None):
            messages.append(Message(role="system", content=system, label=False))
        for i in range(total_msg_turns):
            if "prompt" in prompt:
                messages.append(
                    Message(role="user", content=prompt["prompt"], label=False)
                )
            else:
                messages.append(
                    Message(
                        role="user",
                        content=prompt["chosen"][i * 2]["content"],
                        label=False,
                    )
                )
            if i < total_msg_turns - 1:
                messages.append(
                    Message(
                        role="assistant",
                        content=prompt["chosen"][i * 2 + 1]["content"],
                        label=False,
                    )
                )

        return MessageList(messages=messages)

    def get_chosen(self, prompt) -> MessageList:
        res = self.get_prompt(prompt)
        res.messages.append(
            Message(
                role="assistant", content=prompt["chosen"][-1]["content"], label=True
            )
        )
        return res

    def get_rejected(self, prompt) -> MessageList:
        res = self.get_prompt(prompt)
        res.messages.append(
            Message(
                role="assistant", content=prompt["rejected"][-1]["content"], label=True
            )
        )
        return res


class ORPOTokenizingStrategy(PromptTokenizingStrategy):
    """
    rejected_input_ids
    input_ids
    rejected_attention_mask
    attention_mask
    rejected_labels
    labels
    """

    def __init__(
        self,
        *args,
        dataset_parser=None,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.dataset_parser = dataset_parser

    def tokenize_prompt(self, prompt):
        # pass the rejected prompt/row to the Prompter to get the formatted prompt
        prompt_len = 0
        rejected_message_list = self.dataset_parser.get_rejected_conversation_thread(
            prompt
        )
        input_ids = []
        labels = []
        for _, (part, label) in enumerate(
            self.prompter.build_prompt(rejected_message_list)
        ):
            if not part:
                continue
            _input_ids = self.tokenizer.encode(part, add_special_tokens=False)
            prev_idx = len(input_ids)
            input_ids += _input_ids[prev_idx:]
            if label:
                labels += input_ids[prev_idx:]
            else:
                labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)
                prompt_len = len(input_ids)
        # remap the input_ids, attention_mask and labels
        rejected_input_ids = input_ids
        rejected_labels = labels
        # pass the chosen prompt/row to the Prompter to get the formatted prompt
        chosen_message_list = self.dataset_parser.get_chosen_conversation_thread(prompt)
        input_ids = []
        labels = []
        for _, (part, label) in enumerate(
            self.prompter.build_prompt(chosen_message_list)
        ):
            if not part:
                continue
            _input_ids = self.tokenizer.encode(part, add_special_tokens=False)
            prev_idx = len(input_ids)
            input_ids += _input_ids[prev_idx:]
            if label:
                labels += input_ids[prev_idx:]
            else:
                labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)

        return {
            "rejected_input_ids": rejected_input_ids,
            "rejected_labels": rejected_labels,
            "rejected_attention_mask": [1] * len(rejected_labels),
            "input_ids": input_ids,
            "labels": labels,
            "attention_mask": [1] * len(labels),
            "prompt_attention_mask": [1] * prompt_len
            + [0] * (len(labels) - prompt_len),
        }


class ORPOPrompter(Prompter):
    """Single Turn prompter for ORPO"""

    def __init__(self, chat_template, tokenizer):
        self.chat_template = chat_template
        self.tokenizer = tokenizer

    def build_prompt(
        self,
        message_list: MessageList,
    ) -> Generator[Tuple[str, bool], None, None]:
        conversation = []
        for message in message_list.messages:
            conversation.append(message.model_dump())
            if message.role == "system":
                yield self.tokenizer.apply_chat_template(
                    conversation,
                    add_generation_prompt=False,
                    chat_template=self.chat_template,
                    tokenize=False,
                ), False
            if message.role == "user":
                yield self.tokenizer.apply_chat_template(
                    conversation,
                    add_generation_prompt=True,
                    chat_template=self.chat_template,
                    tokenize=False,
                ), False
            if message.role == "assistant":
                yield self.tokenizer.apply_chat_template(
                    conversation,
                    add_generation_prompt=False,
                    chat_template=self.chat_template,
                    tokenize=False,
                ), True


def argilla(cfg, **kwargs):  # pylint: disable=possibly-unused-variable,unused-argument
    dataset_parser = ORPODatasetParsingStrategy()

    chat_template_str = chat_templates(cfg.chat_template)

    def transform_fn(sample, tokenizer=None):
        res = {}

        res["prompt"] = tokenizer.apply_chat_template(
            [msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
            add_generation_prompt=True,
            chat_template=chat_template_str,
            tokenize=False,
        )
        prompt_str_len = len(res["prompt"])
        res["chosen"] = tokenizer.apply_chat_template(
            [msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
            add_generation_prompt=False,
            chat_template=chat_template_str,
            tokenize=False,
        )[prompt_str_len:]
        res["rejected"] = tokenizer.apply_chat_template(
            [msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
            add_generation_prompt=False,
            chat_template=chat_template_str,
            tokenize=False,
        )[prompt_str_len:]

        return res

    return transform_fn