File size: 2,925 Bytes
4d09b42 c10563c 4d09b42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
"""
Test module for raw i/o data for prompts
"""
import pytest
from datasets import Dataset
from tokenizers import AddedToken
from transformers import AutoTokenizer
from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies.input_output import (
RawInputOutputPrompter,
RawInputOutputStrategy,
)
@pytest.fixture(name="segments_dataset")
def fixture_sharegpt_dataset():
return Dataset.from_list(
[
{
"segments": [
{
"label": False,
"text": "<s>hello ",
},
{
"label": True,
"text": "hi there.<eot>",
},
{
"label": False,
"text": "goodbye ",
},
{
"label": True,
"text": "farewell<eot>",
},
]
}
]
)
@pytest.fixture(name="tokenizer")
def fixture_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(
"casperhansen/mistral-7b-instruct-v0.1-awq"
)
tokenizer.add_tokens(
[
AddedToken("<eot>", rstrip=False, lstrip=False, normalized=False),
]
)
return tokenizer
class TestRawInputOutputPrompts:
"""
Test class for raw i/o prompter
"""
def test_segment_prompts(self, segments_dataset, tokenizer):
strategy = RawInputOutputStrategy(
RawInputOutputPrompter(),
tokenizer,
False, # train_on_inputs
2048, # sequence_len
)
dataset_wrapper = TokenizedPromptDataset(
strategy, segments_dataset, process_count=1
)
input_ids = dataset_wrapper[0]["input_ids"]
labels = dataset_wrapper[0]["labels"]
assert (
tokenizer.decode(input_ids)
== "<s> hello hi there.<eot> goodbye farewell<eot>"
)
# fmt: off
assert input_ids == [
1, # <s>
6312, # hell
28709, # o
28705, #
12014, # hi
736, # there
28723, # .
32000, # <eot>
1179, # good
17664, # bye
28705, #
19111, # fare
5458, # well
32000, # <eot>
]
# fmt: on
# fmt: off
assert labels == [
-100, # <s>
-100, # hell
-100, # o
-100, #
12014, # hi
736, # there
28723, # .
32000, # <eot>
-100, # good
-100, # bye
-100, #
19111, # fare
5458, # well
32000, # <eot>
]
# fmt: on
|