"""Taken and modified from vllm: https://github.com/vllm-project/vllm/blob/93b38bea5dd03e1b140ca997dfaadef86f8f1855/benchmarks/benchmark_serving.py Filter dataset to: 1. Remove entries that have too long prompts or completions 2. Only keep first human prompt for each conversation """ import json import random from typing import AsyncGenerator, List, Tuple from transformers import ( AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast, ) def filter_dataset_to_size( dataset_path: str, size: int, ) -> List[Tuple[str, int, int]]: # Load the dataset. with open(dataset_path) as f: dataset = json.load(f) # randomly sample dataset return random.sample(dataset, size) def filter_dataset( dataset_path: str, tokenizer: PreTrainedTokenizerBase, ) -> List[Tuple[str, int, int]]: # Load the dataset. with open(dataset_path) as f: dataset = json.load(f) # Filter out the conversations with less than 2 turns. dataset = [data for data in dataset if len(data["conversations"]) >= 2] # Only keep the first two turns of each conversation. dataset = [ ( data["id"], data["conversations"][0]["value"], data["conversations"][1]["value"], ) for data in dataset ] # Tokenize the prompts and completions. conversation_ids = [conv_id for conv_id, _, _ in dataset] prompts = [prompt for _, prompt, _ in dataset] prompt_token_ids = tokenizer(prompts).input_ids completions = [completion for _, _, completion in dataset] completion_token_ids = tokenizer(completions).input_ids tokenized_dataset = [] for i in range(len(dataset)): output_len = len(completion_token_ids[i]) tokenized_dataset.append( (conversation_ids[i], prompts[i], prompt_token_ids[i], output_len) ) # Filter out too long sequences. filtered_dataset_json = [] for conv_id, prompt, prompt_token_ids, output_len in tokenized_dataset: prompt_len = len(prompt_token_ids) if prompt_len < 4 or output_len < 4: # Prune too short sequences. # This is because TGI causes errors when the input or output length # is too short. continue # making even shorter than 1024 to account for additional tokens introduced by chat completion wrapper if prompt_len > 800 or output_len > 800: # if prompt_len > 1024 or output_len > 1024: # Prune too long sequences. continue filtered_dataset_json.append( { "id": conv_id, "conversations": [ { "from": "human", "value": prompt, } ], } ) return filtered_dataset_json def main(): tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") # download: https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json filtered_dataset = filter_dataset( "ShareGPT_V3_unfiltered_cleaned_split.json", tokenizer ) with open("ShareGPT_V3_filtered.json", "w") as f: json.dump(filtered_dataset, f) sampled_dataset = filter_dataset_to_size("ShareGPT_V3_filtered.json", 500) with open("ShareGPT_V3_filtered_500.json", "w") as f: json.dump(sampled_dataset, f) if __name__ == "__main__": main()