# Authors: Hui Ren (rhfeiyang.github.io) import os import pickle import random import shutil from torch.utils.data import Dataset from torchvision import transforms from PIL import Image class LhqDataset(Dataset): def __init__(self, image_folder_path:str, caption_folder_path:str, id_file:str = "clip_dissection/lhq/idx/subsample_100.pickle", transforms: transforms = None, get_img=True, get_cap=True,): if isinstance(id_file, list): self.ids = id_file elif isinstance(id_file, str): with open(id_file, 'rb') as f: print(f"Loading ids from {id_file}", flush=True) self.ids = pickle.load(f) print(f"Loaded ids from {id_file}", flush=True) self.image_folder_path = image_folder_path self.caption_folder_path = caption_folder_path self.transforms = transforms self.column_names = ["image", "text"] self.get_img = get_img self.get_cap = get_cap def __len__(self): return len(self.ids) def __getitem__(self, index: int): id = self.ids[index] ret={"id":id} if self.get_img: image = self._load_image(id) ret["image"]=image if self.get_cap: target = self._load_caption(id) ret["caption"]=[target] if self.transforms is not None: ret = self.transforms(ret) return ret def _load_image(self, id: int): image_path = f"{self.image_folder_path}/{id}.jpg" with open(image_path, 'rb') as f: img = Image.open(f).convert("RGB") return img def _load_caption(self, id: int): caption_path = f"{self.caption_folder_path}/{id}.txt" with open(caption_path, 'r') as f: caption_file = f.read() caption = [] for line in caption_file.split("\n"): line = line.strip() if len(line) > 0: caption.append(line) return caption def subsample(self, n: int = 10000): if n is None or n == -1: return self ori_len = len(self) assert n <= ori_len # equal interval subsample ids = self.ids[::ori_len // n][:n] self.ids = ids print(f"LHQ dataset subsampled from {ori_len} to {len(self)}") return self def with_transform(self, transform): self.transforms = transform return self def generate_idx(data_folder = "/data/vision/torralba/clip_dissection/huiren/lhq/lhq_1024_jpg/lhq_1024_jpg/", save_path = "/data/vision/torralba/clip_dissection/huiren/lhq/idx/all_ids.pickle"): all_ids = os.listdir(data_folder) all_ids = [i.split(".")[0] for i in all_ids if i.endswith(".jpg") or i.endswith(".png")] os.makedirs(os.path.dirname(save_path), exist_ok=True) pickle.dump(all_ids, open(f"{save_path}", "wb")) print("all_ids generated") return all_ids def random_sample(all_ids, sample_num = 110, save_root = "/data/vision/torralba/clip_dissection/huiren/lhq/subsample"): chosen_id = random.sample(all_ids, sample_num) save_dir = f"{save_root}/{sample_num}" os.makedirs(save_dir, exist_ok=True) for id in chosen_id: img_path = f"/data/vision/torralba/clip_dissection/huiren/lhq/lhq_1024_jpg/lhq_1024_jpg/{id}.jpg" shutil.copy(img_path, save_dir) return chosen_id if __name__ == "__main__": # all_ids = generate_idx() # with open("/data/vision/torralba/clip_dissection/huiren/lhq/idx/all_ids.pickle", "rb") as f: # all_ids = pickle.load(f) # # random_sample(all_ids, 1) # # # generate_idx(data_folder="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/subsample/100", # # save_path="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/idx/subsample_100.pickle") # # # lhq 500 # with open("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/idx/subsample_100.pickle", "rb") as f: # lhq_100_idx = pickle.load(f) # # extra_idx = set(all_ids) - set(lhq_100_idx) # add_idx = random.sample(extra_idx, 400) # lhq_500_idx = lhq_100_idx + add_idx # with open("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/idx/subsample_500.pickle", "wb") as f: # pickle.dump(lhq_500_idx, f) # save_dir = "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/subsample/500" # os.makedirs(save_dir, exist_ok=True) # for id in lhq_500_idx: # img_path = f"/data/vision/torralba/clip_dissection/huiren/lhq/lhq_1024_jpg/lhq_1024_jpg/{id}.jpg" # # softlink # os.symlink(img_path, os.path.join(save_dir, f"{id}.jpg")) # lhq9 all_ids = generate_idx(data_folder="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/subsample/9", save_path="/data/vision/torralba/clip_dissection/huiren/lhq/idx/subsample_9.pickle") print(all_ids)