project_splits.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. from __future__ import annotations
  2. import random
  3. from collections import defaultdict
  4. from pathlib import Path
  5. from .builder import build_dataset_index
  6. from .records import SegSampleRecord
  7. PROJECT_SPLIT_ROOT = Path("splits") / "project"
  8. PROJECT_SPLIT_DATASETS = {"BUS-UCLM", "BUSI", "BUS-BRA", "BUS_UC", "CCAUI", "DDTI"}
  9. def _project_split_dir(root: str | Path) -> Path:
  10. return Path(root) / PROJECT_SPLIT_ROOT
  11. def get_project_split_file(
  12. root: str | Path,
  13. split: str,
  14. ) -> Path:
  15. return _project_split_dir(root) / f"{split}.txt"
  16. def load_project_split_ids(
  17. root: str | Path,
  18. split: str,
  19. ) -> list[str]:
  20. path = get_project_split_file(root, split)
  21. if not path.exists():
  22. raise FileNotFoundError(f"Project split file not found: {path}")
  23. return [
  24. line.strip()
  25. for line in path.read_text(encoding="utf-8", errors="ignore").splitlines()
  26. if line.strip()
  27. ]
  28. def _write_split_ids(path: Path, sample_ids: list[str]) -> None:
  29. path.parent.mkdir(parents=True, exist_ok=True)
  30. if sample_ids:
  31. path.write_text("\n".join(sample_ids) + "\n", encoding="utf-8")
  32. else:
  33. path.write_text("", encoding="utf-8")
  34. def _deduplicate_records(
  35. dataset_name: str,
  36. records: list[SegSampleRecord],
  37. ) -> list[SegSampleRecord]:
  38. if dataset_name != "BUS_UC":
  39. return records
  40. # BUS_UC 的 All 与 Benign/Malignant 是重复样本,默认只保留 All 作为正式划分基底。
  41. all_records = [record for record in records if record.class_name == "all"]
  42. return all_records if all_records else records
  43. def select_project_split_base_records(
  44. dataset_name: str,
  45. records: list[SegSampleRecord],
  46. ) -> list[SegSampleRecord]:
  47. return _deduplicate_records(dataset_name, records)
  48. def _group_records_for_split(
  49. records: list[SegSampleRecord],
  50. ) -> dict[str, list[SegSampleRecord]]:
  51. groups: dict[str, list[SegSampleRecord]] = defaultdict(list)
  52. for record in records:
  53. key = record.class_name or "__default__"
  54. groups[key].append(record)
  55. return groups
  56. def _split_group(
  57. group_records: list[SegSampleRecord],
  58. *,
  59. val_ratio: float,
  60. rng: random.Random,
  61. ) -> tuple[list[SegSampleRecord], list[SegSampleRecord]]:
  62. shuffled = list(group_records)
  63. rng.shuffle(shuffled)
  64. val_count = int(round(len(shuffled) * val_ratio))
  65. if len(shuffled) >= 2:
  66. val_count = max(1, min(len(shuffled) - 1, val_count))
  67. elif len(shuffled) == 1:
  68. val_count = 0
  69. val_records = shuffled[:val_count]
  70. train_records = shuffled[val_count:]
  71. return train_records, val_records
  72. def generate_project_splits(
  73. dataset_name: str,
  74. root: str | Path,
  75. *,
  76. val_ratio: float = 0.2,
  77. seed: int = 42,
  78. stratify_by_class: bool = True,
  79. reuse_existing: bool = True,
  80. ) -> dict[str, list[str]]:
  81. if dataset_name not in PROJECT_SPLIT_DATASETS:
  82. raise ValueError(
  83. f"Dataset '{dataset_name}' is not enabled for project split generation."
  84. )
  85. if not 0.0 < val_ratio < 1.0:
  86. raise ValueError(f"val_ratio must be between 0 and 1, got {val_ratio}.")
  87. train_path = get_project_split_file(root, "train")
  88. val_path = get_project_split_file(root, "val")
  89. if reuse_existing and train_path.exists() and val_path.exists():
  90. return {
  91. "train": load_project_split_ids(root, "train"),
  92. "val": load_project_split_ids(root, "val"),
  93. }
  94. records = build_dataset_index(dataset_name, root)
  95. records = _deduplicate_records(dataset_name, records)
  96. rng = random.Random(seed)
  97. train_records: list[SegSampleRecord] = []
  98. val_records: list[SegSampleRecord] = []
  99. if stratify_by_class:
  100. groups = _group_records_for_split(records)
  101. for group_records in groups.values():
  102. group_train, group_val = _split_group(group_records, val_ratio=val_ratio, rng=rng)
  103. train_records.extend(group_train)
  104. val_records.extend(group_val)
  105. else:
  106. train_records, val_records = _split_group(records, val_ratio=val_ratio, rng=rng)
  107. train_ids = sorted(record.sample_id for record in train_records if record.sample_id is not None)
  108. val_ids = sorted(record.sample_id for record in val_records if record.sample_id is not None)
  109. split_dir = _project_split_dir(root)
  110. split_dir.mkdir(parents=True, exist_ok=True)
  111. _write_split_ids(train_path, train_ids)
  112. _write_split_ids(val_path, val_ids)
  113. return {
  114. "train": train_ids,
  115. "val": val_ids,
  116. }
  117. __all__ = [
  118. "PROJECT_SPLIT_DATASETS",
  119. "PROJECT_SPLIT_ROOT",
  120. "generate_project_splits",
  121. "get_project_split_file",
  122. "load_project_split_ids",
  123. "select_project_split_base_records",
  124. ]