|
import gc |
|
from typing import Union, Optional, Iterator, Callable |
|
|
|
import torch |
|
from datasets import load_dataset |
|
from litgpt.tokenizer import Tokenizer |
|
from transformers import AutoTokenizer |
|
|
|
def _batch_text_iterator(path: str, |
|
name: Optional[str]=None, |
|
data_dir: Optional[str]=None, |
|
data_files: Optional[str]=None, |
|
keep_in_memory: bool=False, |
|
revision: Optional[str]=None, |
|
split: str='train', |
|
num_proc: Optional[int]=None, |
|
format: Optional[Callable|str]=None) -> Iterator[str]: |
|
assert isinstance(format, str) or callable(format), repr(format) |
|
|
|
dataset = load_dataset(path=path, |
|
name=name, |
|
data_dir=data_dir, |
|
data_files=data_files, |
|
keep_in_memory=keep_in_memory, |
|
revision=revision, |
|
split=split, |
|
trust_remote_code=True, |
|
num_proc=num_proc) |
|
|
|
if callable(format): |
|
for row in dataset: |
|
text = format(row) |
|
yield text |
|
else: |
|
for row in dataset: |
|
text = format.format(**row) |
|
yield text |
|
|
|
del dataset |
|
gc.collect() |
|
|
|
|
|
def _batch_chat_iterator(path: str, |
|
name: Optional[str]=None, |
|
data_dir: Optional[str]=None, |
|
data_files: Optional[str]=None, |
|
keep_in_memory: bool=False, |
|
revision: Optional[str]=None, |
|
split: str='train', |
|
num_proc: Optional[int]=None, |
|
field: Optional[str]=None, |
|
transform: Optional[Callable]=None) -> Iterator[list[dict[str, str]]]: |
|
|
|
dataset = load_dataset(path=path, |
|
name=name, |
|
data_dir=data_dir, |
|
data_files=data_files, |
|
keep_in_memory=keep_in_memory, |
|
revision=revision, |
|
split=split, |
|
trust_remote_code=True, |
|
num_proc=num_proc) |
|
|
|
if callable(transform): |
|
for row in dataset: |
|
if field: |
|
messages = transform(row[field]) |
|
else: |
|
messages = transform(row) |
|
|
|
yield messages |
|
else: |
|
for row in dataset: |
|
if field: |
|
messages = row[field] |
|
else: |
|
raise ValueError(field) |
|
|
|
yield messages |
|
|
|
del dataset |
|
gc.collect() |
|
|
|
|
|
def batch_text_iterator(dataset_config: Union[list, dict]) -> Iterator[str]: |
|
assert isinstance(dataset_config, (dict, list)), dataset_config |
|
|
|
if isinstance(dataset_config, dict): |
|
for text in _batch_text_iterator(**dataset_config): |
|
yield text |
|
elif isinstance(dataset_config, list): |
|
for dc in dataset_config: |
|
for text in _batch_text_iterator(**dc): |
|
yield text |
|
|
|
|
|
def batch_chat_iterator(dataset_config: Union[list, dict]) -> Iterator[list[dict[str, str]]]: |
|
assert isinstance(dataset_config, (dict, list)), dataset_config |
|
|
|
if isinstance(dataset_config, dict): |
|
for messages in _batch_chat_iterator(**dataset_config): |
|
yield messages |
|
elif isinstance(dataset_config, list): |
|
for dc in dataset_config: |
|
for messages in _batch_chat_iterator(**dc): |
|
yield messages |
|
|
|
|
|
def tokenize_text_fn(dataset_config: list, tokenizer: Tokenizer, min_len: Optional[int]=None, max_len: Optional[int]=None) -> Iterator[torch.Tensor]: |
|
for text in batch_text_iterator(dataset_config): |
|
text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=True) |
|
|
|
if min_len is None and max_len is None: |
|
yield text_ids |
|
|
|
if min_len is None: |
|
min_len = 0 |
|
|
|
if max_len is None: |
|
max_len = len(text_ids) |
|
|
|
if min_len <= len(text_ids) <= max_len: |
|
yield text_ids |
|
|
|
|
|
def tokenize_chat_fn(dataset_config: list, hf_tokenizer: AutoTokenizer, tokenizer: Tokenizer, min_len: Optional[int]=None, max_len: Optional[int]=None) -> Iterator[torch.Tensor]: |
|
for messages in batch_chat_iterator(dataset_config): |
|
|
|
|
|
text: str = hf_tokenizer.apply_chat_template(messages, tokenize=False) |
|
text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=False) |
|
|
|
if min_len is None and max_len is None: |
|
yield text_ids |
|
|
|
if min_len is None: |
|
min_len = 0 |
|
|
|
if max_len is None: |
|
max_len = len(text_ids) |
|
|
|
if min_len <= len(text_ids) <= max_len: |
|
yield text_ids |
|
|