|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import csv |
|
import json |
|
import random |
|
from pathlib import Path |
|
from typing import Callable, Dict, Optional, Union |
|
|
|
from torchvision.datasets import VisionDataset |
|
from torchvision.io import ImageReadMode, read_image |
|
|
|
class MIMICDataset(VisionDataset): |
|
""" |
|
Dataset for loading image-text data for tasks like CLIP training, Image Captioning. |
|
|
|
Args: |
|
root: (string): The root path where the dataset is stored |
|
file_path: (string): Path to the file containing the image_paths and associated captions. |
|
The expected format is jsonlines where each line is a json object containing to keys. |
|
`image_path`: The path to the image. |
|
`captions`: An `array` of captions. |
|
mode: (string): target format: |
|
* 'longest': return the longest sections |
|
* 'docs': return findings and impressions |
|
transform (callable, optional): A function/transform that takes in an PIL image |
|
and returns a transformed version. E.g, ``transforms.ToTensor`` |
|
target_transform (callable, optional): A function/transform that takes in the |
|
target and transforms it. |
|
transforms (callable, optional): A function/transform that takes input sample and its target as entry |
|
and returns a transformed version. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
root: str, |
|
file_path: str, |
|
mode: str = 'longest', |
|
transform: Optional[Callable] = None, |
|
target_transform: Optional[Callable] = None, |
|
transforms: Optional[Callable] = None, |
|
): |
|
super().__init__(root, transforms, transform, target_transform) |
|
|
|
root = Path(root) |
|
|
|
if not mode in {'longest', 'docs'}: |
|
raise ValueError('Invalid mode') |
|
|
|
self.mode = mode |
|
|
|
with open(root / file_path, "r") as f: |
|
examples = [json.loads(line) for line in f.readlines()] |
|
|
|
self.captions = [] |
|
self.image_paths = [] |
|
|
|
for example in examples: |
|
img_path = root / example["image_path"] |
|
if img_path.exists(): |
|
self.captions.append(example["caption"]) |
|
self.image_paths.append(str(img_path)) |
|
|
|
def _load_image(self, idx: int): |
|
path = self.image_paths[idx] |
|
return read_image(path, mode=ImageReadMode.RGB) |
|
|
|
def _load_target(self, idx) -> str: |
|
sections = self.captions[idx] |
|
|
|
if self.mode == 'docs': |
|
_collection = [] |
|
if 'impression' in sections: |
|
_collection.append(sections['impression']) |
|
|
|
if 'findings' in sections: |
|
_collection.append(sections['findings']) |
|
|
|
if len(_collection) == 1: |
|
output = _collection[0] |
|
if len(_collection) == 2: |
|
output = random.choice(_collection) |
|
|
|
if self.mode == 'longest' or len(_collection) == 0: |
|
longest_section = max( |
|
filter(lambda x: isinstance(x, str), sections.values()), |
|
key=len |
|
) |
|
|
|
output = longest_section |
|
|
|
return output |
|
|
|
def __getitem__(self, index: int): |
|
image = self._load_image(index) |
|
target = self._load_target(index) |
|
|
|
if self.transforms is not None: |
|
image, target = self.transforms(image, target) |
|
|
|
return image, target |
|
|
|
def __len__(self) -> int: |
|
return len(self.captions) |
|
|
|
|
|
class ROCODataset(VisionDataset): |
|
""" |
|
Dataset for loading image-text data for tasks like CLIP training, Image Captioning. |
|
|
|
Args: |
|
root: (string): The root path where the dataset is stored |
|
file_path: (string): Path to the file containing the image_paths and associated captions. |
|
The expected format is jsonlines where each line is a json object containing to keys. |
|
`image_path`: The path to the image. |
|
`captions`: An `array` of captions. |
|
transform (callable, optional): A function/transform that takes in an PIL image |
|
and returns a transformed version. E.g, ``transforms.ToTensor`` |
|
target_transform (callable, optional): A function/transform that takes in the |
|
target and transforms it. |
|
transforms (callable, optional): A function/transform that takes input sample and its target as entry |
|
and returns a transformed version. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
root: str, |
|
split: str, |
|
transform: Optional[Callable] = None, |
|
target_transform: Optional[Callable] = None, |
|
transforms: Optional[Callable] = None, |
|
): |
|
super().__init__(root, transforms, transform, target_transform) |
|
|
|
root = Path(root) / f"{split}/radiology/" |
|
file_path = f"{split}.csv" |
|
|
|
self.captions = [] |
|
self.image_paths = [] |
|
|
|
with open((root / file_path).resolve(), 'r') as buf: |
|
csv_reader = csv.reader(buf) |
|
next(csv_reader) |
|
|
|
for row in csv_reader: |
|
if len(row) == 3: |
|
_, fname, caption = row |
|
else: |
|
print(row) |
|
self.captions.append(caption.strip()) |
|
self.image_paths.append(str(root / 'images' / fname.strip())) |
|
|
|
def _load_image(self, idx: int): |
|
path = self.image_paths[idx] |
|
return read_image(path, mode=ImageReadMode.RGB) |
|
|
|
def _load_target(self, idx: int) -> str: |
|
return self.captions[idx] |
|
|
|
def __getitem__(self, index: int): |
|
image = self._load_image(index) |
|
target = self._load_target(index) |
|
|
|
if self.transforms is not None: |
|
image, target = self.transforms(image, target) |
|
|
|
return image, target |
|
|
|
def __len__(self) -> int: |
|
return len(self.captions) |
|
|