from __future__ import annotations from collections.abc import Callable from pathlib import Path from typing import Any import numpy as np import torch from PIL import Image from torch.utils.data import Dataset from .ddti import build_ddti_mask from .records import SegSampleRecord def default_image_loader(path: str | Path) -> torch.Tensor: image = Image.open(path).convert("RGB") array = np.asarray(image, dtype=np.float32) / 255.0 return torch.from_numpy(array).permute(2, 0, 1).contiguous() def default_mask_loader(path: str | Path) -> torch.Tensor: mask = Image.open(path).convert("L") array = (np.asarray(mask, dtype=np.float32) > 0).astype(np.float32) return torch.from_numpy(array).unsqueeze(0).contiguous() class SegmentationRecordDataset(Dataset): def __init__( self, records: list[SegSampleRecord], image_loader: Callable[[str | Path], torch.Tensor] | None = None, mask_loader: Callable[[str | Path], torch.Tensor] | None = None, joint_transform: Callable[[torch.Tensor, torch.Tensor | None], tuple[torch.Tensor, torch.Tensor | None]] | None = None, image_transform: Callable[[torch.Tensor], torch.Tensor] | None = None, mask_transform: Callable[[torch.Tensor], torch.Tensor] | None = None, ) -> None: self.records = records self.image_loader = image_loader or default_image_loader self.mask_loader = mask_loader or default_mask_loader self.joint_transform = joint_transform self.image_transform = image_transform self.mask_transform = mask_transform def __len__(self) -> int: return len(self.records) def __getitem__(self, index: int) -> dict[str, Any]: record = self.records[index] image = self.image_loader(record.image_path) mask = None if record.mask_path is not None: mask = self.mask_loader(record.mask_path) elif record.annotation_path is not None and record.dataset_name == "DDTI": ddti_mask = build_ddti_mask(record.image_path, record.annotation_path) mask_array = (np.asarray(ddti_mask, dtype=np.float32) > 0).astype(np.float32) mask = torch.from_numpy(mask_array).unsqueeze(0).contiguous() if self.joint_transform is not None: image, mask = self.joint_transform(image, mask) if self.image_transform is not None: image = self.image_transform(image) if mask is not None and self.mask_transform is not None: mask = self.mask_transform(mask) return { "image": image, "mask": mask, "dataset_name": record.dataset_name, "sample_id": record.sample_id, "split": record.split, "class_name": record.class_name, "meta": record.meta, } __all__ = [ "SegmentationRecordDataset", "default_image_loader", "default_mask_loader", ]