File size: 13,022 Bytes
cb9e677 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 |
import logging
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from mistral_common.protocol.instruct.messages import (
FinetuningAssistantMessage,
Roles,
SystemMessage,
ToolMessage,
UserMessage,
)
from mistral_common.protocol.instruct.tool_calls import (
Function,
FunctionCall,
Tool,
ToolCall,
)
from mistral_common.protocol.instruct.validator import (
MistralRequestValidatorV3,
ValidationMode,
)
from mistral_common.tokens.instruct.request import InstructRequest
from mistral_common.tokens.tokenizers.base import Tokenizer
from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase
from .exceptions import (
ConversationFormatError,
FunctionFormatError,
MessageFormatError,
ToolCallFormatError,
UnrecognizedRoleError,
)
logger = logging.getLogger("tokenize")
Sequence = List[int]
Mask = List[bool]
class TrainingInstructSample(InstructRequest):
available_tools: Optional[List[Tool]] = None
only_last: bool = False
@dataclass()
class TokenSample:
tokens: Sequence
masks: Mask
class SampleType(str, Enum):
PRETRAIN = "pretrain"
INSTRUCT = "instruct"
def encode(
data: Dict[str, Any],
instruct_tokenizer: InstructTokenizerBase,
as_type: SampleType,
) -> TokenSample:
sample: Union[str, TrainingInstructSample]
if as_type == SampleType.PRETRAIN:
sample = get_pretrain_sample(data)
elif as_type == SampleType.INSTRUCT:
sample = build_instruct_sample(data)
return tokenize(sample=sample, instruct_tokenizer=instruct_tokenizer)
def get_pretrain_sample(data: Dict[str, Any]) -> str:
content_keys = ["text", "content"]
assert not all(
k in data for k in content_keys
), "Make sure to have either 'text' or 'content' in your data. Not both."
assert any(
data.get(k) is not None for k in content_keys
), f"Must have one of 'text' or 'content' in your data. Only have {data.keys()}"
# get first non-None value
sample = None
for key in content_keys:
sample = data[key] if key in data else sample
assert isinstance(sample, str), sample
return sample
def build_instruct_sample(data: Dict[str, Any]) -> TrainingInstructSample:
messages: List[
SystemMessage | UserMessage | FinetuningAssistantMessage | ToolMessage
] = []
# optional data fields that might be set
available_tools: Optional[List[Tool]] = data.get("available_tools")
system_prompt = data.get("system_prompt")
messages_keys = ["messages", "interactions"]
content_keys = ["content", "text"] # both are accepted
allowed_roles = [role.value for role in Roles]
if not any(messages_key in data for messages_key in messages_keys):
err = f"The conversation does not contain one of '{', '.join(messages_keys)}' key, but only {', '.join(data.keys())}. Make sure that the conversation includes one of '{', '.join(messages_keys)}'."
raise ConversationFormatError(err, str(data))
if all(messages_key in data for messages_key in messages_keys):
err = f"The conversation cannot contain both of '{', '.join(messages_keys)}' key, but only one of the two."
raise ConversationFormatError(err, str(data))
# get first non-None value
data_messages: Optional[List[Dict[str, Any]]] = None
for key in messages_keys:
data_messages = data[key] if key in data else data_messages
assert data_messages is not None, "data_messages can't be None"
if "available_tools" in data and "tools" in data:
err = "The conversation contains both an `available_tools` and `tools` key. You can only have one."
raise ConversationFormatError(err, str(data))
if data.get("tools", None) is not None and len(data["tools"]) > 0:
available_tools = _parse_available_tools(data["tools"])
elif (
data.get("available_tools", None) is not None
and len(data["available_tools"]) > 0
):
available_tools = _parse_available_tools(data["available_tools"])
for data_message in data_messages:
is_tool_call = data_message.get("tool_calls") is not None
if "role" not in data_message:
err = f"A message does not contain a 'role' key, but only {', '.join(data_message.keys())}. Make sure that the message includes the key 'role'."
raise MessageFormatError(err, str(data))
role = data_message["role"]
if all(key in data_message for key in content_keys):
err = f"A {role} message contains both a 'text' and 'content' key. Make sure that there is only one of the two."
raise MessageFormatError(err, str(data))
content: Optional[str] = None
for key in content_keys:
content = content if content is not None else data_message.get(key)
# non-function call message must have content
if not is_tool_call and content is None:
err = f"A {role} message does not contain one of '{content_keys}' key, but only {', '.join(data_message.keys())}. Make sure that the message includes one of '{content_keys}' keys."
raise MessageFormatError(err, str(data))
if role not in allowed_roles:
raise UnrecognizedRoleError(role, allowed_roles)
if data_message["role"] == "user":
assert content is not None
messages.append(UserMessage(content=content))
elif data_message["role"] == "assistant":
tool_calls: Optional[List[ToolCall]] = None
if is_tool_call:
tool_calls = _parse_tool_calls(data_message["tool_calls"])
weight = data_message.get("weight")
messages.append(
FinetuningAssistantMessage(
content=content, tool_calls=tool_calls, weight=weight
)
)
elif data_message["role"] == "system":
if system_prompt is not None:
err = "Multiple messages with role 'system' encountered. Only one is allowed."
raise MessageFormatError(err, str(data))
system_prompt = content
elif data_message["role"] == "tool":
assert content is not None
tool_message = _parse_tool_message(content, data_message)
messages.append(tool_message)
# validate created messages
validator = MistralRequestValidatorV3(ValidationMode.finetuning)
validator.validate_messages(messages)
validator._validate_tools(available_tools or [])
# whether to train only on last assistant message
only_last = data.get("only_last", False) or available_tools is not None
return TrainingInstructSample(
messages=messages,
system_prompt=system_prompt,
available_tools=available_tools,
only_last=only_last,
)
def _parse_available_tools(tools: List[Dict[str, Any]]) -> List[Tool]:
available_tools = []
for tool in tools:
if "function" not in tool:
raise FunctionFormatError(
"A tool dict does not have a 'function' key.", str(tool)
)
func_data = tool["function"]
for key in ["name", "description", "parameters"]:
if key not in func_data:
raise FunctionFormatError(
f"A function dict does not have a {key} key.", str(func_data)
)
if not isinstance(func_data["parameters"], dict):
raise FunctionFormatError(
f"A function 'parameters' key has to be of type dict, but is {type(func_data['parameters'])}. If the function has no parameters pass an empyt dict ", str(func_data)
)
description = func_data["description"]
function = Function(
name=func_data["name"],
description=description,
parameters=func_data["parameters"],
)
available_tools.append(Tool(function=function))
return available_tools
def _parse_tool_calls(calls: List[Dict[str, Any]]) -> List[ToolCall]:
for key in ["id", "function"]:
if not all(key in call for call in calls):
err = f"A tool call of an assistant message does not have a {key} key"
raise ToolCallFormatError(err, str(calls))
for key in ["name", "arguments"]:
if not all(key in call["function"] for call in calls):
err = (
f"A tool call function of an assistant message does not have a {key} key"
)
raise ToolCallFormatError(err, str(calls))
if not all(isinstance(call["function"]["arguments"], str) for call in calls):
err = "A tool call function of an assistant message does not have a 'arguments' key of type str"
raise ToolCallFormatError(err, str(calls))
tool_calls = [
ToolCall(
id=call["id"],
function=FunctionCall(
name=call["function"]["name"],
arguments=call["function"]["arguments"],
),
)
for call in calls
]
return tool_calls
def _parse_tool_message(content: str, data_message: Dict[str, Any]) -> ToolMessage:
if "tool_call_id" not in data_message:
err = f"A tool message does not contain a 'tool_call_id' key, but only {', '.join(data_message.keys())}. Make sure that the message includes the key 'tool_call_id'."
raise MessageFormatError(err, str(data_message))
call_id = data_message["tool_call_id"]
# name is deprecated in v3, but we'll add it nevertheless for now
name = data_message.get("name")
return ToolMessage(content=content, tool_call_id=call_id, name=name)
def tokenize(
sample: Union[str, TrainingInstructSample],
instruct_tokenizer: InstructTokenizerBase,
) -> TokenSample:
if isinstance(sample, str):
tokenizer: Tokenizer = instruct_tokenizer.tokenizer
return tokenize_pretrain(sample, tokenizer)
elif isinstance(sample, TrainingInstructSample):
return tokenize_instruct(sample, instruct_tokenizer)
raise ValueError(
f"`sample` has to be either of type `str` or `TrainingInstructSample`, not {type(sample)}."
)
def tokenize_pretrain(sample: str, tokenizer: Tokenizer) -> TokenSample:
tokens = tokenizer.encode(sample, bos=True, eos=True)
masks = [True] * len(tokens)
return TokenSample(tokens, masks)
def tokenize_instruct(
sample: TrainingInstructSample,
instruct_tokenizer: InstructTokenizerBase,
) -> TokenSample:
tokens: List[int] = instruct_tokenizer.start()
masks: List[bool] = [False]
mask_all_but_last = sample.only_last
# find first and last user message
user_messages = [
i for i, msg in enumerate(sample.messages) if isinstance(msg, UserMessage)
]
first_user_idx = user_messages[0] if user_messages else -1
last_user_idx = user_messages[-1] if user_messages else -1
for msg_idx, message in enumerate(sample.messages):
if isinstance(message, UserMessage):
curr_tokens = instruct_tokenizer.encode_user_message(
message,
available_tools=sample.available_tools,
is_last=msg_idx == last_user_idx,
is_first=msg_idx == first_user_idx,
system_prompt=sample.system_prompt,
)
curr_masks = [False] * len(curr_tokens) # only predict bot answers
elif isinstance(message, ToolMessage):
curr_tokens = instruct_tokenizer.encode_tool_message(
message, is_before_last_user_message=msg_idx < last_user_idx
)
curr_masks = [False] * len(curr_tokens) # only predict bot answers
elif isinstance(message, FinetuningAssistantMessage):
is_last_message = msg_idx == (len(sample.messages) - 1)
# we don't want to predict a random call id
message = maybe_remove_call_id(message, is_last_message=is_last_message)
curr_tokens = instruct_tokenizer.encode_assistant_message(
message, is_before_last_user_message=False
)
is_weighted = message.weight is None or message.weight == 1
is_relevant = (not mask_all_but_last) or is_last_message
if is_weighted and is_relevant:
curr_masks = [True] * len(curr_tokens) # only predict bot answers
else:
# in function calling we only backprop through last message
curr_masks = [False] * len(curr_tokens)
tokens.extend(curr_tokens)
masks.extend(curr_masks)
return TokenSample(tokens, masks)
def maybe_remove_call_id(message: FinetuningAssistantMessage, is_last_message: bool):
if message.tool_calls is None or not is_last_message:
return message
# remove call id
message.tool_calls = [
ToolCall(function=call.function) for call in message.tool_calls
]
return message
|