loaders.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. from __future__ import annotations
  2. from pathlib import Path
  3. from typing import Any
  4. from torch.utils.data import DataLoader
  5. from .augment import build_segmentation_augmentation
  6. from .builder import build_dataset_index
  7. from .collate import record_collate_fn
  8. from .datasets import SegmentationRecordDataset
  9. from .project_splits import (
  10. PROJECT_SPLIT_DATASETS,
  11. get_project_split_file,
  12. load_project_split_ids,
  13. select_project_split_base_records,
  14. )
  15. from .records import SegSampleRecord
  16. from .splits import load_id_txt, load_json_split
  17. OFFICIAL_SPLIT_FILES: dict[str, dict[str, str]] = {
  18. "OTU_2d": {
  19. "train": "train.txt",
  20. "val": "val.txt",
  21. },
  22. "TN3K": {
  23. "train": "tn3k-trainval.json",
  24. "val": "tn3k-trainval.json",
  25. "test": "tn3k-trainval.json",
  26. },
  27. "TG3K": {
  28. "train": "tg3k-trainval.json",
  29. "val": "tg3k-trainval.json",
  30. "test": "tg3k-trainval.json",
  31. },
  32. }
  33. def _normalize_id_set(values: list[str]) -> set[str]:
  34. normalized = set()
  35. for item in values:
  36. normalized.add(item)
  37. try:
  38. normalized.add(f"{int(item):04d}")
  39. except ValueError:
  40. pass
  41. return normalized
  42. def _as_exact_id_set(values: list[str]) -> set[str]:
  43. return {item for item in values}
  44. def _clone_record(record: SegSampleRecord, split_name: str | None) -> SegSampleRecord:
  45. return SegSampleRecord(
  46. dataset_name=record.dataset_name,
  47. image_path=record.image_path,
  48. mask_path=record.mask_path,
  49. annotation_path=record.annotation_path,
  50. split=split_name,
  51. sample_id=record.sample_id,
  52. class_name=record.class_name,
  53. meta=dict(record.meta),
  54. )
  55. def _filter_by_sample_ids(records: list[SegSampleRecord], sample_ids: set[str], split_name: str) -> list[SegSampleRecord]:
  56. filtered = []
  57. for record in records:
  58. if record.sample_id in sample_ids:
  59. filtered.append(_clone_record(record, split_name))
  60. return filtered
  61. def _filter_by_existing_split(records: list[SegSampleRecord], split: str) -> list[SegSampleRecord]:
  62. return [_clone_record(record, split) for record in records if record.split == split]
  63. def get_official_split_file(
  64. dataset_name: str,
  65. root: str | Path,
  66. split: str,
  67. ) -> Path | None:
  68. split_map = OFFICIAL_SPLIT_FILES.get(dataset_name)
  69. if split_map is None:
  70. return None
  71. relative_path = split_map.get(split)
  72. if relative_path is None:
  73. return None
  74. return Path(root) / relative_path
  75. def list_supported_splits(dataset_name: str) -> list[str]:
  76. official = OFFICIAL_SPLIT_FILES.get(dataset_name)
  77. if official is not None:
  78. return list(official.keys())
  79. if dataset_name in PROJECT_SPLIT_DATASETS:
  80. return ["train", "val"]
  81. return []
  82. def apply_official_split(
  83. dataset_name: str,
  84. root: str | Path,
  85. records: list[SegSampleRecord],
  86. split: str,
  87. *,
  88. split_file: str | Path | None = None,
  89. ) -> list[SegSampleRecord]:
  90. root = Path(root)
  91. if dataset_name == "OTU_2d":
  92. if split not in {"train", "val"}:
  93. raise ValueError("OTU_2d currently supports official splits: train, val.")
  94. split_path = Path(split_file) if split_file is not None else get_official_split_file(dataset_name, root, split)
  95. ids = _normalize_id_set(load_id_txt(split_path))
  96. return _filter_by_sample_ids(records, ids, split_name=split)
  97. if dataset_name == "TN3K":
  98. if split == "test":
  99. return _filter_by_existing_split(records, "test")
  100. split_path = Path(split_file) if split_file is not None else get_official_split_file(dataset_name, root, split)
  101. split_map = load_json_split(split_path)
  102. if split not in split_map:
  103. raise ValueError(f"Split '{split}' not found in {split_path}.")
  104. ids = _normalize_id_set(split_map[split])
  105. trainval_records = [record for record in records if record.split == "trainval"]
  106. return _filter_by_sample_ids(trainval_records, ids, split_name=split)
  107. if dataset_name == "TG3K":
  108. split_path = Path(split_file) if split_file is not None else get_official_split_file(dataset_name, root, split)
  109. split_map = load_json_split(split_path)
  110. if split not in split_map:
  111. raise ValueError(f"Split '{split}' not found in {split_path}.")
  112. ids = _normalize_id_set(split_map[split])
  113. return _filter_by_sample_ids(records, ids, split_name=split)
  114. if dataset_name in PROJECT_SPLIT_DATASETS:
  115. if split not in {"train", "val"}:
  116. raise ValueError(
  117. f"{dataset_name} currently supports project splits: train, val."
  118. )
  119. records = select_project_split_base_records(dataset_name, records)
  120. split_path = Path(split_file) if split_file is not None else get_project_split_file(root, split)
  121. ids = _as_exact_id_set(load_project_split_ids(root, split) if split_file is None else load_id_txt(split_path))
  122. return _filter_by_sample_ids(records, ids, split_name=split)
  123. filtered = _filter_by_existing_split(records, split)
  124. if filtered:
  125. return filtered
  126. raise ValueError(
  127. f"No split handler registered for dataset '{dataset_name}' and split '{split}'."
  128. )
  129. def build_record_dataset(
  130. dataset_name: str,
  131. root: str | Path,
  132. *,
  133. split: str | None = None,
  134. split_file: str | Path | None = None,
  135. augmentation_config: dict[str, Any] | None = None,
  136. image_transform=None,
  137. mask_transform=None,
  138. ) -> SegmentationRecordDataset:
  139. records = build_dataset_index(dataset_name, root)
  140. if split is not None:
  141. records = apply_official_split(
  142. dataset_name=dataset_name,
  143. root=root,
  144. records=records,
  145. split=split,
  146. split_file=split_file,
  147. )
  148. return SegmentationRecordDataset(
  149. records=records,
  150. joint_transform=build_segmentation_augmentation(augmentation_config),
  151. image_transform=image_transform,
  152. mask_transform=mask_transform,
  153. )
  154. def build_dataloader(
  155. dataset_name: str,
  156. root: str | Path,
  157. *,
  158. split: str | None = None,
  159. split_file: str | Path | None = None,
  160. batch_size: int = 4,
  161. shuffle: bool = False,
  162. num_workers: int = 0,
  163. augmentation_config: dict[str, Any] | None = None,
  164. image_transform=None,
  165. mask_transform=None,
  166. **loader_kwargs: Any,
  167. ) -> DataLoader:
  168. dataset = build_record_dataset(
  169. dataset_name=dataset_name,
  170. root=root,
  171. split=split,
  172. split_file=split_file,
  173. augmentation_config=augmentation_config,
  174. image_transform=image_transform,
  175. mask_transform=mask_transform,
  176. )
  177. return DataLoader(
  178. dataset,
  179. batch_size=batch_size,
  180. shuffle=shuffle,
  181. num_workers=num_workers,
  182. collate_fn=loader_kwargs.pop("collate_fn", record_collate_fn),
  183. **loader_kwargs,
  184. )
  185. __all__ = [
  186. "OFFICIAL_SPLIT_FILES",
  187. "apply_official_split",
  188. "build_record_dataset",
  189. "build_dataloader",
  190. "get_official_split_file",
  191. "list_supported_splits",
  192. ]