Spaces:
Running
on
Zero
Running
on
Zero
# Authors: Hui Ren (rhfeiyang.github.io) | |
import os | |
import pickle | |
import random | |
import shutil | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
from PIL import Image | |
class LhqDataset(Dataset): | |
def __init__(self, image_folder_path:str, caption_folder_path:str, id_file:str = "clip_dissection/lhq/idx/subsample_100.pickle", transforms: transforms = None, | |
get_img=True, | |
get_cap=True,): | |
if isinstance(id_file, list): | |
self.ids = id_file | |
elif isinstance(id_file, str): | |
with open(id_file, 'rb') as f: | |
print(f"Loading ids from {id_file}", flush=True) | |
self.ids = pickle.load(f) | |
print(f"Loaded ids from {id_file}", flush=True) | |
self.image_folder_path = image_folder_path | |
self.caption_folder_path = caption_folder_path | |
self.transforms = transforms | |
self.column_names = ["image", "text"] | |
self.get_img = get_img | |
self.get_cap = get_cap | |
def __len__(self): | |
return len(self.ids) | |
def __getitem__(self, index: int): | |
id = self.ids[index] | |
ret={"id":id} | |
if self.get_img: | |
image = self._load_image(id) | |
ret["image"]=image | |
if self.get_cap: | |
target = self._load_caption(id) | |
ret["caption"]=[target] | |
if self.transforms is not None: | |
ret = self.transforms(ret) | |
return ret | |
def _load_image(self, id: int): | |
image_path = f"{self.image_folder_path}/{id}.jpg" | |
with open(image_path, 'rb') as f: | |
img = Image.open(f).convert("RGB") | |
return img | |
def _load_caption(self, id: int): | |
caption_path = f"{self.caption_folder_path}/{id}.txt" | |
with open(caption_path, 'r') as f: | |
caption_file = f.read() | |
caption = [] | |
for line in caption_file.split("\n"): | |
line = line.strip() | |
if len(line) > 0: | |
caption.append(line) | |
return caption | |
def subsample(self, n: int = 10000): | |
if n is None or n == -1: | |
return self | |
ori_len = len(self) | |
assert n <= ori_len | |
# equal interval subsample | |
ids = self.ids[::ori_len // n][:n] | |
self.ids = ids | |
print(f"LHQ dataset subsampled from {ori_len} to {len(self)}") | |
return self | |
def with_transform(self, transform): | |
self.transforms = transform | |
return self | |
def generate_idx(data_folder = "/data/vision/torralba/clip_dissection/huiren/lhq/lhq_1024_jpg/lhq_1024_jpg/", save_path = "/data/vision/torralba/clip_dissection/huiren/lhq/idx/all_ids.pickle"): | |
all_ids = os.listdir(data_folder) | |
all_ids = [i.split(".")[0] for i in all_ids if i.endswith(".jpg") or i.endswith(".png")] | |
os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
pickle.dump(all_ids, open(f"{save_path}", "wb")) | |
print("all_ids generated") | |
return all_ids | |
def random_sample(all_ids, sample_num = 110, save_root = "/data/vision/torralba/clip_dissection/huiren/lhq/subsample"): | |
chosen_id = random.sample(all_ids, sample_num) | |
save_dir = f"{save_root}/{sample_num}" | |
os.makedirs(save_dir, exist_ok=True) | |
for id in chosen_id: | |
img_path = f"/data/vision/torralba/clip_dissection/huiren/lhq/lhq_1024_jpg/lhq_1024_jpg/{id}.jpg" | |
shutil.copy(img_path, save_dir) | |
return chosen_id | |
if __name__ == "__main__": | |
# all_ids = generate_idx() | |
# with open("/data/vision/torralba/clip_dissection/huiren/lhq/idx/all_ids.pickle", "rb") as f: | |
# all_ids = pickle.load(f) | |
# # random_sample(all_ids, 1) | |
# | |
# # generate_idx(data_folder="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/subsample/100", | |
# # save_path="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/idx/subsample_100.pickle") | |
# | |
# # lhq 500 | |
# with open("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/idx/subsample_100.pickle", "rb") as f: | |
# lhq_100_idx = pickle.load(f) | |
# | |
# extra_idx = set(all_ids) - set(lhq_100_idx) | |
# add_idx = random.sample(extra_idx, 400) | |
# lhq_500_idx = lhq_100_idx + add_idx | |
# with open("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/idx/subsample_500.pickle", "wb") as f: | |
# pickle.dump(lhq_500_idx, f) | |
# save_dir = "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/subsample/500" | |
# os.makedirs(save_dir, exist_ok=True) | |
# for id in lhq_500_idx: | |
# img_path = f"/data/vision/torralba/clip_dissection/huiren/lhq/lhq_1024_jpg/lhq_1024_jpg/{id}.jpg" | |
# # softlink | |
# os.symlink(img_path, os.path.join(save_dir, f"{id}.jpg")) | |
# lhq9 | |
all_ids = generate_idx(data_folder="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/subsample/9", | |
save_path="/data/vision/torralba/clip_dissection/huiren/lhq/idx/subsample_9.pickle") | |
print(all_ids) | |