File size: 5,071 Bytes
9eb0b88 5281790 9eb0b88 82c786e 9eb0b88 5281790 9eb0b88 5281790 9eb0b88 5281790 9eb0b88 5281790 82c786e 9eb0b88 82c786e 9eb0b88 5281790 f6fa207 5281790 f6fa207 5281790 9eb0b88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import gc
from typing import Union, Optional, Iterator, Callable
import torch
from datasets import load_dataset
from litgpt.tokenizer import Tokenizer
from transformers import AutoTokenizer
def _batch_text_iterator(path: str,
name: Optional[str]=None,
data_dir: Optional[str]=None,
data_files: Optional[str]=None,
keep_in_memory: bool=False,
revision: Optional[str]=None,
split: str='train',
num_proc: Optional[int]=None,
format: Optional[Callable|str]=None) -> Iterator[str]:
assert isinstance(format, str) or callable(format), repr(format)
dataset = load_dataset(path=path,
name=name,
data_dir=data_dir,
data_files=data_files,
keep_in_memory=keep_in_memory,
revision=revision,
split=split,
trust_remote_code=True,
num_proc=num_proc)
if callable(format):
for row in dataset:
text = format(row)
yield text
else:
for row in dataset:
text = format.format(**row)
yield text
del dataset
gc.collect()
def _batch_chat_iterator(path: str,
name: Optional[str]=None,
data_dir: Optional[str]=None,
data_files: Optional[str]=None,
keep_in_memory: bool=False,
revision: Optional[str]=None,
split: str='train',
num_proc: Optional[int]=None,
field: Optional[str]=None,
transform: Optional[Callable]=None) -> Iterator[list[dict[str, str]]]:
dataset = load_dataset(path=path,
name=name,
data_dir=data_dir,
data_files=data_files,
keep_in_memory=keep_in_memory,
revision=revision,
split=split,
trust_remote_code=True,
num_proc=num_proc)
if callable(transform):
for row in dataset:
if field:
messages = transform(row[field])
else:
messages = transform(row)
yield messages
else:
for row in dataset:
if field:
messages = row[field]
else:
raise ValueError(field)
yield messages
del dataset
gc.collect()
def batch_text_iterator(dataset_config: Union[list, dict]) -> Iterator[str]:
assert isinstance(dataset_config, (dict, list)), dataset_config
if isinstance(dataset_config, dict):
for text in _batch_text_iterator(**dataset_config):
yield text
elif isinstance(dataset_config, list):
for dc in dataset_config:
for text in _batch_text_iterator(**dc):
yield text
def batch_chat_iterator(dataset_config: Union[list, dict]) -> Iterator[list[dict[str, str]]]:
assert isinstance(dataset_config, (dict, list)), dataset_config
if isinstance(dataset_config, dict):
for messages in _batch_chat_iterator(**dataset_config):
yield messages
elif isinstance(dataset_config, list):
for dc in dataset_config:
for messages in _batch_chat_iterator(**dc):
yield messages
def tokenize_text_fn(dataset_config: list, tokenizer: Tokenizer, min_len: Optional[int]=None, max_len: Optional[int]=None) -> Iterator[torch.Tensor]:
for text in batch_text_iterator(dataset_config):
text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=True)
if min_len is None and max_len is None:
yield text_ids
if min_len is None:
min_len = 0
if max_len is None:
max_len = len(text_ids)
if min_len <= len(text_ids) <= max_len:
yield text_ids
def tokenize_chat_fn(dataset_config: list, hf_tokenizer: AutoTokenizer, tokenizer: Tokenizer, min_len: Optional[int]=None, max_len: Optional[int]=None) -> Iterator[torch.Tensor]:
for messages in batch_chat_iterator(dataset_config):
# text_ids: torch.Tensor = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors='pt')
# text_ids = text_ids.to(torch.int)
text: str = hf_tokenizer.apply_chat_template(messages, tokenize=False)
text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=False)
if min_len is None and max_len is None:
yield text_ids
if min_len is None:
min_len = 0
if max_len is None:
max_len = len(text_ids)
if min_len <= len(text_ids) <= max_len:
yield text_ids
|