builder.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. from __future__ import annotations
  2. from pathlib import Path
  3. from typing import Callable
  4. from .indexers import (
  5. build_filename_matched_records,
  6. build_paired_folder_records,
  7. build_prefixed_paired_records,
  8. build_pre_split_records,
  9. build_stem_paired_records,
  10. build_xml_annotation_records,
  11. )
  12. from .records import SegSampleRecord
  13. def _build_bus_uclm(root: Path) -> list[SegSampleRecord]:
  14. return build_paired_folder_records(
  15. dataset_name="BUS-UCLM",
  16. image_dir=root / "images",
  17. mask_dir=root / "masks",
  18. )
  19. def _build_tg3k(root: Path) -> list[SegSampleRecord]:
  20. return build_paired_folder_records(
  21. dataset_name="TG3K",
  22. image_dir=root / "thyroid-image",
  23. mask_dir=root / "thyroid-mask",
  24. )
  25. def _build_tn3k(root: Path) -> list[SegSampleRecord]:
  26. return build_pre_split_records(
  27. dataset_name="TN3K",
  28. train_image_dir=root / "trainval-image",
  29. train_mask_dir=root / "trainval-mask",
  30. test_image_dir=root / "test-image",
  31. test_mask_dir=root / "test-mask",
  32. )
  33. def _build_otu_2d(root: Path) -> list[SegSampleRecord]:
  34. return build_stem_paired_records(
  35. dataset_name="OTU_2d",
  36. image_dir=root / "images",
  37. mask_dir=root / "annotations",
  38. )
  39. def _build_busi(root: Path) -> list[SegSampleRecord]:
  40. base = root / "Dataset_BUSI_with_GT"
  41. records: list[SegSampleRecord] = []
  42. for class_name in ["benign", "malignant", "normal"]:
  43. class_dir = base / class_name
  44. records.extend(
  45. build_filename_matched_records(
  46. dataset_name="BUSI",
  47. folder=class_dir,
  48. class_name=class_name,
  49. )
  50. )
  51. return records
  52. def _build_bus_bra(root: Path) -> list[SegSampleRecord]:
  53. image_dir = root / "BUSBRA" / "BUSBRA" / "Images"
  54. mask_dir = root / "BUSBRA" / "BUSBRA" / "Masks"
  55. return build_prefixed_paired_records(
  56. dataset_name="BUS-BRA",
  57. image_dir=image_dir,
  58. mask_dir=mask_dir,
  59. image_prefix_to_strip="bus_",
  60. mask_prefix_to_strip="mask_",
  61. )
  62. def _build_bus_uc(root: Path) -> list[SegSampleRecord]:
  63. base = root / "BUS_UC" / "BUS_UC"
  64. records: list[SegSampleRecord] = []
  65. records.extend(
  66. build_paired_folder_records(
  67. dataset_name="BUS_UC",
  68. image_dir=base / "All" / "images",
  69. mask_dir=base / "All" / "masks",
  70. split="all",
  71. class_name="all",
  72. )
  73. )
  74. records.extend(
  75. build_paired_folder_records(
  76. dataset_name="BUS_UC",
  77. image_dir=base / "Benign" / "images",
  78. mask_dir=base / "Benign" / "masks",
  79. class_name="benign",
  80. )
  81. )
  82. records.extend(
  83. build_paired_folder_records(
  84. dataset_name="BUS_UC",
  85. image_dir=base / "Malignant" / "images",
  86. mask_dir=base / "Malignant" / "masks",
  87. class_name="malignant",
  88. )
  89. )
  90. return records
  91. def _build_ccaui(root: Path) -> list[SegSampleRecord]:
  92. base = root / "Common Carotid Artery Ultrasound Images"
  93. return build_paired_folder_records(
  94. dataset_name="CCAUI",
  95. image_dir=base / "US images",
  96. mask_dir=base / "Expert mask images",
  97. )
  98. def _build_ddti(root: Path) -> list[SegSampleRecord]:
  99. return build_xml_annotation_records(
  100. dataset_name="DDTI",
  101. root=root,
  102. )
  103. DATASET_REGISTRY: dict[str, Callable[[Path], list[SegSampleRecord]]] = {
  104. "BUS-UCLM": _build_bus_uclm,
  105. "TG3K": _build_tg3k,
  106. "TN3K": _build_tn3k,
  107. "OTU_2d": _build_otu_2d,
  108. "BUSI": _build_busi,
  109. "BUS-BRA": _build_bus_bra,
  110. "BUS_UC": _build_bus_uc,
  111. "CCAUI": _build_ccaui,
  112. "DDTI": _build_ddti,
  113. }
  114. def build_dataset_index(dataset_name: str, root: str | Path) -> list[SegSampleRecord]:
  115. builder = DATASET_REGISTRY.get(dataset_name)
  116. if builder is None:
  117. raise ValueError(
  118. f"Unsupported dataset '{dataset_name}'. Expected one of: {', '.join(DATASET_REGISTRY)}."
  119. )
  120. root = Path(root)
  121. if not root.exists():
  122. raise FileNotFoundError(f"Dataset root not found: {root}")
  123. return builder(root)
  124. __all__ = ["DATASET_REGISTRY", "build_dataset_index"]