fclong's picture
Upload 396 files
8ebda9e
# 这里这个dataset只是临时测试用的,所以暂时用最简陋的方式放在这里,后续会优化
from torch.utils.data import Dataset
from PIL import Image
class flickr30k_CNA(Dataset):
def __init__(self, img_root_path=None,
text_annot_path=None,
data_process_fn=None):
self.images = []
self.captions = []
self.labels = []
self.root = img_root_path
with open(text_annot_path, 'r') as f:
for line in f:
line = line.strip().split('\t')
key, caption = line[0].split('#')[0], line[1]
img_path = key + '.jpg'
self.images.append(img_path)
self.captions.append(caption)
self.labels.append(key)
self.data_process_fn = data_process_fn
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = str(self.root + "/" + self.images[idx])
instance_image = Image.open(img_path)
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
captions = self.captions[idx]
label = self.labels[idx]
image, text = self.data_process_fn(instance_image, captions)
return image, text, label