File size: 13,022 Bytes
cb9e677
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
import logging
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Union

from mistral_common.protocol.instruct.messages import (
    FinetuningAssistantMessage,
    Roles,
    SystemMessage,
    ToolMessage,
    UserMessage,
)
from mistral_common.protocol.instruct.tool_calls import (
    Function,
    FunctionCall,
    Tool,
    ToolCall,
)
from mistral_common.protocol.instruct.validator import (
    MistralRequestValidatorV3,
    ValidationMode,
)
from mistral_common.tokens.instruct.request import InstructRequest
from mistral_common.tokens.tokenizers.base import Tokenizer
from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase

from .exceptions import (
    ConversationFormatError,
    FunctionFormatError,
    MessageFormatError,
    ToolCallFormatError,
    UnrecognizedRoleError,
)

logger = logging.getLogger("tokenize")

Sequence = List[int]
Mask = List[bool]


class TrainingInstructSample(InstructRequest):
    available_tools: Optional[List[Tool]] = None
    only_last: bool = False


@dataclass()
class TokenSample:
    tokens: Sequence
    masks: Mask


class SampleType(str, Enum):
    PRETRAIN = "pretrain"
    INSTRUCT = "instruct"


def encode(
    data: Dict[str, Any],
    instruct_tokenizer: InstructTokenizerBase,
    as_type: SampleType,
) -> TokenSample:
    sample: Union[str, TrainingInstructSample]
    if as_type == SampleType.PRETRAIN:
        sample = get_pretrain_sample(data)
    elif as_type == SampleType.INSTRUCT:
        sample = build_instruct_sample(data)

    return tokenize(sample=sample, instruct_tokenizer=instruct_tokenizer)


def get_pretrain_sample(data: Dict[str, Any]) -> str:
    content_keys = ["text", "content"]
    assert not all(
        k in data for k in content_keys
    ), "Make sure to have either 'text' or 'content' in your data. Not both."
    assert any(
        data.get(k) is not None for k in content_keys
    ), f"Must have one of 'text' or 'content' in your data. Only have {data.keys()}"

    # get first non-None value
    sample = None
    for key in content_keys:
        sample = data[key] if key in data else sample

    assert isinstance(sample, str), sample

    return sample


def build_instruct_sample(data: Dict[str, Any]) -> TrainingInstructSample:
    messages: List[
        SystemMessage | UserMessage | FinetuningAssistantMessage | ToolMessage
    ] = []
    # optional data fields that might be set
    available_tools: Optional[List[Tool]] = data.get("available_tools")
    system_prompt = data.get("system_prompt")

    messages_keys = ["messages", "interactions"]
    content_keys = ["content", "text"]  # both are accepted
    allowed_roles = [role.value for role in Roles]

    if not any(messages_key in data for messages_key in messages_keys):
        err = f"The conversation does not contain one of '{', '.join(messages_keys)}' key, but only {', '.join(data.keys())}. Make sure that the conversation includes one of '{', '.join(messages_keys)}'."
        raise ConversationFormatError(err, str(data))

    if all(messages_key in data for messages_key in messages_keys):
        err = f"The conversation cannot contain both of '{', '.join(messages_keys)}' key, but only one of the two."
        raise ConversationFormatError(err, str(data))

    # get first non-None value
    data_messages: Optional[List[Dict[str, Any]]] = None
    for key in messages_keys:
        data_messages = data[key] if key in data else data_messages

    assert data_messages is not None, "data_messages can't be None"

    if "available_tools" in data and "tools" in data:
        err = "The conversation contains both an `available_tools` and `tools` key. You can only have one."
        raise ConversationFormatError(err, str(data))

    if data.get("tools", None) is not None and len(data["tools"]) > 0:
        available_tools = _parse_available_tools(data["tools"])
    elif (
        data.get("available_tools", None) is not None
        and len(data["available_tools"]) > 0
    ):
        available_tools = _parse_available_tools(data["available_tools"])

    for data_message in data_messages:
        is_tool_call = data_message.get("tool_calls") is not None

        if "role" not in data_message:
            err = f"A message does not contain a 'role' key, but only {', '.join(data_message.keys())}. Make sure that the message includes the key 'role'."
            raise MessageFormatError(err, str(data))

        role = data_message["role"]

        if all(key in data_message for key in content_keys):
            err = f"A {role} message contains both a 'text' and 'content' key. Make sure that there is only one of the two."
            raise MessageFormatError(err, str(data))

        content: Optional[str] = None
        for key in content_keys:
            content = content if content is not None else data_message.get(key)

        # non-function call message must have content
        if not is_tool_call and content is None:
            err = f"A {role} message does not contain one of '{content_keys}' key, but only {', '.join(data_message.keys())}. Make sure that the message includes one of '{content_keys}' keys."
            raise MessageFormatError(err, str(data))

        if role not in allowed_roles:
            raise UnrecognizedRoleError(role, allowed_roles)

        if data_message["role"] == "user":
            assert content is not None
            messages.append(UserMessage(content=content))
        elif data_message["role"] == "assistant":
            tool_calls: Optional[List[ToolCall]] = None

            if is_tool_call:
                tool_calls = _parse_tool_calls(data_message["tool_calls"])

            weight = data_message.get("weight")
            messages.append(
                FinetuningAssistantMessage(
                    content=content, tool_calls=tool_calls, weight=weight
                )
            )
        elif data_message["role"] == "system":
            if system_prompt is not None:
                err = "Multiple messages with role 'system' encountered. Only one is allowed."
                raise MessageFormatError(err, str(data))

            system_prompt = content
        elif data_message["role"] == "tool":
            assert content is not None
            tool_message = _parse_tool_message(content, data_message)
            messages.append(tool_message)

    # validate created messages
    validator = MistralRequestValidatorV3(ValidationMode.finetuning)
    validator.validate_messages(messages)
    validator._validate_tools(available_tools or [])

    # whether to train only on last assistant message
    only_last = data.get("only_last", False) or available_tools is not None

    return TrainingInstructSample(
        messages=messages,
        system_prompt=system_prompt,
        available_tools=available_tools,
        only_last=only_last,
    )


def _parse_available_tools(tools: List[Dict[str, Any]]) -> List[Tool]:
    available_tools = []
    for tool in tools:
        if "function" not in tool:
            raise FunctionFormatError(
                "A tool dict does not have a 'function' key.", str(tool)
            )

        func_data = tool["function"]

        for key in ["name", "description", "parameters"]:
            if key not in func_data:
                raise FunctionFormatError(
                    f"A function dict does not have a {key} key.", str(func_data)
                )

        if not isinstance(func_data["parameters"], dict):
            raise FunctionFormatError(
                f"A function 'parameters' key has to be of type dict, but is {type(func_data['parameters'])}. If the function has no parameters pass an empyt dict ", str(func_data)
            )

        description = func_data["description"]
        function = Function(
            name=func_data["name"],
            description=description,
            parameters=func_data["parameters"],
        )

    available_tools.append(Tool(function=function))
    return available_tools


def _parse_tool_calls(calls: List[Dict[str, Any]]) -> List[ToolCall]:
    for key in ["id", "function"]:
        if not all(key in call for call in calls):
            err = f"A tool call of an assistant message does not have a {key} key"
            raise ToolCallFormatError(err, str(calls))

    for key in ["name", "arguments"]:
        if not all(key in call["function"] for call in calls):
            err = (
                f"A tool call function of an assistant message does not have a {key} key"
            )
            raise ToolCallFormatError(err, str(calls))

    if not all(isinstance(call["function"]["arguments"], str) for call in calls):
        err = "A tool call function of an assistant message does not have a 'arguments' key of type str"
        raise ToolCallFormatError(err, str(calls))

    tool_calls = [
        ToolCall(
            id=call["id"],
            function=FunctionCall(
                name=call["function"]["name"],
                arguments=call["function"]["arguments"],
            ),
        )
        for call in calls
    ]
    return tool_calls


def _parse_tool_message(content: str, data_message: Dict[str, Any]) -> ToolMessage:
    if "tool_call_id" not in data_message:
        err = f"A tool message does not contain a 'tool_call_id' key, but only {', '.join(data_message.keys())}. Make sure that the message includes the key 'tool_call_id'."
        raise MessageFormatError(err, str(data_message))

    call_id = data_message["tool_call_id"]
    # name is deprecated in v3, but we'll add it nevertheless for now
    name = data_message.get("name")

    return ToolMessage(content=content, tool_call_id=call_id, name=name)


def tokenize(
    sample: Union[str, TrainingInstructSample],
    instruct_tokenizer: InstructTokenizerBase,
) -> TokenSample:
    if isinstance(sample, str):
        tokenizer: Tokenizer = instruct_tokenizer.tokenizer
        return tokenize_pretrain(sample, tokenizer)
    elif isinstance(sample, TrainingInstructSample):
        return tokenize_instruct(sample, instruct_tokenizer)

    raise ValueError(
        f"`sample` has to be either of type `str` or `TrainingInstructSample`, not {type(sample)}."
    )


def tokenize_pretrain(sample: str, tokenizer: Tokenizer) -> TokenSample:
    tokens = tokenizer.encode(sample, bos=True, eos=True)
    masks = [True] * len(tokens)
    return TokenSample(tokens, masks)


def tokenize_instruct(
    sample: TrainingInstructSample,
    instruct_tokenizer: InstructTokenizerBase,
) -> TokenSample:
    tokens: List[int] = instruct_tokenizer.start()
    masks: List[bool] = [False]

    mask_all_but_last = sample.only_last

    # find first and last user message
    user_messages = [
        i for i, msg in enumerate(sample.messages) if isinstance(msg, UserMessage)
    ]
    first_user_idx = user_messages[0] if user_messages else -1
    last_user_idx = user_messages[-1] if user_messages else -1

    for msg_idx, message in enumerate(sample.messages):
        if isinstance(message, UserMessage):
            curr_tokens = instruct_tokenizer.encode_user_message(
                message,
                available_tools=sample.available_tools,
                is_last=msg_idx == last_user_idx,
                is_first=msg_idx == first_user_idx,
                system_prompt=sample.system_prompt,
            )
            curr_masks = [False] * len(curr_tokens)  # only predict bot answers
        elif isinstance(message, ToolMessage):
            curr_tokens = instruct_tokenizer.encode_tool_message(
                message, is_before_last_user_message=msg_idx < last_user_idx
            )
            curr_masks = [False] * len(curr_tokens)  # only predict bot answers
        elif isinstance(message, FinetuningAssistantMessage):
            is_last_message = msg_idx == (len(sample.messages) - 1)

            # we don't want to predict a random call id
            message = maybe_remove_call_id(message, is_last_message=is_last_message)

            curr_tokens = instruct_tokenizer.encode_assistant_message(
                message, is_before_last_user_message=False
            )

            is_weighted = message.weight is None or message.weight == 1
            is_relevant = (not mask_all_but_last) or is_last_message
            if is_weighted and is_relevant:
                curr_masks = [True] * len(curr_tokens)  # only predict bot answers
            else:
                # in function calling we only backprop through last message
                curr_masks = [False] * len(curr_tokens)

        tokens.extend(curr_tokens)
        masks.extend(curr_masks)

    return TokenSample(tokens, masks)


def maybe_remove_call_id(message: FinetuningAssistantMessage, is_last_message: bool):
    if message.tool_calls is None or not is_last_message:
        return message

    # remove call id
    message.tool_calls = [
        ToolCall(function=call.function) for call in message.tool_calls
    ]

    return message