File size: 8,744 Bytes
262b155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from .mypath import MyPath
from copy import deepcopy
from datasets import load_dataset
from torch.utils.data import Dataset
import numpy as np

def get_dataset(dataset_name, transformation=None , train_subsample:int =None, val_subsample:int = 10000, get_val=True):
    if train_subsample is not None and train_subsample<val_subsample and train_subsample!=-1:
        print(f"Warning: train_subsample is smaller than val_subsample. val_subsample will be set to train_subsample: {train_subsample}")
        val_subsample = train_subsample

    if dataset_name == "imagenet":
        from .imagenet import Imagenet1k
        train_set = Imagenet1k(data_dir = MyPath.db_root_dir(dataset_name), transform = transformation, split="train", prompt_transform=Label_prompt_transform(real=True))
    elif dataset_name == "coco_train":
        # raise NotImplementedError("Use coco_filtered instead")
        from .coco import CocoCaptions
        train_set = CocoCaptions(root=MyPath.db_root_dir("coco_train"), annFile=MyPath.db_root_dir("coco_caption_train"))
    elif dataset_name == "coco_val":
        from .coco import CocoCaptions
        train_set = CocoCaptions(root=MyPath.db_root_dir("coco_val"), annFile=MyPath.db_root_dir("coco_caption_val"))
        return {"val": train_set}

    elif dataset_name == "coco_clip_filtered":
        from .coco import CocoCaptions_clip_filtered
        train_set = CocoCaptions_clip_filtered(root=MyPath.db_root_dir("coco_train"), annFile=MyPath.db_root_dir("coco_caption_train"))
    elif dataset_name == "coco_filtered_sub100":
        from .coco import CocoCaptions_clip_filtered
        train_set = CocoCaptions_clip_filtered(root=MyPath.db_root_dir("coco_train"), annFile=MyPath.db_root_dir("coco_caption_train"), id_file=MyPath.db_root_dir("coco_clip_filtered_ids_sub100"),)
    elif dataset_name == "cifar10":
        from .cifar import CIFAR10
        train_set = CIFAR10(root=MyPath.db_root_dir("cifar10"), train=True, transform=transformation, prompt_transform=Label_prompt_transform(real=True))
    elif dataset_name == "cifar100":
        from .cifar import CIFAR100
        train_set = CIFAR100(root=MyPath.db_root_dir("cifar100"), train=True, transform=transformation, prompt_transform=Label_prompt_transform(real=True))
    elif "wikiart" in dataset_name and "/" not in dataset_name:
        from .wikiart.wikiart import Wikiart_caption
        dataset = Wikiart_caption(data_path=MyPath.db_root_dir(dataset_name))
        return {"train": dataset.subsample(train_subsample).get_dataset(), "val": deepcopy(dataset).subsample(val_subsample).get_dataset() if get_val else None}
    elif "imagepair" in dataset_name:
        from .imagepair import ImagePair
        train_set = ImagePair(folder1=MyPath.db_root_dir(dataset_name)[0], folder2=MyPath.db_root_dir(dataset_name)[1], transform=transformation).subsample(train_subsample)
    # elif dataset_name == "sam_clip_filtered":
    #     from .sam import SamDataset
    #     train_set = SamDataset(image_folder_path=MyPath.db_root_dir("sam_images"), caption_folder_path=MyPath.db_root_dir("sam_captions"), id_file=MyPath.db_root_dir("sam_ids"), transforms=transformation).subsample(train_subsample)
    elif dataset_name == "sam_whole_filtered":
        from .sam import SamDataset
        train_set = SamDataset(image_folder_path=MyPath.db_root_dir("sam_images"), caption_folder_path=MyPath.db_root_dir("sam_captions"), id_file=MyPath.db_root_dir("sam_whole_filtered_ids_train"), id_dict_file=MyPath.db_root_dir("sam_id_dict"), transforms=transformation).subsample(train_subsample)
    elif dataset_name == "sam_whole_filtered_val":
        from .sam import SamDataset
        train_set = SamDataset(image_folder_path=MyPath.db_root_dir("sam_images"), caption_folder_path=MyPath.db_root_dir("sam_captions"), id_file=MyPath.db_root_dir("sam_whole_filtered_ids_val"), id_dict_file=MyPath.db_root_dir("sam_id_dict"), transforms=transformation).subsample(train_subsample)
        return {"val": train_set}
    elif dataset_name == "lhq_sub100":
        from .lhq import LhqDataset
        train_set = LhqDataset(image_folder_path=MyPath.db_root_dir("lhq_images"), caption_folder_path=MyPath.db_root_dir("lhq_captions"), id_file=MyPath.db_root_dir("lhq_ids_sub100"), transforms=transformation)
    elif dataset_name == "lhq_sub500":
        from .lhq import LhqDataset
        train_set = LhqDataset(image_folder_path=MyPath.db_root_dir("lhq_images"), caption_folder_path=MyPath.db_root_dir("lhq_captions"), id_file=MyPath.db_root_dir("lhq_ids_sub500"), transforms=transformation)
    elif dataset_name == "lhq_sub9":
        from .lhq import LhqDataset
        train_set = LhqDataset(image_folder_path=MyPath.db_root_dir("lhq_images"), caption_folder_path=MyPath.db_root_dir("lhq_captions"), id_file=MyPath.db_root_dir("lhq_ids_sub9"), transforms=transformation)

    elif dataset_name == "custom_coco100":
        from .coco import CustomCocoCaptions
        train_set = CustomCocoCaptions(root=MyPath.db_root_dir("coco_val"), annFile=MyPath.db_root_dir("coco_caption_val"),
                           custom_file=MyPath.db_root_dir("custom_coco100_captions"), transforms=transformation)
    elif dataset_name == "custom_coco500":
        from .coco import CustomCocoCaptions
        train_set = CustomCocoCaptions(root=MyPath.db_root_dir("coco_val"), annFile=MyPath.db_root_dir("coco_caption_val"),
                           custom_file=MyPath.db_root_dir("custom_coco500_captions"), transforms=transformation)
    elif dataset_name == "laion_pop500":
        from .custom_caption import Laion_pop
        train_set = Laion_pop(anno_file=MyPath.db_root_dir("laion_pop500"), image_root=MyPath.db_root_dir("laion_images"), transform=transformation)

    elif dataset_name == "laion_pop500_first_sentence":
        from .custom_caption import Laion_pop
        train_set = Laion_pop(anno_file=MyPath.db_root_dir("laion_pop500_first_sentence"), image_root=MyPath.db_root_dir("laion_images"), transform=transformation)


    else:
        try:
            train_set = load_dataset('imagefolder', data_dir = dataset_name, split="train")
            val_set = deepcopy(train_set)
            if val_subsample is not None and val_subsample != -1:
                val_set = val_set.shuffle(seed=0).select(range(val_subsample))
            return {"train": train_set, "val": val_set if get_val else None}
        except:
            raise ValueError(f"dataset_name {dataset_name} not found.")
    return {"train": train_set, "val": deepcopy(train_set).subsample(val_subsample) if get_val else None}


class MergeDataset(Dataset):
    @staticmethod
    def get_merged_dataset(dataset_names:list, transformation=None, train_subsample:int =None, val_subsample:int = 10000):
        train_datasets = []
        val_datasets = []
        for dataset_name in dataset_names:
            datasets = get_dataset(dataset_name, transformation, train_subsample, val_subsample)
            train_datasets.append(datasets["train"])
            val_datasets.append(datasets["val"])
        train_datasets = MergeDataset(train_datasets).subsample(train_subsample)
        val_datasets = MergeDataset(val_datasets).subsample(val_subsample)
        return {"train": train_datasets, "val": val_datasets}

    def __init__(self, datasets:list):
        self.datasets = datasets
        self.column_names = self.datasets[0].column_names
        # self.ids = []
        # start = 0
        # for dataset in self.datasets:
        #     self.ids += [i+start for i in dataset.ids]
    def define_resolution(self, resolution: int):
        for dataset in self.datasets:
            dataset.define_resolution(resolution)

    def __len__(self):
        return sum([len(dataset) for dataset in self.datasets])
    def __getitem__(self, index):
        for i,dataset in enumerate(self.datasets):
            if index < len(dataset):
                ret = dataset[index]
                ret["id"] = index
                ret["dataset"] = i
                return ret
            index -= len(dataset)
        raise IndexError

    def subsample(self, num:int):
        if num is None:
            return self
        dataset_ratio = np.array([len(dataset) for dataset in self.datasets]) / len(self)
        new_datasets = []
        for i, dataset in enumerate(self.datasets):
            new_datasets.append(dataset.subsample(int(num*dataset_ratio[i])))
        return MergeDataset(new_datasets)

    def with_transform(self, transform):
        for dataset in self.datasets:
            dataset.with_transform(transform)
        return self