"""Module for tokenization utilities""" import logging import re from typing import Dict, List from termcolor import colored LOG = logging.getLogger("axolotl") def check_dataset_labels( dataset, tokenizer, num_examples=5, text_only=False, rl_mode=False, ): # the dataset is already shuffled, so let's just check the first 5 elements for idx in range(num_examples): if not rl_mode: check_example_labels(dataset[idx], tokenizer, text_only=text_only) else: check_rl_example_labels(dataset[idx], tokenizer, text_only=text_only) def check_example_labels(example, tokenizer, text_only=False): # Get the input_ids, labels, and attention_mask from the dataset input_ids = example["input_ids"] labels = example["labels"] # You can compare the input_ids and labels element-wise # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0 colored_tokens = [] for _, (input_id, label_id) in enumerate(zip(input_ids, labels)): decoded_input_token = tokenizer.decode(input_id) # Choose the color based on whether the label has the ignore value or not color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green") colored_token = colored(decoded_input_token, color) + ( not text_only and colored(f"({label_id}, {input_id})", "white") or "" ) colored_tokens.append(colored_token) delimiter = "" if text_only else " " LOG.info(delimiter.join(colored_tokens)) LOG.info("\n\n\n") return " ".join(colored_tokens) def color_token_for_rl_debug(decoded_token, encoded_token, color, text_only): """Helper function to color tokens based on their type.""" colored_text = colored(decoded_token, color) return ( colored_text if text_only else f"{colored_text}{colored(f'({encoded_token})', 'white')}" ) def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only): """Helper function to process and color tokens.""" colored_tokens = [ color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only) for token in tokenizer.encode(tokens) ] return colored_tokens def check_rl_example_labels(example, tokenizer, text_only=False): field_prompt, field_chosen, field_rejected = "prompt", "chosen", "rejected" input_tokens = example[field_prompt] labels_chosen, labels_rejected = example[field_chosen], example[field_rejected] # Process and color each type of token colored_tokens = process_tokens_for_rl_debug( input_tokens, "yellow", tokenizer, text_only ) colored_chosens = process_tokens_for_rl_debug( labels_chosen, "green", tokenizer, text_only ) colored_rejecteds = process_tokens_for_rl_debug( labels_rejected, "red", tokenizer, text_only ) # Create a delimiter based on text_only flag delimiter = "" if text_only else " " # Logging information LOG.info(f"INPUT PROMPT: {delimiter.join(colored_tokens)}\n\n") LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n") LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n") return delimiter.join(colored_tokens) GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"] GLAIVE_TO_SHAREGPT_ROLE = { "SYSTEM": "system", "USER": "human", "ASSISTANT": "gpt", "FUNCTION RESPONSE": "tool", } GLAIVE_MSG_REGEX = re.compile(rf"({'|'.join(GLAIVE_ROLES)}): ") def chatml_to_conversation(row: Dict[str, str]) -> List[Dict[str, str]]: """ Converts a ChatML formatted row to a list of messages in ShareGPT format. Initially based off https://github.com/lilacai/lilac/blob/main/notebooks/GlaiveToShareGPT.ipynb. """ system_prompt = row.get("system") if system_prompt: system_prompt = system_prompt.removeprefix("SYSTEM: ") chat_str = row["chat"] chat_msgs = [s.strip() for s in GLAIVE_MSG_REGEX.split(chat_str) if s] chat_msg_dicts = [ {"from": GLAIVE_TO_SHAREGPT_ROLE[role], "value": value} for role, value in zip(chat_msgs[::2], chat_msgs[1::2]) ] if system_prompt: chat_msg_dicts = [ {"from": GLAIVE_TO_SHAREGPT_ROLE["SYSTEM"], "value": system_prompt} ] + chat_msg_dicts return chat_msg_dicts def merge_consecutive_messages(messages): """ Merge consecutive messages from the same sender into a single message. This can be useful with datasets that contain multiple consecutive tool calls. """ merged_messages = [] current_from = None current_message = "" for msg in messages: if current_from == msg["from"]: current_message += msg["value"] else: if current_from is not None: merged_messages.append({"from": current_from, "value": current_message}) current_from = msg["from"] current_message = msg["value"] if current_from is not None: merged_messages.append({"from": current_from, "value": current_message}) return merged_messages