datasets.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. from __future__ import annotations
  2. from collections.abc import Callable
  3. from pathlib import Path
  4. from typing import Any
  5. import numpy as np
  6. import torch
  7. from PIL import Image
  8. from torch.utils.data import Dataset
  9. from .ddti import build_ddti_mask
  10. from .records import SegSampleRecord
  11. def default_image_loader(path: str | Path) -> torch.Tensor:
  12. image = Image.open(path).convert("RGB")
  13. array = np.asarray(image, dtype=np.float32) / 255.0
  14. return torch.from_numpy(array).permute(2, 0, 1).contiguous()
  15. def default_mask_loader(path: str | Path) -> torch.Tensor:
  16. mask = Image.open(path).convert("L")
  17. array = (np.asarray(mask, dtype=np.float32) > 0).astype(np.float32)
  18. return torch.from_numpy(array).unsqueeze(0).contiguous()
  19. class SegmentationRecordDataset(Dataset):
  20. def __init__(
  21. self,
  22. records: list[SegSampleRecord],
  23. image_loader: Callable[[str | Path], torch.Tensor] | None = None,
  24. mask_loader: Callable[[str | Path], torch.Tensor] | None = None,
  25. joint_transform: Callable[[torch.Tensor, torch.Tensor | None], tuple[torch.Tensor, torch.Tensor | None]] | None = None,
  26. image_transform: Callable[[torch.Tensor], torch.Tensor] | None = None,
  27. mask_transform: Callable[[torch.Tensor], torch.Tensor] | None = None,
  28. ) -> None:
  29. self.records = records
  30. self.image_loader = image_loader or default_image_loader
  31. self.mask_loader = mask_loader or default_mask_loader
  32. self.joint_transform = joint_transform
  33. self.image_transform = image_transform
  34. self.mask_transform = mask_transform
  35. def __len__(self) -> int:
  36. return len(self.records)
  37. def __getitem__(self, index: int) -> dict[str, Any]:
  38. record = self.records[index]
  39. image = self.image_loader(record.image_path)
  40. mask = None
  41. if record.mask_path is not None:
  42. mask = self.mask_loader(record.mask_path)
  43. elif record.annotation_path is not None and record.dataset_name == "DDTI":
  44. ddti_mask = build_ddti_mask(record.image_path, record.annotation_path)
  45. mask_array = (np.asarray(ddti_mask, dtype=np.float32) > 0).astype(np.float32)
  46. mask = torch.from_numpy(mask_array).unsqueeze(0).contiguous()
  47. if self.joint_transform is not None:
  48. image, mask = self.joint_transform(image, mask)
  49. if self.image_transform is not None:
  50. image = self.image_transform(image)
  51. if mask is not None and self.mask_transform is not None:
  52. mask = self.mask_transform(mask)
  53. return {
  54. "image": image,
  55. "mask": mask,
  56. "dataset_name": record.dataset_name,
  57. "sample_id": record.sample_id,
  58. "split": record.split,
  59. "class_name": record.class_name,
  60. "meta": record.meta,
  61. }
  62. __all__ = [
  63. "SegmentationRecordDataset",
  64. "default_image_loader",
  65. "default_mask_loader",
  66. ]