memorizing_transformer_gpt2 / batched_dataloader.py
lavawolfiee's picture
Finally
6bc49a9
raw
history blame
1.41 kB
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding
# some utils for training
class BooksBatcherIter:
def __init__(self, data_iter, batch_size, tokenizer, chunk_size=1024):
self.data_iter = data_iter
self.batch_size = batch_size
self.chunk_size = chunk_size
self.batch_fns = [self._batch_fn()]
self.collate_fn = DataCollatorWithPadding(tokenizer)
def _batch_fn(self):
for book in self.data_iter:
for i in range(0, len(book), self.chunk_size):
yield book[i:i+self.chunk_size]
def __iter__(self) -> 'BooksBatcherIter':
return self
def __next__(self) -> Any:
batch = []
try:
for b in self.batch_fns:
batch.append(next(b))
except StopIteration:
raise StopIteration
return self.collate_fn(batch)
class BooksBatcher:
def __init__(self, dataset, batch_size, tokenizer) -> None:
self.batch_size = batch_size
self.tokenizer = tokenizer
self.dataloader = DataLoader(
dataset=dataset,
batch_size=None, # return raw samples
shuffle=True,
num_workers=2,
prefetch_factor=4
)
def __iter__(self) -> 'BooksBatcherIter':
return BooksBatcherIter(iter(self.dataloader), self.batch_size, self.tokenizer)