File size: 2,925 Bytes
4d09b42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c10563c
 
 
4d09b42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
Test module for raw i/o data for prompts
"""
import pytest
from datasets import Dataset
from tokenizers import AddedToken
from transformers import AutoTokenizer

from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies.input_output import (
    RawInputOutputPrompter,
    RawInputOutputStrategy,
)


@pytest.fixture(name="segments_dataset")
def fixture_sharegpt_dataset():
    return Dataset.from_list(
        [
            {
                "segments": [
                    {
                        "label": False,
                        "text": "<s>hello ",
                    },
                    {
                        "label": True,
                        "text": "hi there.<eot>",
                    },
                    {
                        "label": False,
                        "text": "goodbye ",
                    },
                    {
                        "label": True,
                        "text": "farewell<eot>",
                    },
                ]
            }
        ]
    )


@pytest.fixture(name="tokenizer")
def fixture_tokenizer():
    tokenizer = AutoTokenizer.from_pretrained(
        "casperhansen/mistral-7b-instruct-v0.1-awq"
    )
    tokenizer.add_tokens(
        [
            AddedToken("<eot>", rstrip=False, lstrip=False, normalized=False),
        ]
    )

    return tokenizer


class TestRawInputOutputPrompts:
    """
    Test class for raw i/o prompter
    """

    def test_segment_prompts(self, segments_dataset, tokenizer):
        strategy = RawInputOutputStrategy(
            RawInputOutputPrompter(),
            tokenizer,
            False,  # train_on_inputs
            2048,  # sequence_len
        )

        dataset_wrapper = TokenizedPromptDataset(
            strategy, segments_dataset, process_count=1
        )

        input_ids = dataset_wrapper[0]["input_ids"]
        labels = dataset_wrapper[0]["labels"]

        assert (
            tokenizer.decode(input_ids)
            == "<s> hello  hi there.<eot> goodbye  farewell<eot>"
        )
        # fmt: off
        assert input_ids == [
            1,  # <s>
            6312,  # hell
            28709,  # o
            28705,  #
            12014,  # hi
            736,  # there
            28723,  # .
            32000,  # <eot>
            1179,  # good
            17664,  # bye
            28705,  #
            19111,  # fare
            5458,  # well
            32000,  # <eot>
        ]
        # fmt: on

        # fmt: off
        assert labels == [
            -100,  # <s>
            -100,  # hell
            -100,  # o
            -100,  #
            12014,  # hi
            736,  # there
            28723,  # .
            32000,  # <eot>
            -100,  # good
            -100,  # bye
            -100,  #
            19111,  # fare
            5458,  # well
            32000,  # <eot>
        ]
        # fmt: on