File size: 3,563 Bytes
c5e73ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""Taken and modified from vllm: https://github.com/vllm-project/vllm/blob/93b38bea5dd03e1b140ca997dfaadef86f8f1855/benchmarks/benchmark_serving.py
   Filter dataset to:
   1. Remove entries that have too long prompts or completions
   2. Only keep first human prompt for each conversation
"""

import json
import random
from typing import AsyncGenerator, List, Tuple

from transformers import (
    AutoTokenizer,
    PreTrainedTokenizer,
    PreTrainedTokenizerBase,
    PreTrainedTokenizerFast,
)


def filter_dataset_to_size(
    dataset_path: str,
    size: int,
) -> List[Tuple[str, int, int]]:
    # Load the dataset.
    with open(dataset_path) as f:
        dataset = json.load(f)

    # randomly sample dataset
    return random.sample(dataset, size)


def filter_dataset(
    dataset_path: str,
    tokenizer: PreTrainedTokenizerBase,
) -> List[Tuple[str, int, int]]:
    # Load the dataset.
    with open(dataset_path) as f:
        dataset = json.load(f)
    # Filter out the conversations with less than 2 turns.
    dataset = [data for data in dataset if len(data["conversations"]) >= 2]
    # Only keep the first two turns of each conversation.
    dataset = [
        (
            data["id"],
            data["conversations"][0]["value"],
            data["conversations"][1]["value"],
        )
        for data in dataset
    ]

    # Tokenize the prompts and completions.
    conversation_ids = [conv_id for conv_id, _, _ in dataset]
    prompts = [prompt for _, prompt, _ in dataset]
    prompt_token_ids = tokenizer(prompts).input_ids
    completions = [completion for _, _, completion in dataset]
    completion_token_ids = tokenizer(completions).input_ids
    tokenized_dataset = []
    for i in range(len(dataset)):
        output_len = len(completion_token_ids[i])
        tokenized_dataset.append(
            (conversation_ids[i], prompts[i], prompt_token_ids[i], output_len)
        )

    # Filter out too long sequences.
    filtered_dataset_json = []
    for conv_id, prompt, prompt_token_ids, output_len in tokenized_dataset:
        prompt_len = len(prompt_token_ids)
        if prompt_len < 4 or output_len < 4:
            # Prune too short sequences.
            # This is because TGI causes errors when the input or output length
            # is too short.
            continue
        # making even shorter than 1024 to account for additional tokens introduced by chat completion wrapper
        if prompt_len > 800 or output_len > 800:
            # if prompt_len > 1024 or output_len > 1024:
            # Prune too long sequences.
            continue
        filtered_dataset_json.append(
            {
                "id": conv_id,
                "conversations": [
                    {
                        "from": "human",
                        "value": prompt,
                    }
                ],
            }
        )

    return filtered_dataset_json


def main():
    tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
    # download: https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
    filtered_dataset = filter_dataset(
        "ShareGPT_V3_unfiltered_cleaned_split.json", tokenizer
    )
    with open("ShareGPT_V3_filtered.json", "w") as f:
        json.dump(filtered_dataset, f)

    sampled_dataset = filter_dataset_to_size("ShareGPT_V3_filtered.json", 500)
    with open("ShareGPT_V3_filtered_500.json", "w") as f:
        json.dump(sampled_dataset, f)


if __name__ == "__main__":
    main()