Spaces:
Runtime error
Runtime error
File size: 1,434 Bytes
6eb1d7d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
import random
from collections import deque
from typing import Any, Collection, Deque, Iterable, Iterator, List, Sequence
Loader = Iterable[Any]
def _pooled_next(iterator: Iterator[Any], pool: Deque[Any]):
if not pool:
pool.extend(next(iterator))
return pool.popleft()
class CombinedDataLoader:
"""
Combines data loaders using the provided sampling ratios
"""
BATCH_COUNT = 100
def __init__(self, loaders: Collection[Loader], batch_size: int, ratios: Sequence[float]):
self.loaders = loaders
self.batch_size = batch_size
self.ratios = ratios
def __iter__(self) -> Iterator[List[Any]]:
iters = [iter(loader) for loader in self.loaders]
indices = []
pool = [deque()] * len(iters)
# infinite iterator, as in D2
while True:
if not indices:
# just a buffer of indices, its size doesn't matter
# as long as it's a multiple of batch_size
k = self.batch_size * self.BATCH_COUNT
indices = random.choices(range(len(self.loaders)), self.ratios, k=k)
try:
batch = [_pooled_next(iters[i], pool[i]) for i in indices[: self.batch_size]]
except StopIteration:
break
indices = indices[self.batch_size :]
yield batch
|