|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Tuple |
|
|
|
from torchvision.datasets import VisionDataset |
|
|
|
from .decoders import TargetDecoder, ImageDataDecoder |
|
|
|
|
|
class ExtendedVisionDataset(VisionDataset): |
|
def __init__(self, *args, **kwargs) -> None: |
|
super().__init__(*args, **kwargs) |
|
|
|
def get_image_data(self, index: int) -> bytes: |
|
raise NotImplementedError |
|
|
|
def get_target(self, index: int) -> Any: |
|
raise NotImplementedError |
|
|
|
def __getitem__(self, index: int) -> Tuple[Any, Any]: |
|
try: |
|
image_data = self.get_image_data(index) |
|
image = ImageDataDecoder(image_data).decode() |
|
except Exception as e: |
|
raise RuntimeError(f"can not read image for sample {index}") from e |
|
target = self.get_target(index) |
|
target = TargetDecoder(target).decode() |
|
|
|
if self.transforms is not None: |
|
image, target = self.transforms(image, target) |
|
|
|
return image, target |
|
|
|
def __len__(self) -> int: |
|
raise NotImplementedError |
|
|