Fixed pre-commit problems, fixed small bug in logging_config to handle LOG_LEVEL env var
b1f4f7a
"""Module for testing prompt tokenizers.""" | |
import json | |
import logging | |
import unittest | |
from pathlib import Path | |
from transformers import AutoTokenizer | |
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter | |
from axolotl.prompt_strategies.alpaca_w_system import ( | |
InstructionWSystemPromptTokenizingStrategy, | |
SystemDataPrompter, | |
) | |
from axolotl.prompt_tokenizers import ( | |
AlpacaPromptTokenizingStrategy, | |
ShareGPTPromptTokenizingStrategy, | |
) | |
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter | |
LOG = logging.getLogger("axolotl") | |
class TestPromptTokenizationStrategies(unittest.TestCase): | |
""" | |
Test class for prompt tokenization strategies. | |
""" | |
def setUp(self) -> None: | |
# pylint: disable=duplicate-code | |
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") | |
self.tokenizer.add_special_tokens( | |
{ | |
"bos_token": "<s>", | |
"eos_token": "</s>", | |
"unk_token": "<unk>", | |
} | |
) | |
def test_sharegpt_integration(self): | |
with open( | |
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8" | |
) as fin: | |
data = fin.read() | |
conversation = json.loads(data) | |
with open( | |
Path(__file__).parent / "fixtures/conversation.tokenized.json", | |
encoding="utf-8", | |
) as fin: | |
data = fin.read() | |
tokenized_conversation = json.loads(data) | |
prompter = ShareGPTPrompter("chat") | |
strat = ShareGPTPromptTokenizingStrategy( | |
prompter, | |
self.tokenizer, | |
False, | |
2048, | |
) | |
example = strat.tokenize_prompt(conversation) | |
for fields in ["input_ids", "attention_mask", "labels"]: | |
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields])) | |
self.assertEqual(example[fields], tokenized_conversation[fields]) | |
def test_no_sys_prompt(self): | |
""" | |
tests the interface between the user and assistant parts | |
""" | |
prompter = NoSystemPrompter() | |
# pylint: disable=duplicate-code | |
strat = AlpacaPromptTokenizingStrategy( | |
prompter, | |
self.tokenizer, | |
False, | |
2048, | |
) | |
sample = { | |
"instruction": "hello cruel. lorem ipsum dolor sit amet.", | |
"output": "world!", | |
} | |
example = strat.tokenize_prompt(sample) | |
world_idx = example["input_ids"].index(3186) | |
assert example["labels"][world_idx] == 3186 | |
assert example["labels"][world_idx - 1] == -100 | |
def test_alpaca(self): | |
""" | |
tests the interface between the user and assistant parts | |
""" | |
# pylint: disable=duplicate-code | |
prompter = AlpacaPrompter() | |
strat = AlpacaPromptTokenizingStrategy( | |
prompter, | |
self.tokenizer, | |
False, | |
2048, | |
) | |
sample = {"instruction": "hello!", "output": "Hi! How can I help?"} | |
example = strat.tokenize_prompt(sample) | |
world_idx = example["input_ids"].index(6324) | |
assert example["labels"][world_idx] == 6324 | |
assert example["labels"][world_idx - 1] == -100 | |
class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase): | |
""" | |
Test class for prompt tokenization strategies with sys prompt from the dataset | |
""" | |
def setUp(self) -> None: | |
# pylint: disable=duplicate-code | |
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") | |
self.tokenizer.add_special_tokens( | |
{ | |
"bos_token": "<s>", | |
"eos_token": "</s>", | |
"unk_token": "<unk>", | |
} | |
) | |
def test_system_alpaca(self): | |
prompter = SystemDataPrompter(PromptStyle.CHAT.value) | |
strat = InstructionWSystemPromptTokenizingStrategy( | |
prompter, | |
self.tokenizer, | |
False, | |
2048, | |
) | |
sample = { | |
"system": "use cot", | |
"instruction": "hello!", | |
"output": "Hi! How can I help?", | |
} | |
example = strat.tokenize_prompt(sample) | |
assert example["input_ids"][0:3] == [1, 671, 20118] # <s>use cot | |
assert example["input_ids"][3] == 11889 # USER | |
if __name__ == "__main__": | |
unittest.main() | |