| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182 |
- from __future__ import annotations
- from collections.abc import Callable
- from pathlib import Path
- from typing import Any
- import numpy as np
- import torch
- from PIL import Image
- from torch.utils.data import Dataset
- from .ddti import build_ddti_mask
- from .records import SegSampleRecord
- def default_image_loader(path: str | Path) -> torch.Tensor:
- image = Image.open(path).convert("RGB")
- array = np.asarray(image, dtype=np.float32) / 255.0
- return torch.from_numpy(array).permute(2, 0, 1).contiguous()
- def default_mask_loader(path: str | Path) -> torch.Tensor:
- mask = Image.open(path).convert("L")
- array = (np.asarray(mask, dtype=np.float32) > 0).astype(np.float32)
- return torch.from_numpy(array).unsqueeze(0).contiguous()
- class SegmentationRecordDataset(Dataset):
- def __init__(
- self,
- records: list[SegSampleRecord],
- image_loader: Callable[[str | Path], torch.Tensor] | None = None,
- mask_loader: Callable[[str | Path], torch.Tensor] | None = None,
- joint_transform: Callable[[torch.Tensor, torch.Tensor | None], tuple[torch.Tensor, torch.Tensor | None]] | None = None,
- image_transform: Callable[[torch.Tensor], torch.Tensor] | None = None,
- mask_transform: Callable[[torch.Tensor], torch.Tensor] | None = None,
- ) -> None:
- self.records = records
- self.image_loader = image_loader or default_image_loader
- self.mask_loader = mask_loader or default_mask_loader
- self.joint_transform = joint_transform
- self.image_transform = image_transform
- self.mask_transform = mask_transform
- def __len__(self) -> int:
- return len(self.records)
- def __getitem__(self, index: int) -> dict[str, Any]:
- record = self.records[index]
- image = self.image_loader(record.image_path)
- mask = None
- if record.mask_path is not None:
- mask = self.mask_loader(record.mask_path)
- elif record.annotation_path is not None and record.dataset_name == "DDTI":
- ddti_mask = build_ddti_mask(record.image_path, record.annotation_path)
- mask_array = (np.asarray(ddti_mask, dtype=np.float32) > 0).astype(np.float32)
- mask = torch.from_numpy(mask_array).unsqueeze(0).contiguous()
- if self.joint_transform is not None:
- image, mask = self.joint_transform(image, mask)
- if self.image_transform is not None:
- image = self.image_transform(image)
- if mask is not None and self.mask_transform is not None:
- mask = self.mask_transform(mask)
- return {
- "image": image,
- "mask": mask,
- "dataset_name": record.dataset_name,
- "sample_id": record.sample_id,
- "split": record.split,
- "class_name": record.class_name,
- "meta": record.meta,
- }
- __all__ = [
- "SegmentationRecordDataset",
- "default_image_loader",
- "default_mask_loader",
- ]
|