winglian's picture
ORPO Trainer replacement (#1551)
7d1d22f unverified
raw
history blame
9.29 kB
"""chatml prompt tokenization strategy for ORPO"""
from typing import Any, Dict, Generator, List, Optional, Tuple
from pydantic import BaseModel
from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
from axolotl.prompters import Prompter
from axolotl.utils.chat_templates import chat_templates
class Message(BaseModel):
"""message/turn"""
role: str
content: str
label: Optional[bool] = None
class MessageList(BaseModel):
"""conversation"""
messages: List[Message]
def load(
tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
"""
chatml transforms for datasets with system, input, chosen, rejected
"""
chat_template = chat_templates("chatml")
if ds_cfg and "chat_template" in ds_cfg:
chat_template = ds_cfg["chat_template"]
try:
chat_template = chat_templates(chat_template)
except ValueError:
pass
tokenizer.chat_template = chat_template
return ORPOTokenizingStrategy(
ORPOPrompter(chat_template, tokenizer),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
dataset_parser=ORPODatasetParsingStrategy(),
)
class ORPODatasetParsingStrategy:
"""Strategy to parse chosen rejected dataset into messagelist"""
def get_chosen_conversation_thread(self, prompt) -> MessageList:
"""Dataset structure mappings"""
messages: List[Message] = []
if system := prompt.get("system", None):
messages.append(Message(role="system", content=system, label=False))
messages.append(Message(role="user", content=prompt["prompt"], label=False))
messages.append(
Message(
role="assistant", content=prompt["chosen"][1]["content"], label=True
)
)
return MessageList(messages=messages)
def get_rejected_conversation_thread(self, prompt) -> MessageList:
"""Dataset structure mappings"""
messages: List[Message] = []
if system := prompt.get("system", None):
messages.append(Message(role="system", content=system, label=False))
messages.append(Message(role="user", content=prompt["prompt"], label=False))
messages.append(
Message(
role="assistant", content=prompt["rejected"][1]["content"], label=True
)
)
return MessageList(messages=messages)
def get_prompt(self, prompt) -> MessageList:
"""Map the data to extract everything up to the last turn"""
total_msg_len = len(prompt["chosen"])
total_msg_turns, remainder = divmod(total_msg_len, 2)
assert remainder == 0, "invalid number of turns"
messages: List[Message] = []
if system := prompt.get("system", None):
messages.append(Message(role="system", content=system, label=False))
for i in range(total_msg_turns):
if "prompt" in prompt:
messages.append(
Message(role="user", content=prompt["prompt"], label=False)
)
else:
messages.append(
Message(
role="user",
content=prompt["chosen"][i * 2]["content"],
label=False,
)
)
if i < total_msg_turns - 1:
messages.append(
Message(
role="assistant",
content=prompt["chosen"][i * 2 + 1]["content"],
label=False,
)
)
return MessageList(messages=messages)
def get_chosen(self, prompt) -> MessageList:
res = self.get_prompt(prompt)
res.messages.append(
Message(
role="assistant", content=prompt["chosen"][-1]["content"], label=True
)
)
return res
def get_rejected(self, prompt) -> MessageList:
res = self.get_prompt(prompt)
res.messages.append(
Message(
role="assistant", content=prompt["rejected"][-1]["content"], label=True
)
)
return res
class ORPOTokenizingStrategy(PromptTokenizingStrategy):
"""
rejected_input_ids
input_ids
rejected_attention_mask
attention_mask
rejected_labels
labels
"""
def __init__(
self,
*args,
dataset_parser=None,
**kwargs,
):
super().__init__(*args, **kwargs)
self.dataset_parser = dataset_parser
def tokenize_prompt(self, prompt):
# pass the rejected prompt/row to the Prompter to get the formatted prompt
prompt_len = 0
rejected_message_list = self.dataset_parser.get_rejected_conversation_thread(
prompt
)
input_ids = []
labels = []
for _, (part, label) in enumerate(
self.prompter.build_prompt(rejected_message_list)
):
if not part:
continue
_input_ids = self.tokenizer.encode(part, add_special_tokens=False)
prev_idx = len(input_ids)
input_ids += _input_ids[prev_idx:]
if label:
labels += input_ids[prev_idx:]
else:
labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)
prompt_len = len(input_ids)
# remap the input_ids, attention_mask and labels
rejected_input_ids = input_ids
rejected_labels = labels
# pass the chosen prompt/row to the Prompter to get the formatted prompt
chosen_message_list = self.dataset_parser.get_chosen_conversation_thread(prompt)
input_ids = []
labels = []
for _, (part, label) in enumerate(
self.prompter.build_prompt(chosen_message_list)
):
if not part:
continue
_input_ids = self.tokenizer.encode(part, add_special_tokens=False)
prev_idx = len(input_ids)
input_ids += _input_ids[prev_idx:]
if label:
labels += input_ids[prev_idx:]
else:
labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)
return {
"rejected_input_ids": rejected_input_ids,
"rejected_labels": rejected_labels,
"rejected_attention_mask": [1] * len(rejected_labels),
"input_ids": input_ids,
"labels": labels,
"attention_mask": [1] * len(labels),
"prompt_attention_mask": [1] * prompt_len
+ [0] * (len(labels) - prompt_len),
}
class ORPOPrompter(Prompter):
"""Single Turn prompter for ORPO"""
def __init__(self, chat_template, tokenizer):
self.chat_template = chat_template
self.tokenizer = tokenizer
def build_prompt(
self,
message_list: MessageList,
) -> Generator[Tuple[str, bool], None, None]:
conversation = []
for message in message_list.messages:
conversation.append(message.model_dump())
if message.role == "system":
yield self.tokenizer.apply_chat_template(
conversation,
add_generation_prompt=False,
chat_template=self.chat_template,
tokenize=False,
), False
if message.role == "user":
yield self.tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
chat_template=self.chat_template,
tokenize=False,
), False
if message.role == "assistant":
yield self.tokenizer.apply_chat_template(
conversation,
add_generation_prompt=False,
chat_template=self.chat_template,
tokenize=False,
), True
def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
dataset_parser = ORPODatasetParsingStrategy()
chat_template_str = chat_templates(cfg.chat_template)
def transform_fn(sample, tokenizer=None):
res = {}
res["prompt"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=False,
)
prompt_str_len = len(res["prompt"])
res["chosen"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
add_generation_prompt=False,
chat_template=chat_template_str,
tokenize=False,
)[prompt_str_len:]
res["rejected"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
add_generation_prompt=False,
chat_template=chat_template_str,
tokenize=False,
)[prompt_str_len:]
return res
return transform_fn