Spaces:
Running
on
Zero
Running
on
Zero
# Authors: Hui Ren (rhfeiyang.github.io) | |
import torch | |
import pandas as pd | |
import numpy as np | |
import os | |
from PIL import Image | |
class Caption_set(torch.utils.data.Dataset): | |
style_set_names=[ | |
"andre-derain_subset1", | |
"andy_subset1", | |
"camille-corot_subset1", | |
"gerhard-richter_subset1", | |
"henri-matisse_subset1", | |
"katsushika-hokusai_subset1", | |
"klimt_subset3", | |
"monet_subset2", | |
"picasso_subset1", | |
"van_gogh_subset1", | |
] | |
style_set_map={f"{name}":f"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/Style_captions/{name}/style_captions.csv" for name in style_set_names} | |
def __init__(self, prompts_path=None, set_name=None, transform=None): | |
assert prompts_path is not None or set_name is not None, "Either prompts_path or set_name should be provided" | |
if prompts_path is None: | |
prompts_path = self.style_set_map[set_name] | |
self.prompts = pd.read_csv(prompts_path, delimiter=';') | |
self.transform = transform | |
def __len__(self): | |
return len(self.prompts) | |
def __getitem__(self, idx): | |
ret={} | |
ret["id"] = idx | |
info = self.prompts.iloc[idx] | |
ret.update(info) | |
for k,v in ret.items(): | |
if isinstance(v,np.int64): | |
ret[k] = int(v) | |
ret["caption"] = [ret["caption"]] | |
if self.transform: | |
ret = self.transform(ret) | |
return ret | |
def with_transform(self, transform): | |
self.transform = transform | |
return self | |
class HRS_caption(Caption_set): | |
def __init__(self, prompts_path="/vision-nfs/torralba/projects/jomat/hui/stable_diffusion/clip_dissection/Style_captions/andre-derain_subset1/style_captions.csv", transform=None, delimiter=','): | |
self.prompts = pd.read_csv(prompts_path, delimiter=delimiter) | |
self.transform = transform | |
self.caption_key = "original_prompts" | |
def __getitem__(self, idx): | |
ret={} | |
ret["id"] = idx | |
info = self.prompts.iloc[idx] | |
ret["caption"] = [info[self.caption_key]] | |
ret["seed"] = idx | |
if self.transform: | |
ret = self.transform(ret) | |
return ret | |
class Laion_pop(torch.utils.data.Dataset): | |
def __init__(self, anno_file="/vision-nfs/torralba/projects/jomat/hui/stable_diffusion/custom_datasets/laion_pop500.csv",image_root="/vision-nfs/torralba/scratch/jomat/sam_dataset/laion_pop",transform=None): | |
self.transform = transform | |
self.info = pd.read_csv(anno_file, delimiter=";") | |
self.caption_key = "caption" | |
self.image_root = image_root | |
self.get_img=True | |
self.get_caption=True | |
def __len__(self): | |
return len(self.info) | |
# def subsample(self, num:int): | |
# self.data = self.data.select(range(num)) | |
# return self | |
def load_image(self, key): | |
image_path = os.path.join(self.image_root, f"{key:09}.jpg") | |
with open(image_path, "rb") as f: | |
image = Image.open(f).convert("RGB") | |
return image | |
def __getitem__(self, idx): | |
info = self.info.iloc[idx] | |
ret = {} | |
key = info["key"] | |
ret["id"] = key | |
if self.get_caption: | |
ret["caption"] = [info[self.caption_key]] | |
ret["seed"] = int(key) | |
if self.get_img: | |
ret["image"] = self.load_image(key) | |
if self.transform: | |
ret = self.transform(ret) | |
return ret | |
def with_transform(self, transform): | |
self.transform = transform | |
return self | |
def subset(self, ids:list): | |
self.info = self.info[self.info["key"].isin(ids)] | |
return self | |
if __name__ == "__main__": | |
dataset = Caption_set("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/Style_captions/andre-derain_subset1/style_captions.csv") | |
dataset[0] | |