File size: 5,238 Bytes
262b155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Authors: Hui Ren (rhfeiyang.github.io)
import os
import pickle
import random
import shutil
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image

class LhqDataset(Dataset):
    def __init__(self, image_folder_path:str, caption_folder_path:str, id_file:str = "clip_dissection/lhq/idx/subsample_100.pickle", transforms: transforms = None,

                 get_img=True,

                 get_cap=True,):

        if isinstance(id_file, list):
            self.ids = id_file
        elif isinstance(id_file, str):
            with open(id_file, 'rb') as f:
                print(f"Loading ids from {id_file}", flush=True)
                self.ids = pickle.load(f)
                print(f"Loaded ids from {id_file}", flush=True)
        self.image_folder_path = image_folder_path
        self.caption_folder_path = caption_folder_path
        self.transforms = transforms
        self.column_names = ["image", "text"]
        self.get_img = get_img
        self.get_cap = get_cap

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, index: int):
        id = self.ids[index]
        ret={"id":id}
        if self.get_img:
            image = self._load_image(id)
            ret["image"]=image
        if self.get_cap:
            target = self._load_caption(id)
            ret["caption"]=[target]
        if self.transforms is not None:
            ret = self.transforms(ret)
        return ret

    def _load_image(self, id: int):
        image_path = f"{self.image_folder_path}/{id}.jpg"
        with open(image_path, 'rb') as f:
            img = Image.open(f).convert("RGB")
        return img

    def _load_caption(self, id: int):
        caption_path = f"{self.caption_folder_path}/{id}.txt"
        with open(caption_path, 'r') as f:
            caption_file = f.read()
        caption = []
        for line in caption_file.split("\n"):
            line = line.strip()
            if len(line) > 0:
                caption.append(line)
        return caption

    def subsample(self, n: int = 10000):
        if n is None or n == -1:
            return self
        ori_len = len(self)
        assert n <= ori_len
        # equal interval subsample
        ids = self.ids[::ori_len // n][:n]
        self.ids = ids
        print(f"LHQ dataset subsampled from {ori_len} to {len(self)}")
        return self

    def with_transform(self, transform):
        self.transforms = transform
        return self


def generate_idx(data_folder = "/data/vision/torralba/clip_dissection/huiren/lhq/lhq_1024_jpg/lhq_1024_jpg/", save_path = "/data/vision/torralba/clip_dissection/huiren/lhq/idx/all_ids.pickle"):
    all_ids = os.listdir(data_folder)
    all_ids = [i.split(".")[0] for i in all_ids if i.endswith(".jpg") or i.endswith(".png")]
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    pickle.dump(all_ids, open(f"{save_path}", "wb"))
    print("all_ids generated")
    return all_ids

def random_sample(all_ids, sample_num = 110, save_root = "/data/vision/torralba/clip_dissection/huiren/lhq/subsample"):
    chosen_id = random.sample(all_ids, sample_num)
    save_dir = f"{save_root}/{sample_num}"
    os.makedirs(save_dir, exist_ok=True)
    for id in chosen_id:
        img_path = f"/data/vision/torralba/clip_dissection/huiren/lhq/lhq_1024_jpg/lhq_1024_jpg/{id}.jpg"
        shutil.copy(img_path, save_dir)

    return chosen_id

if __name__ == "__main__":
    # all_ids = generate_idx()
    # with open("/data/vision/torralba/clip_dissection/huiren/lhq/idx/all_ids.pickle", "rb") as f:
    #     all_ids = pickle.load(f)
    # # random_sample(all_ids, 1)
    #
    # # generate_idx(data_folder="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/subsample/100",
    # #              save_path="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/idx/subsample_100.pickle")
    #
    # # lhq 500
    # with open("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/idx/subsample_100.pickle", "rb") as f:
    #     lhq_100_idx = pickle.load(f)
    #
    # extra_idx = set(all_ids) - set(lhq_100_idx)
    # add_idx = random.sample(extra_idx, 400)
    # lhq_500_idx = lhq_100_idx + add_idx
    # with open("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/idx/subsample_500.pickle", "wb") as f:
    #     pickle.dump(lhq_500_idx, f)
    # save_dir = "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/subsample/500"
    # os.makedirs(save_dir, exist_ok=True)
    # for id in lhq_500_idx:
    #     img_path = f"/data/vision/torralba/clip_dissection/huiren/lhq/lhq_1024_jpg/lhq_1024_jpg/{id}.jpg"
    #     # softlink
    #     os.symlink(img_path, os.path.join(save_dir, f"{id}.jpg"))

    # lhq9
    all_ids = generate_idx(data_folder="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/subsample/9",
                           save_path="/data/vision/torralba/clip_dissection/huiren/lhq/idx/subsample_9.pickle")
    print(all_ids)