qwerrwe / src /axolotl /utils /tokenization.py
winglian's picture
add tests and supoort for loader for sys prompt data
3a38271
raw
history blame
1.34 kB
"""Module for tokenization utilities"""
import logging
from termcolor import colored
def check_dataset_labels(dataset, tokenizer):
# the dataset is already shuffled, so let's just check the first 5 elements
for idx in range(5):
check_example_labels(dataset[idx], tokenizer)
def check_example_labels(example, tokenizer):
# Get the input_ids, labels, and attention_mask from the dataset
input_ids = example["input_ids"]
labels = example["labels"]
attention_mask = example["attention_mask"]
# You can compare the input_ids and labels element-wise
# Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
colored_tokens = []
for _, (input_id, label_id, mask) in enumerate(
zip(input_ids, labels, attention_mask)
):
decoded_input_token = tokenizer.decode(input_id)
# Choose the color based on whether the label has the ignore value or not
color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
colored_token = colored(decoded_input_token, color) + colored(
f"({label_id}, {mask}, {input_id})", "white"
)
colored_tokens.append(colored_token)
logging.info(" ".join(colored_tokens))
logging.info("\n\n\n")
return " ".join(colored_tokens)