import gc from typing import Union, Optional, Iterator, Callable import torch from datasets import load_dataset from litgpt.tokenizer import Tokenizer from transformers import AutoTokenizer def _batch_text_iterator(path: str, name: Optional[str]=None, data_dir: Optional[str]=None, data_files: Optional[str]=None, keep_in_memory: bool=False, revision: Optional[str]=None, split: str='train', num_proc: Optional[int]=None, format: Optional[Callable|str]=None) -> Iterator[str]: assert isinstance(format, str) or callable(format), repr(format) dataset = load_dataset(path=path, name=name, data_dir=data_dir, data_files=data_files, keep_in_memory=keep_in_memory, revision=revision, split=split, trust_remote_code=True, num_proc=num_proc) if callable(format): for row in dataset: text = format(row) yield text else: for row in dataset: text = format.format(**row) yield text del dataset gc.collect() def _batch_chat_iterator(path: str, name: Optional[str]=None, data_dir: Optional[str]=None, data_files: Optional[str]=None, keep_in_memory: bool=False, revision: Optional[str]=None, split: str='train', num_proc: Optional[int]=None, field: Optional[str]=None, transform: Optional[Callable]=None) -> Iterator[list[dict[str, str]]]: dataset = load_dataset(path=path, name=name, data_dir=data_dir, data_files=data_files, keep_in_memory=keep_in_memory, revision=revision, split=split, trust_remote_code=True, num_proc=num_proc) if callable(transform): for row in dataset: if field: messages = transform(row[field]) else: messages = transform(row) yield messages else: for row in dataset: if field: messages = row[field] else: raise ValueError(field) yield messages del dataset gc.collect() def batch_text_iterator(dataset_config: Union[list, dict]) -> Iterator[str]: assert isinstance(dataset_config, (dict, list)), dataset_config if isinstance(dataset_config, dict): for text in _batch_text_iterator(**dataset_config): yield text elif isinstance(dataset_config, list): for dc in dataset_config: for text in _batch_text_iterator(**dc): yield text def batch_chat_iterator(dataset_config: Union[list, dict]) -> Iterator[list[dict[str, str]]]: assert isinstance(dataset_config, (dict, list)), dataset_config if isinstance(dataset_config, dict): for messages in _batch_chat_iterator(**dataset_config): yield messages elif isinstance(dataset_config, list): for dc in dataset_config: for messages in _batch_chat_iterator(**dc): yield messages def tokenize_text_fn(dataset_config: list, tokenizer: Tokenizer, min_len: Optional[int]=None, max_len: Optional[int]=None) -> Iterator[torch.Tensor]: for text in batch_text_iterator(dataset_config): text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=True) if min_len is None and max_len is None: yield text_ids if min_len is None: min_len = 0 if max_len is None: max_len = len(text_ids) if min_len <= len(text_ids) <= max_len: yield text_ids def tokenize_chat_fn(dataset_config: list, hf_tokenizer: AutoTokenizer, tokenizer: Tokenizer, min_len: Optional[int]=None, max_len: Optional[int]=None) -> Iterator[torch.Tensor]: for messages in batch_chat_iterator(dataset_config): # text_ids: torch.Tensor = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors='pt') # text_ids = text_ids.to(torch.int) text: str = hf_tokenizer.apply_chat_template(messages, tokenize=False) text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=False) if min_len is None and max_len is None: yield text_ids if min_len is None: min_len = 0 if max_len is None: max_len = len(text_ids) if min_len <= len(text_ids) <= max_len: yield text_ids