File size: 3,012 Bytes
00568c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367b2e8
00568c1
 
367b2e8
00568c1
 
367b2e8
 
00568c1
 
 
 
 
 
 
 
 
 
 
 
 
 
367b2e8
 
 
 
00568c1
367b2e8
 
 
00568c1
367b2e8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""Module for testing streaming dataset sequence packing"""
import pytest
from datasets import concatenate_datasets, load_dataset
from torch.utils.data import DataLoader, RandomSampler
from transformers import AutoTokenizer

from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies.completion import load
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths


@pytest.fixture(name="tokenizer")
def fixture_tokenizer():
    tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
    tokenizer.pad_token = "</s>"
    return tokenizer


@pytest.fixture(name="max_seq_length")
def fixture_max_seq_length():
    return 4096


class TestBatchedSamplerPacking:
    """
    Test class for packing streaming dataset sequences
    """

    @pytest.mark.parametrize(
        "batch_size, num_workers",
        [
            (1, 0),
            (2, 0),
            (1, 2),
            (2, 2),
        ],
    )
    def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
        import axolotl.monkeypatch.data.batch_dataset_fetcher  # pylint: disable=unused-import  # noqa: F401

        dataset = load_dataset(
            "Trelis/tiny-shakespeare",
            split="train",
        )

        cfg = DictDefault(
            {
                "train_on_inputs": True,
                "sequence_len": max_seq_length,
            }
        )
        ds_cfg = DictDefault(
            {
                "field": "Text",
            }
        )
        completion_strategy = load(tokenizer, cfg, ds_cfg)
        dataset_wrapper = TokenizedPromptDataset(
            completion_strategy,
            dataset,
        )
        train_dataset = concatenate_datasets([dataset_wrapper])
        lengths = get_dataset_lengths(train_dataset)
        batch_sampler = MultipackBatchSampler(
            sampler=RandomSampler(train_dataset),
            lengths=lengths,
            batch_size=batch_size,
            batch_max_len=max_seq_length,
            group_size=100000,
            bin_size=200,
        )

        loader = DataLoader(
            train_dataset,
            batch_sampler=batch_sampler,
            collate_fn=V2BatchSamplerDataCollatorForSeq2Seq(  # pylint: disable=unexpected-keyword-arg
                tokenizer=tokenizer,
                padding=True,
                pad_to_multiple_of=max_seq_length,
                return_tensors="pt",
            ),
            num_workers=num_workers,
        )

        batch_idxs = []
        for batch in batch_sampler:
            for pack in batch:
                batch_idxs.extend(pack)

        for batch in loader:
            assert len(batch["input_ids"]) <= batch_size * max_seq_length
            assert batch["input_ids"].shape[1] == max_seq_length

        original_idxs = set(range(len(train_dataset)))
        assert original_idxs == set(batch_idxs)