Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,744 Bytes
262b155 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from .mypath import MyPath
from copy import deepcopy
from datasets import load_dataset
from torch.utils.data import Dataset
import numpy as np
def get_dataset(dataset_name, transformation=None , train_subsample:int =None, val_subsample:int = 10000, get_val=True):
if train_subsample is not None and train_subsample<val_subsample and train_subsample!=-1:
print(f"Warning: train_subsample is smaller than val_subsample. val_subsample will be set to train_subsample: {train_subsample}")
val_subsample = train_subsample
if dataset_name == "imagenet":
from .imagenet import Imagenet1k
train_set = Imagenet1k(data_dir = MyPath.db_root_dir(dataset_name), transform = transformation, split="train", prompt_transform=Label_prompt_transform(real=True))
elif dataset_name == "coco_train":
# raise NotImplementedError("Use coco_filtered instead")
from .coco import CocoCaptions
train_set = CocoCaptions(root=MyPath.db_root_dir("coco_train"), annFile=MyPath.db_root_dir("coco_caption_train"))
elif dataset_name == "coco_val":
from .coco import CocoCaptions
train_set = CocoCaptions(root=MyPath.db_root_dir("coco_val"), annFile=MyPath.db_root_dir("coco_caption_val"))
return {"val": train_set}
elif dataset_name == "coco_clip_filtered":
from .coco import CocoCaptions_clip_filtered
train_set = CocoCaptions_clip_filtered(root=MyPath.db_root_dir("coco_train"), annFile=MyPath.db_root_dir("coco_caption_train"))
elif dataset_name == "coco_filtered_sub100":
from .coco import CocoCaptions_clip_filtered
train_set = CocoCaptions_clip_filtered(root=MyPath.db_root_dir("coco_train"), annFile=MyPath.db_root_dir("coco_caption_train"), id_file=MyPath.db_root_dir("coco_clip_filtered_ids_sub100"),)
elif dataset_name == "cifar10":
from .cifar import CIFAR10
train_set = CIFAR10(root=MyPath.db_root_dir("cifar10"), train=True, transform=transformation, prompt_transform=Label_prompt_transform(real=True))
elif dataset_name == "cifar100":
from .cifar import CIFAR100
train_set = CIFAR100(root=MyPath.db_root_dir("cifar100"), train=True, transform=transformation, prompt_transform=Label_prompt_transform(real=True))
elif "wikiart" in dataset_name and "/" not in dataset_name:
from .wikiart.wikiart import Wikiart_caption
dataset = Wikiart_caption(data_path=MyPath.db_root_dir(dataset_name))
return {"train": dataset.subsample(train_subsample).get_dataset(), "val": deepcopy(dataset).subsample(val_subsample).get_dataset() if get_val else None}
elif "imagepair" in dataset_name:
from .imagepair import ImagePair
train_set = ImagePair(folder1=MyPath.db_root_dir(dataset_name)[0], folder2=MyPath.db_root_dir(dataset_name)[1], transform=transformation).subsample(train_subsample)
# elif dataset_name == "sam_clip_filtered":
# from .sam import SamDataset
# train_set = SamDataset(image_folder_path=MyPath.db_root_dir("sam_images"), caption_folder_path=MyPath.db_root_dir("sam_captions"), id_file=MyPath.db_root_dir("sam_ids"), transforms=transformation).subsample(train_subsample)
elif dataset_name == "sam_whole_filtered":
from .sam import SamDataset
train_set = SamDataset(image_folder_path=MyPath.db_root_dir("sam_images"), caption_folder_path=MyPath.db_root_dir("sam_captions"), id_file=MyPath.db_root_dir("sam_whole_filtered_ids_train"), id_dict_file=MyPath.db_root_dir("sam_id_dict"), transforms=transformation).subsample(train_subsample)
elif dataset_name == "sam_whole_filtered_val":
from .sam import SamDataset
train_set = SamDataset(image_folder_path=MyPath.db_root_dir("sam_images"), caption_folder_path=MyPath.db_root_dir("sam_captions"), id_file=MyPath.db_root_dir("sam_whole_filtered_ids_val"), id_dict_file=MyPath.db_root_dir("sam_id_dict"), transforms=transformation).subsample(train_subsample)
return {"val": train_set}
elif dataset_name == "lhq_sub100":
from .lhq import LhqDataset
train_set = LhqDataset(image_folder_path=MyPath.db_root_dir("lhq_images"), caption_folder_path=MyPath.db_root_dir("lhq_captions"), id_file=MyPath.db_root_dir("lhq_ids_sub100"), transforms=transformation)
elif dataset_name == "lhq_sub500":
from .lhq import LhqDataset
train_set = LhqDataset(image_folder_path=MyPath.db_root_dir("lhq_images"), caption_folder_path=MyPath.db_root_dir("lhq_captions"), id_file=MyPath.db_root_dir("lhq_ids_sub500"), transforms=transformation)
elif dataset_name == "lhq_sub9":
from .lhq import LhqDataset
train_set = LhqDataset(image_folder_path=MyPath.db_root_dir("lhq_images"), caption_folder_path=MyPath.db_root_dir("lhq_captions"), id_file=MyPath.db_root_dir("lhq_ids_sub9"), transforms=transformation)
elif dataset_name == "custom_coco100":
from .coco import CustomCocoCaptions
train_set = CustomCocoCaptions(root=MyPath.db_root_dir("coco_val"), annFile=MyPath.db_root_dir("coco_caption_val"),
custom_file=MyPath.db_root_dir("custom_coco100_captions"), transforms=transformation)
elif dataset_name == "custom_coco500":
from .coco import CustomCocoCaptions
train_set = CustomCocoCaptions(root=MyPath.db_root_dir("coco_val"), annFile=MyPath.db_root_dir("coco_caption_val"),
custom_file=MyPath.db_root_dir("custom_coco500_captions"), transforms=transformation)
elif dataset_name == "laion_pop500":
from .custom_caption import Laion_pop
train_set = Laion_pop(anno_file=MyPath.db_root_dir("laion_pop500"), image_root=MyPath.db_root_dir("laion_images"), transform=transformation)
elif dataset_name == "laion_pop500_first_sentence":
from .custom_caption import Laion_pop
train_set = Laion_pop(anno_file=MyPath.db_root_dir("laion_pop500_first_sentence"), image_root=MyPath.db_root_dir("laion_images"), transform=transformation)
else:
try:
train_set = load_dataset('imagefolder', data_dir = dataset_name, split="train")
val_set = deepcopy(train_set)
if val_subsample is not None and val_subsample != -1:
val_set = val_set.shuffle(seed=0).select(range(val_subsample))
return {"train": train_set, "val": val_set if get_val else None}
except:
raise ValueError(f"dataset_name {dataset_name} not found.")
return {"train": train_set, "val": deepcopy(train_set).subsample(val_subsample) if get_val else None}
class MergeDataset(Dataset):
@staticmethod
def get_merged_dataset(dataset_names:list, transformation=None, train_subsample:int =None, val_subsample:int = 10000):
train_datasets = []
val_datasets = []
for dataset_name in dataset_names:
datasets = get_dataset(dataset_name, transformation, train_subsample, val_subsample)
train_datasets.append(datasets["train"])
val_datasets.append(datasets["val"])
train_datasets = MergeDataset(train_datasets).subsample(train_subsample)
val_datasets = MergeDataset(val_datasets).subsample(val_subsample)
return {"train": train_datasets, "val": val_datasets}
def __init__(self, datasets:list):
self.datasets = datasets
self.column_names = self.datasets[0].column_names
# self.ids = []
# start = 0
# for dataset in self.datasets:
# self.ids += [i+start for i in dataset.ids]
def define_resolution(self, resolution: int):
for dataset in self.datasets:
dataset.define_resolution(resolution)
def __len__(self):
return sum([len(dataset) for dataset in self.datasets])
def __getitem__(self, index):
for i,dataset in enumerate(self.datasets):
if index < len(dataset):
ret = dataset[index]
ret["id"] = index
ret["dataset"] = i
return ret
index -= len(dataset)
raise IndexError
def subsample(self, num:int):
if num is None:
return self
dataset_ratio = np.array([len(dataset) for dataset in self.datasets]) / len(self)
new_datasets = []
for i, dataset in enumerate(self.datasets):
new_datasets.append(dataset.subsample(int(num*dataset_ratio[i])))
return MergeDataset(new_datasets)
def with_transform(self, transform):
for dataset in self.datasets:
dataset.with_transform(transform)
return self
|