Spaces:
Sleeping
Sleeping
File size: 1,204 Bytes
fe3e74d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import os
import numpy as np
from abc import abstractmethod
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
class Txt2ImgIterableBaseDataset(IterableDataset):
'''
Define an interface to make the IterableDatasets for text2img data chainable
'''
def __init__(self, num_records=0, valid_ids=None, size=256):
super().__init__()
self.num_records = num_records
self.valid_ids = valid_ids
self.sample_ids = valid_ids
self.size = size
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
def __len__(self):
return self.num_records
@abstractmethod
def __iter__(self):
pass
class PRNGMixin(object):
"""
Adds a prng property which is a numpy RandomState which gets
reinitialized whenever the pid changes to avoid synchronized sampling
behavior when used in conjunction with multiprocessing.
"""
@property
def prng(self):
currentpid = os.getpid()
if getattr(self, "_initpid", None) != currentpid:
self._initpid = currentpid
self._prng = np.random.RandomState()
return self._prng
|