File size: 4,019 Bytes
262b155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Authors: Hui Ren (rhfeiyang.github.io)
import torch
import pandas as pd
import numpy as np
import os
from PIL import Image

class Caption_set(torch.utils.data.Dataset):

    style_set_names=[
        "andre-derain_subset1",
        "andy_subset1",
        "camille-corot_subset1",
        "gerhard-richter_subset1",
        "henri-matisse_subset1",
        "katsushika-hokusai_subset1",
        "klimt_subset3",
        "monet_subset2",
        "picasso_subset1",
        "van_gogh_subset1",
    ]
    style_set_map={f"{name}":f"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/Style_captions/{name}/style_captions.csv" for name in style_set_names}

    def __init__(self, prompts_path=None, set_name=None, transform=None):
        assert prompts_path is not None or set_name is not None, "Either prompts_path or set_name should be provided"
        if prompts_path is None:
            prompts_path = self.style_set_map[set_name]

        self.prompts = pd.read_csv(prompts_path, delimiter=';')
        self.transform = transform
    def __len__(self):
        return len(self.prompts)
    def __getitem__(self, idx):
        ret={}
        ret["id"] = idx
        info = self.prompts.iloc[idx]
        ret.update(info)
        for k,v in ret.items():
            if isinstance(v,np.int64):
                ret[k] = int(v)
        ret["caption"] = [ret["caption"]]
        if self.transform:
            ret = self.transform(ret)
        return ret

    def with_transform(self, transform):
        self.transform = transform
        return self


class HRS_caption(Caption_set):
    def __init__(self, prompts_path="/vision-nfs/torralba/projects/jomat/hui/stable_diffusion/clip_dissection/Style_captions/andre-derain_subset1/style_captions.csv", transform=None, delimiter=','):
        self.prompts = pd.read_csv(prompts_path, delimiter=delimiter)
        self.transform = transform
        self.caption_key = "original_prompts"

    def __getitem__(self, idx):
        ret={}
        ret["id"] = idx
        info = self.prompts.iloc[idx]
        ret["caption"] = [info[self.caption_key]]
        ret["seed"] = idx
        if self.transform:
            ret = self.transform(ret)
        return ret

class Laion_pop(torch.utils.data.Dataset):
    def __init__(self, anno_file="/vision-nfs/torralba/projects/jomat/hui/stable_diffusion/custom_datasets/laion_pop500.csv",image_root="/vision-nfs/torralba/scratch/jomat/sam_dataset/laion_pop",transform=None):
        self.transform = transform
        self.info = pd.read_csv(anno_file, delimiter=";")
        self.caption_key = "caption"
        self.image_root = image_root
        self.get_img=True
        self.get_caption=True
    def __len__(self):
        return len(self.info)

    # def subsample(self, num:int):
    #     self.data = self.data.select(range(num))
    #     return self

    def load_image(self, key):
        image_path = os.path.join(self.image_root, f"{key:09}.jpg")
        with open(image_path, "rb") as f:
            image = Image.open(f).convert("RGB")
        return image

    def __getitem__(self, idx):
        info = self.info.iloc[idx]
        ret = {}
        key = info["key"]
        ret["id"] = key
        if self.get_caption:
            ret["caption"] = [info[self.caption_key]]
        ret["seed"] = int(key)
        if self.get_img:
            ret["image"] = self.load_image(key)

        if self.transform:
            ret = self.transform(ret)
        return ret

    def with_transform(self, transform):
        self.transform = transform
        return self

    def subset(self, ids:list):
        self.info = self.info[self.info["key"].isin(ids)]
        return self

if __name__ == "__main__":
    dataset = Caption_set("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/Style_captions/andre-derain_subset1/style_captions.csv")
    dataset[0]