|
import numpy as np |
|
from torch.utils.data import DataLoader |
|
from torch.utils.data.sampler import SubsetRandomSampler |
|
import torchvision.transforms as transforms |
|
from data_aug.gaussian_blur import GaussianBlur |
|
from torchvision import datasets |
|
import pandas as pd |
|
from PIL import Image |
|
from skimage import io, img_as_ubyte |
|
|
|
np.random.seed(0) |
|
|
|
class Dataset(): |
|
def __init__(self, csv_file, transform=None): |
|
lines = [] |
|
with open(csv_file) as f: |
|
for line in f: |
|
line = line.rstrip().strip() |
|
lines.append(line) |
|
self.files_list = lines |
|
self.transform = transform |
|
def __len__(self): |
|
return len(self.files_list) |
|
def __getitem__(self, idx): |
|
temp_path = self.files_list[idx] |
|
img = Image.open(temp_path) |
|
img = transforms.functional.to_tensor(img) |
|
if self.transform: |
|
sample = self.transform(img) |
|
return sample |
|
|
|
class ToPIL(object): |
|
def __call__(self, sample): |
|
img = sample |
|
img = transforms.functional.to_pil_image(img) |
|
return img |
|
|
|
class DataSetWrapper(object): |
|
|
|
def __init__(self, batch_size, num_workers, valid_size, input_shape, s): |
|
self.batch_size = batch_size |
|
self.num_workers = num_workers |
|
self.valid_size = valid_size |
|
self.s = s |
|
self.input_shape = eval(input_shape) |
|
|
|
def get_data_loaders(self): |
|
data_augment = self._get_simclr_pipeline_transform() |
|
train_dataset = Dataset(csv_file='all_patches.csv', transform=SimCLRDataTransform(data_augment)) |
|
train_loader, valid_loader = self.get_train_validation_data_loaders(train_dataset) |
|
return train_loader, valid_loader |
|
|
|
def _get_simclr_pipeline_transform(self): |
|
|
|
color_jitter = transforms.ColorJitter(0.8 * self.s, 0.8 * self.s, 0.8 * self.s, 0.2 * self.s) |
|
data_transforms = transforms.Compose([ToPIL(), |
|
|
|
transforms.Resize((self.input_shape[0],self.input_shape[1])), |
|
transforms.RandomHorizontalFlip(), |
|
transforms.RandomApply([color_jitter], p=0.8), |
|
transforms.RandomGrayscale(p=0.2), |
|
GaussianBlur(kernel_size=int(0.06 * self.input_shape[0])), |
|
transforms.ToTensor()]) |
|
return data_transforms |
|
|
|
def get_train_validation_data_loaders(self, train_dataset): |
|
|
|
num_train = len(train_dataset) |
|
indices = list(range(num_train)) |
|
np.random.shuffle(indices) |
|
|
|
split = int(np.floor(self.valid_size * num_train)) |
|
train_idx, valid_idx = indices[split:], indices[:split] |
|
|
|
|
|
train_sampler = SubsetRandomSampler(train_idx) |
|
valid_sampler = SubsetRandomSampler(valid_idx) |
|
|
|
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=train_sampler, |
|
num_workers=self.num_workers, drop_last=True, shuffle=False) |
|
valid_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=valid_sampler, |
|
num_workers=self.num_workers, drop_last=True) |
|
return train_loader, valid_loader |
|
|
|
|
|
class SimCLRDataTransform(object): |
|
def __init__(self, transform): |
|
self.transform = transform |
|
|
|
def __call__(self, sample): |
|
xi = self.transform(sample) |
|
xj = self.transform(sample) |
|
return xi, xj |
|
|