Spaces:
Running
on
Zero
Running
on
Zero
import os.path | |
from typing import Any, Callable, List, Optional, Tuple | |
from PIL import Image | |
from torchvision.datasets.vision import VisionDataset | |
import pickle | |
import csv | |
import pandas as pd | |
import torch | |
import torchvision | |
import re | |
# from torchvision.datasets import CocoDetection | |
# from utils.clip_filter import Clip_filter | |
from tqdm import tqdm | |
from .mypath import MyPath | |
class CocoDetection(VisionDataset): | |
"""`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset. | |
It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_. | |
Args: | |
root (string): Root directory where images are downloaded to. | |
annFile (string): Path to json annotation file. | |
transform (callable, optional): A function/transform that takes in an PIL image | |
and returns a transformed version. E.g, ``transforms.PILToTensor`` | |
target_transform (callable, optional): A function/transform that takes in the | |
target and transforms it. | |
transforms (callable, optional): A function/transform that takes input sample and its target as entry | |
and returns a transformed version. | |
""" | |
def __init__( | |
self, | |
root: str , | |
annFile: str, | |
transform: Optional[Callable] = None, | |
target_transform: Optional[Callable] = None, | |
transforms: Optional[Callable] = None, | |
get_img=True, | |
get_cap=True | |
) -> None: | |
super().__init__(root, transforms, transform, target_transform) | |
from pycocotools.coco import COCO | |
self.coco = COCO(annFile) | |
self.ids = list(sorted(self.coco.imgs.keys())) | |
self.column_names = ["image", "text"] | |
self.get_img = get_img | |
self.get_cap = get_cap | |
def _load_image(self, id: int) -> Image.Image: | |
path = self.coco.loadImgs(id)[0]["file_name"] | |
with open(os.path.join(self.root, path), 'rb') as f: | |
img = Image.open(f).convert("RGB") | |
return img | |
def _load_target(self, id: int) -> List[Any]: | |
return self.coco.loadAnns(self.coco.getAnnIds(id)) | |
def __getitem__(self, index: int) -> Tuple[Any, Any]: | |
id = self.ids[index] | |
ret={"id":id} | |
if self.get_img: | |
image = self._load_image(id) | |
ret["image"] = image | |
if self.get_cap: | |
target = self._load_target(id) | |
ret["caption"] = [target] | |
if self.transforms is not None: | |
ret = self.transforms(ret) | |
return ret | |
def subsample(self, n: int = 10000): | |
if n is None or n == -1: | |
return self | |
ori_len = len(self) | |
assert n <= ori_len | |
# equal interval subsample | |
ids = self.ids[::ori_len // n][:n] | |
self.ids = ids | |
print(f"COCO dataset subsampled from {ori_len} to {len(self)}") | |
return self | |
def with_transform(self, transform): | |
self.transforms = transform | |
return self | |
def __len__(self) -> int: | |
# return 100 | |
return len(self.ids) | |
class CocoCaptions(CocoDetection): | |
"""`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset. | |
It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_. | |
Args: | |
root (string): Root directory where images are downloaded to. | |
annFile (string): Path to json annotation file. | |
transform (callable, optional): A function/transform that takes in an PIL image | |
and returns a transformed version. E.g, ``transforms.PILToTensor`` | |
target_transform (callable, optional): A function/transform that takes in the | |
target and transforms it. | |
transforms (callable, optional): A function/transform that takes input sample and its target as entry | |
and returns a transformed version. | |
Example: | |
.. code:: python | |
import torchvision.datasets as dset | |
import torchvision.transforms as transforms | |
cap = dset.CocoCaptions(root = 'dir where images are', | |
annFile = 'json annotation file', | |
transform=transforms.PILToTensor()) | |
print('Number of samples: ', len(cap)) | |
img, target = cap[3] # load 4th sample | |
print("Image Size: ", img.size()) | |
print(target) | |
Output: :: | |
Number of samples: 82783 | |
Image Size: (3L, 427L, 640L) | |
[u'A plane emitting smoke stream flying over a mountain.', | |
u'A plane darts across a bright blue sky behind a mountain covered in snow', | |
u'A plane leaves a contrail above the snowy mountain top.', | |
u'A mountain that has a plane flying overheard in the distance.', | |
u'A mountain view with a plume of smoke in the background'] | |
""" | |
def _load_target(self, id: int) -> List[str]: | |
return [ann["caption"] for ann in super()._load_target(id)] | |
class CocoCaptions_clip_filtered(CocoCaptions): | |
positive_prompt=["painting", "drawing", "graffiti",] | |
def __init__( | |
self, | |
root: str , | |
annFile: str, | |
transform: Optional[Callable] = None, | |
target_transform: Optional[Callable] = None, | |
transforms: Optional[Callable] = None, | |
regenerate: bool = False, | |
id_file: Optional[str] = "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/data/coco/coco_clip_filtered_ids.pickle" | |
) -> None: | |
super().__init__(root, annFile, transform, target_transform, transforms) | |
os.makedirs(os.path.dirname(id_file), exist_ok=True) | |
if os.path.exists(id_file) and not regenerate: | |
with open(id_file, "rb") as f: | |
self.ids = pickle.load(f) | |
else: | |
self.ids, naive_filtered_num = self.naive_filter() | |
self.ids, clip_filtered_num = self.clip_filter(0.7) | |
print(f"naive Filtered {naive_filtered_num} images") | |
print(f"Clip Filtered {clip_filtered_num} images") | |
with open(id_file, "wb") as f: | |
pickle.dump(self.ids, f) | |
print(f"Filtered ids saved to {id_file}") | |
print(f"COCO filtered dataset size: {len(self)}") | |
def naive_filter(self, filter_prompt="painting"): | |
new_ids = [] | |
naive_filtered_num = 0 | |
for id in self.ids: | |
target = self._load_target(id) | |
filtered = False | |
for prompt in target: | |
if filter_prompt in prompt.lower(): | |
filtered = True | |
naive_filtered_num += 1 | |
break | |
# if "artwork" in prompt.lower(): | |
# pass | |
if not filtered: | |
new_ids.append(id) | |
return new_ids, naive_filtered_num | |
# def clip_filter(self, threshold=0.7): | |
# | |
# def collate_fn(examples): | |
# # {"image": image, "text": [target], "id":id} | |
# pixel_values = [example["image"] for example in examples] | |
# prompts = [example["text"] for example in examples] | |
# id = [example["id"] for example in examples] | |
# return {"images": pixel_values, "prompts": prompts, "ids": id} | |
# | |
# | |
# clip_filtered_num = 0 | |
# clip_filter = Clip_filter(positive_prompt=self.positive_prompt) | |
# clip_logs={"positive_prompt":clip_filter.positive_prompt, "negative_prompt":clip_filter.negative_prompt, | |
# "ids":torch.Tensor([]),"logits":torch.Tensor([])} | |
# clip_log_file = "data/coco/clip_logs.pth" | |
# new_ids = [] | |
# batch_size = 128 | |
# dataloader = torch.utils.data.DataLoader(self, batch_size=batch_size, num_workers=10, shuffle=False, | |
# collate_fn=collate_fn) | |
# for i, batch in enumerate(tqdm(dataloader)): | |
# images = batch["images"] | |
# filter_result, logits = clip_filter.filter(images, threshold=threshold) | |
# ids = torch.IntTensor(batch["ids"]) | |
# clip_logs["ids"] = torch.cat([clip_logs["ids"], ids]) | |
# clip_logs["logits"] = torch.cat([clip_logs["logits"], logits]) | |
# | |
# new_ids.extend(ids[~filter_result].tolist()) | |
# clip_filtered_num += filter_result.sum().item() | |
# if i % 50 == 0: | |
# torch.save(clip_logs, clip_log_file) | |
# torch.save(clip_logs, clip_log_file) | |
# | |
# return new_ids, clip_filtered_num | |
class CustomCocoCaptions(CocoCaptions): | |
def __init__(self, root: str=MyPath.db_root_dir("coco_val"), annFile: str=MyPath.db_root_dir("coco_caption_val"), custom_file:str="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/jomat-code/filtering/ms_coco_captions_testset100.txt",transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, transforms: Optional[Callable] = None) -> None: | |
super().__init__(root, annFile, transform, target_transform, transforms) | |
self.column_names = ["image", "text"] | |
self.custom_file = custom_file | |
self.load_custom_data(custom_file) | |
self.transforms = transforms | |
def load_custom_data(self, custom_file): | |
self.custom_data = [] | |
with open(custom_file, "r") as f: | |
data = f.readlines() | |
head = data[0].strip().split(",") | |
self.head = head | |
for line in data[1:]: | |
sub_data = line.strip().split(",") | |
if len(sub_data) > len(head): | |
sub_data_new = [sub_data[0]] | |
sub_data_new+=[",".join(sub_data[1:-1])] | |
sub_data_new.append(sub_data[-1]) | |
sub_data = sub_data_new | |
assert len(sub_data) == len(head) | |
self.custom_data.append(sub_data) | |
# to pd | |
self.custom_data = pd.DataFrame(self.custom_data, columns=head) | |
def __len__(self) -> int: | |
return len(self.custom_data) | |
def __getitem__(self, index: int) -> Tuple[Any, Any]: | |
data = self.custom_data.iloc[index] | |
id = int(data["image_id"]) | |
ret={"id":id} | |
if self.get_img: | |
image = self._load_image(id) | |
ret["image"] = image | |
if self.get_cap: | |
caption = data["caption"] | |
ret["caption"] = [caption] | |
ret["seed"] = int(data["random_seed"]) | |
if self.transforms is not None: | |
ret = self.transforms(ret) | |
return ret | |
def get_validation_set(): | |
coco_instance = CocoDetection(root="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/.datasets/coco_2017/train2017/", annFile="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/.datasets/coco_2017/annotations/instances_train2017.json") | |
discard_cat_id = coco_instance.coco.getCatIds(supNms=["person", "animal"]) | |
discard_img_id = [] | |
for cat_id in discard_cat_id: | |
discard_img_id += coco_instance.coco.catToImgs[cat_id] | |
coco_clip_filtered = CocoCaptions_clip_filtered(root=MyPath.db_root_dir("coco_train"), annFile=MyPath.db_root_dir("coco_caption_train"), | |
regenerate=False) | |
coco_clip_filtered_ids = coco_clip_filtered.ids | |
new_ids = set(coco_clip_filtered_ids) - set(discard_img_id) | |
new_ids = list(new_ids) | |
new_ids = random.sample(new_ids, 100) | |
with open("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/data/coco/coco_clip_filtered_subset100.pickle", "wb") as f: | |
pickle.dump(new_ids, f) | |
if __name__ == "__main__": | |
from mypath import MyPath | |
import random | |
# get_validation_set() | |
# coco_filtered_remian_id = pickle.load(open("data/coco/coco_clip_filtered_ids.pickle", "rb")) | |
# | |
# coco_filtered_subset100 = random.sample(coco_filtered_remian_id, 100) | |
# save_path = "data/coco/coco_clip_filtered_subset100.pickle" | |
# with open(save_path, "wb") as f: | |
# pickle.dump(coco_filtered_subset100, f) | |
# dataset = CocoCaptions_clip_filtered(root=MyPath.db_root_dir("coco_train"), annFile=MyPath.db_root_dir("coco_caption_train"), | |
# regenerate=False) | |
dataset = CustomCocoCaptions(root=MyPath.db_root_dir("coco_val"), annFile=MyPath.db_root_dir("coco_caption_val"), | |
custom_file="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/jomat-code/filtering/ms_coco_captions_testset100.txt") | |
dataset[0] | |