AioMedica / feature_extractor /data_aug /dataset_wrapper.py
chris1nexus
First commit
d60982d
raw
history blame
3.94 kB
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision.transforms as transforms
from data_aug.gaussian_blur import GaussianBlur
from torchvision import datasets
import pandas as pd
from PIL import Image
from skimage import io, img_as_ubyte
np.random.seed(0)
class Dataset():
def __init__(self, csv_file, transform=None):
lines = []
with open(csv_file) as f:
for line in f:
line = line.rstrip().strip()
lines.append(line)
self.files_list = lines#pd.read_csv(csv_file)
self.transform = transform
def __len__(self):
return len(self.files_list)
def __getitem__(self, idx):
temp_path = self.files_list[idx]# self.files_list.iloc[idx, 0]
img = Image.open(temp_path)
img = transforms.functional.to_tensor(img)
if self.transform:
sample = self.transform(img)
return sample
class ToPIL(object):
def __call__(self, sample):
img = sample
img = transforms.functional.to_pil_image(img)
return img
class DataSetWrapper(object):
def __init__(self, batch_size, num_workers, valid_size, input_shape, s):
self.batch_size = batch_size
self.num_workers = num_workers
self.valid_size = valid_size
self.s = s
self.input_shape = eval(input_shape)
def get_data_loaders(self):
data_augment = self._get_simclr_pipeline_transform()
train_dataset = Dataset(csv_file='all_patches.csv', transform=SimCLRDataTransform(data_augment))
train_loader, valid_loader = self.get_train_validation_data_loaders(train_dataset)
return train_loader, valid_loader
def _get_simclr_pipeline_transform(self):
# get a set of data augmentation transformations as described in the SimCLR paper.
color_jitter = transforms.ColorJitter(0.8 * self.s, 0.8 * self.s, 0.8 * self.s, 0.2 * self.s)
data_transforms = transforms.Compose([ToPIL(),
# transforms.RandomResizedCrop(size=self.input_shape[0]),
transforms.Resize((self.input_shape[0],self.input_shape[1])),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([color_jitter], p=0.8),
transforms.RandomGrayscale(p=0.2),
GaussianBlur(kernel_size=int(0.06 * self.input_shape[0])),
transforms.ToTensor()])
return data_transforms
def get_train_validation_data_loaders(self, train_dataset):
# obtain training indices that will be used for validation
num_train = len(train_dataset)
indices = list(range(num_train))
np.random.shuffle(indices)
split = int(np.floor(self.valid_size * num_train))
train_idx, valid_idx = indices[split:], indices[:split]
# define samplers for obtaining training and validation batches
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=train_sampler,
num_workers=self.num_workers, drop_last=True, shuffle=False)
valid_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=valid_sampler,
num_workers=self.num_workers, drop_last=True)
return train_loader, valid_loader
class SimCLRDataTransform(object):
def __init__(self, transform):
self.transform = transform
def __call__(self, sample):
xi = self.transform(sample)
xj = self.transform(sample)
return xi, xj