Spaces:
Running
on
Zero
Running
on
Zero
from .mypath import MyPath | |
from copy import deepcopy | |
from datasets import load_dataset | |
from torch.utils.data import Dataset | |
import numpy as np | |
def get_dataset(dataset_name, transformation=None , train_subsample:int =None, val_subsample:int = 10000, get_val=True): | |
if train_subsample is not None and train_subsample<val_subsample and train_subsample!=-1: | |
print(f"Warning: train_subsample is smaller than val_subsample. val_subsample will be set to train_subsample: {train_subsample}") | |
val_subsample = train_subsample | |
if dataset_name == "imagenet": | |
from .imagenet import Imagenet1k | |
train_set = Imagenet1k(data_dir = MyPath.db_root_dir(dataset_name), transform = transformation, split="train", prompt_transform=Label_prompt_transform(real=True)) | |
elif dataset_name == "coco_train": | |
# raise NotImplementedError("Use coco_filtered instead") | |
from .coco import CocoCaptions | |
train_set = CocoCaptions(root=MyPath.db_root_dir("coco_train"), annFile=MyPath.db_root_dir("coco_caption_train")) | |
elif dataset_name == "coco_val": | |
from .coco import CocoCaptions | |
train_set = CocoCaptions(root=MyPath.db_root_dir("coco_val"), annFile=MyPath.db_root_dir("coco_caption_val")) | |
return {"val": train_set} | |
elif dataset_name == "coco_clip_filtered": | |
from .coco import CocoCaptions_clip_filtered | |
train_set = CocoCaptions_clip_filtered(root=MyPath.db_root_dir("coco_train"), annFile=MyPath.db_root_dir("coco_caption_train")) | |
elif dataset_name == "coco_filtered_sub100": | |
from .coco import CocoCaptions_clip_filtered | |
train_set = CocoCaptions_clip_filtered(root=MyPath.db_root_dir("coco_train"), annFile=MyPath.db_root_dir("coco_caption_train"), id_file=MyPath.db_root_dir("coco_clip_filtered_ids_sub100"),) | |
elif dataset_name == "cifar10": | |
from .cifar import CIFAR10 | |
train_set = CIFAR10(root=MyPath.db_root_dir("cifar10"), train=True, transform=transformation, prompt_transform=Label_prompt_transform(real=True)) | |
elif dataset_name == "cifar100": | |
from .cifar import CIFAR100 | |
train_set = CIFAR100(root=MyPath.db_root_dir("cifar100"), train=True, transform=transformation, prompt_transform=Label_prompt_transform(real=True)) | |
elif "wikiart" in dataset_name and "/" not in dataset_name: | |
from .wikiart.wikiart import Wikiart_caption | |
dataset = Wikiart_caption(data_path=MyPath.db_root_dir(dataset_name)) | |
return {"train": dataset.subsample(train_subsample).get_dataset(), "val": deepcopy(dataset).subsample(val_subsample).get_dataset() if get_val else None} | |
elif "imagepair" in dataset_name: | |
from .imagepair import ImagePair | |
train_set = ImagePair(folder1=MyPath.db_root_dir(dataset_name)[0], folder2=MyPath.db_root_dir(dataset_name)[1], transform=transformation).subsample(train_subsample) | |
# elif dataset_name == "sam_clip_filtered": | |
# from .sam import SamDataset | |
# train_set = SamDataset(image_folder_path=MyPath.db_root_dir("sam_images"), caption_folder_path=MyPath.db_root_dir("sam_captions"), id_file=MyPath.db_root_dir("sam_ids"), transforms=transformation).subsample(train_subsample) | |
elif dataset_name == "sam_whole_filtered": | |
from .sam import SamDataset | |
train_set = SamDataset(image_folder_path=MyPath.db_root_dir("sam_images"), caption_folder_path=MyPath.db_root_dir("sam_captions"), id_file=MyPath.db_root_dir("sam_whole_filtered_ids_train"), id_dict_file=MyPath.db_root_dir("sam_id_dict"), transforms=transformation).subsample(train_subsample) | |
elif dataset_name == "sam_whole_filtered_val": | |
from .sam import SamDataset | |
train_set = SamDataset(image_folder_path=MyPath.db_root_dir("sam_images"), caption_folder_path=MyPath.db_root_dir("sam_captions"), id_file=MyPath.db_root_dir("sam_whole_filtered_ids_val"), id_dict_file=MyPath.db_root_dir("sam_id_dict"), transforms=transformation).subsample(train_subsample) | |
return {"val": train_set} | |
elif dataset_name == "lhq_sub100": | |
from .lhq import LhqDataset | |
train_set = LhqDataset(image_folder_path=MyPath.db_root_dir("lhq_images"), caption_folder_path=MyPath.db_root_dir("lhq_captions"), id_file=MyPath.db_root_dir("lhq_ids_sub100"), transforms=transformation) | |
elif dataset_name == "lhq_sub500": | |
from .lhq import LhqDataset | |
train_set = LhqDataset(image_folder_path=MyPath.db_root_dir("lhq_images"), caption_folder_path=MyPath.db_root_dir("lhq_captions"), id_file=MyPath.db_root_dir("lhq_ids_sub500"), transforms=transformation) | |
elif dataset_name == "lhq_sub9": | |
from .lhq import LhqDataset | |
train_set = LhqDataset(image_folder_path=MyPath.db_root_dir("lhq_images"), caption_folder_path=MyPath.db_root_dir("lhq_captions"), id_file=MyPath.db_root_dir("lhq_ids_sub9"), transforms=transformation) | |
elif dataset_name == "custom_coco100": | |
from .coco import CustomCocoCaptions | |
train_set = CustomCocoCaptions(root=MyPath.db_root_dir("coco_val"), annFile=MyPath.db_root_dir("coco_caption_val"), | |
custom_file=MyPath.db_root_dir("custom_coco100_captions"), transforms=transformation) | |
elif dataset_name == "custom_coco500": | |
from .coco import CustomCocoCaptions | |
train_set = CustomCocoCaptions(root=MyPath.db_root_dir("coco_val"), annFile=MyPath.db_root_dir("coco_caption_val"), | |
custom_file=MyPath.db_root_dir("custom_coco500_captions"), transforms=transformation) | |
elif dataset_name == "laion_pop500": | |
from .custom_caption import Laion_pop | |
train_set = Laion_pop(anno_file=MyPath.db_root_dir("laion_pop500"), image_root=MyPath.db_root_dir("laion_images"), transform=transformation) | |
elif dataset_name == "laion_pop500_first_sentence": | |
from .custom_caption import Laion_pop | |
train_set = Laion_pop(anno_file=MyPath.db_root_dir("laion_pop500_first_sentence"), image_root=MyPath.db_root_dir("laion_images"), transform=transformation) | |
else: | |
try: | |
train_set = load_dataset('imagefolder', data_dir = dataset_name, split="train") | |
val_set = deepcopy(train_set) | |
if val_subsample is not None and val_subsample != -1: | |
val_set = val_set.shuffle(seed=0).select(range(val_subsample)) | |
return {"train": train_set, "val": val_set if get_val else None} | |
except: | |
raise ValueError(f"dataset_name {dataset_name} not found.") | |
return {"train": train_set, "val": deepcopy(train_set).subsample(val_subsample) if get_val else None} | |
class MergeDataset(Dataset): | |
def get_merged_dataset(dataset_names:list, transformation=None, train_subsample:int =None, val_subsample:int = 10000): | |
train_datasets = [] | |
val_datasets = [] | |
for dataset_name in dataset_names: | |
datasets = get_dataset(dataset_name, transformation, train_subsample, val_subsample) | |
train_datasets.append(datasets["train"]) | |
val_datasets.append(datasets["val"]) | |
train_datasets = MergeDataset(train_datasets).subsample(train_subsample) | |
val_datasets = MergeDataset(val_datasets).subsample(val_subsample) | |
return {"train": train_datasets, "val": val_datasets} | |
def __init__(self, datasets:list): | |
self.datasets = datasets | |
self.column_names = self.datasets[0].column_names | |
# self.ids = [] | |
# start = 0 | |
# for dataset in self.datasets: | |
# self.ids += [i+start for i in dataset.ids] | |
def define_resolution(self, resolution: int): | |
for dataset in self.datasets: | |
dataset.define_resolution(resolution) | |
def __len__(self): | |
return sum([len(dataset) for dataset in self.datasets]) | |
def __getitem__(self, index): | |
for i,dataset in enumerate(self.datasets): | |
if index < len(dataset): | |
ret = dataset[index] | |
ret["id"] = index | |
ret["dataset"] = i | |
return ret | |
index -= len(dataset) | |
raise IndexError | |
def subsample(self, num:int): | |
if num is None: | |
return self | |
dataset_ratio = np.array([len(dataset) for dataset in self.datasets]) / len(self) | |
new_datasets = [] | |
for i, dataset in enumerate(self.datasets): | |
new_datasets.append(dataset.subsample(int(num*dataset_ratio[i]))) | |
return MergeDataset(new_datasets) | |
def with_transform(self, transform): | |
for dataset in self.datasets: | |
dataset.with_transform(transform) | |
return self | |