# Authors: Hui Ren (rhfeiyang.github.io) import os.path import sys from typing import Any, Callable, List, Optional, Tuple import tqdm from PIL import Image from torch.utils.data import Dataset import pickle from torchvision import transforms # import torch # import torchvision # import re class SamDataset(Dataset): def __init__(self, image_folder_path:str, caption_folder_path:str, id_file:str = "data/sam/clip_filtered_ids.pickle",id_dict_file:str =None , transforms: Optional[Callable] = None, resolution=None, get_img=True, get_cap=True,): if id_dict_file is not None: with open(id_dict_file, 'rb') as f: print(f"Loading id_dict from {id_dict_file}", flush=True) self.id_dict = pickle.load(f) print(f"Loaded id_dict from {id_dict_file}", flush=True) else: self.id_dict = None if isinstance(id_file, list): self.ids = id_file elif isinstance(id_file, str): with open(id_file, 'rb') as f: print(f"Loading ids from {id_file}", flush=True) self.ids = pickle.load(f) print(f"Loaded ids from {id_file}", flush=True) self.resolution = resolution self.ori_image_folder_path = image_folder_path if self.resolution is not None: if os.path.exists("/var/jomat/datasets/"): # self.image_folder_path = f"/var/jomat/datasets/SAM_{resolution}" self.image_folder_path = f"{image_folder_path}_{resolution}" else: self.image_folder_path = f"{image_folder_path}_{resolution}" os.makedirs(self.image_folder_path, exist_ok=True) else: self.image_folder_path = image_folder_path self.caption_folder_path = caption_folder_path self.transforms = transforms self.column_names = ["image", "text"] self.get_img = get_img self.get_cap = get_cap def __len__(self): # return 100 return len(self.ids) def __getitem__(self, index: int): id = self.ids[index] ret={"id":id} try: # if index == 1: # raise Exception("test") if self.get_img: image = self._load_image(id) ret["image"]=image if self.get_cap: target = self._load_caption(id) ret["text"] = [target] if self.transforms is not None: ret = self.transforms(ret) return ret except Exception as e: raise e print(f"Error loading image and caption for id {id}, error: {e}, redirecting to index 0", flush=True) ret = self[0] return ret def define_resolution(self, resolution: int): self.resolution = resolution if os.path.exists("/var/jomat/datasets/"): self.image_folder_path = f"/var/jomat/datasets/SAM_{resolution}" # self.image_folder_path = f"{self.ori_image_folder_path}_{resolution}" else: self.image_folder_path = f"{self.ori_image_folder_path}_{resolution}" print(f"SamDataset resolution defined to {resolution}, new image folder path: {self.image_folder_path}") def _load_image(self, id: int) -> Image.Image: if self.id_dict is not None: subfolder = self.id_dict[id] image_path = f"{self.image_folder_path}/{subfolder}/sa_{id}.jpg" else: image_path = f"{self.image_folder_path}/sa_{id}.jpg" try: with open(image_path, 'rb') as f: img = Image.open(f).convert("RGB") # return img except: # load original image if self.id_dict is not None: subfolder = self.id_dict[id] ori_image_path = f"{self.ori_image_folder_path}/{subfolder}/sa_{id}.jpg" else: ori_image_path = f"{self.ori_image_folder_path}/sa_{id}.jpg" assert os.path.exists(ori_image_path) with open(ori_image_path, 'rb') as f: img = Image.open(f).convert("RGB") # resize image keep aspect ratio if self.resolution is not None: img = transforms.Resize(self.resolution, interpolation=transforms.InterpolationMode.BICUBIC)(img) # write image os.makedirs(os.path.dirname(image_path), exist_ok=True) img.save(image_path) return img def _load_caption(self, id: int): caption_path = f"{self.caption_folder_path}/sa_{id}.txt" if not os.path.exists(caption_path): return None try: with open(caption_path, 'r', encoding="utf-8") as f: content = f.read() except Exception as e: raise e print(f"Error reading caption file {caption_path}, error: {e}") return None sentences = content.split('.') # remove empty sentences and sentences with "black and white"(too many false prediction) sentences = [sentence.strip() for sentence in sentences if sentence.strip() and "black and white" not in sentence] # join sentence sentences = ". ".join(sentences) if len(sentences) > 0 and sentences[-1] != '.': sentences += '.' return sentences def with_transform(self, transform): self.transforms = transform return self def subsample(self, n: int = 10000): if n is None or n == -1: return self ori_len = len(self) assert n <= ori_len # equal interval subsample ids = self.ids[::ori_len // n][:n] self.ids = ids print(f"SAM dataset subsampled from {ori_len} to {len(self)}") return self if __name__ == "__main__": # sam_filt(caption_filt=False, clip_filt=False, clip_logit=True) from custom_datasets.sam_caption.mypath import MyPath dataset = SamDataset(image_folder_path=MyPath.db_root_dir("sam_images"), caption_folder_path=MyPath.db_root_dir("sam_captions"), id_file=MyPath.db_root_dir("sam_whole_filtered_ids_train"), id_dict_file=MyPath.db_root_dir("sam_id_dict")) dataset.get_img = False for i in tqdm.tqdm(dataset): a=i['text']