| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159 |
- from __future__ import annotations
- import random
- from collections import defaultdict
- from pathlib import Path
- from .builder import build_dataset_index
- from .records import SegSampleRecord
- PROJECT_SPLIT_ROOT = Path("splits") / "project"
- PROJECT_SPLIT_DATASETS = {"BUS-UCLM", "BUSI", "BUS-BRA", "BUS_UC", "CCAUI", "DDTI"}
- def _project_split_dir(root: str | Path) -> Path:
- return Path(root) / PROJECT_SPLIT_ROOT
- def get_project_split_file(
- root: str | Path,
- split: str,
- ) -> Path:
- return _project_split_dir(root) / f"{split}.txt"
- def load_project_split_ids(
- root: str | Path,
- split: str,
- ) -> list[str]:
- path = get_project_split_file(root, split)
- if not path.exists():
- raise FileNotFoundError(f"Project split file not found: {path}")
- return [
- line.strip()
- for line in path.read_text(encoding="utf-8", errors="ignore").splitlines()
- if line.strip()
- ]
- def _write_split_ids(path: Path, sample_ids: list[str]) -> None:
- path.parent.mkdir(parents=True, exist_ok=True)
- if sample_ids:
- path.write_text("\n".join(sample_ids) + "\n", encoding="utf-8")
- else:
- path.write_text("", encoding="utf-8")
- def _deduplicate_records(
- dataset_name: str,
- records: list[SegSampleRecord],
- ) -> list[SegSampleRecord]:
- if dataset_name != "BUS_UC":
- return records
- # BUS_UC 的 All 与 Benign/Malignant 是重复样本,默认只保留 All 作为正式划分基底。
- all_records = [record for record in records if record.class_name == "all"]
- return all_records if all_records else records
- def select_project_split_base_records(
- dataset_name: str,
- records: list[SegSampleRecord],
- ) -> list[SegSampleRecord]:
- return _deduplicate_records(dataset_name, records)
- def _group_records_for_split(
- records: list[SegSampleRecord],
- ) -> dict[str, list[SegSampleRecord]]:
- groups: dict[str, list[SegSampleRecord]] = defaultdict(list)
- for record in records:
- key = record.class_name or "__default__"
- groups[key].append(record)
- return groups
- def _split_group(
- group_records: list[SegSampleRecord],
- *,
- val_ratio: float,
- rng: random.Random,
- ) -> tuple[list[SegSampleRecord], list[SegSampleRecord]]:
- shuffled = list(group_records)
- rng.shuffle(shuffled)
- val_count = int(round(len(shuffled) * val_ratio))
- if len(shuffled) >= 2:
- val_count = max(1, min(len(shuffled) - 1, val_count))
- elif len(shuffled) == 1:
- val_count = 0
- val_records = shuffled[:val_count]
- train_records = shuffled[val_count:]
- return train_records, val_records
- def generate_project_splits(
- dataset_name: str,
- root: str | Path,
- *,
- val_ratio: float = 0.2,
- seed: int = 42,
- stratify_by_class: bool = True,
- reuse_existing: bool = True,
- ) -> dict[str, list[str]]:
- if dataset_name not in PROJECT_SPLIT_DATASETS:
- raise ValueError(
- f"Dataset '{dataset_name}' is not enabled for project split generation."
- )
- if not 0.0 < val_ratio < 1.0:
- raise ValueError(f"val_ratio must be between 0 and 1, got {val_ratio}.")
- train_path = get_project_split_file(root, "train")
- val_path = get_project_split_file(root, "val")
- if reuse_existing and train_path.exists() and val_path.exists():
- return {
- "train": load_project_split_ids(root, "train"),
- "val": load_project_split_ids(root, "val"),
- }
- records = build_dataset_index(dataset_name, root)
- records = _deduplicate_records(dataset_name, records)
- rng = random.Random(seed)
- train_records: list[SegSampleRecord] = []
- val_records: list[SegSampleRecord] = []
- if stratify_by_class:
- groups = _group_records_for_split(records)
- for group_records in groups.values():
- group_train, group_val = _split_group(group_records, val_ratio=val_ratio, rng=rng)
- train_records.extend(group_train)
- val_records.extend(group_val)
- else:
- train_records, val_records = _split_group(records, val_ratio=val_ratio, rng=rng)
- train_ids = sorted(record.sample_id for record in train_records if record.sample_id is not None)
- val_ids = sorted(record.sample_id for record in val_records if record.sample_id is not None)
- split_dir = _project_split_dir(root)
- split_dir.mkdir(parents=True, exist_ok=True)
- _write_split_ids(train_path, train_ids)
- _write_split_ids(val_path, val_ids)
- return {
- "train": train_ids,
- "val": val_ids,
- }
- __all__ = [
- "PROJECT_SPLIT_DATASETS",
- "PROJECT_SPLIT_ROOT",
- "generate_project_splits",
- "get_project_split_file",
- "load_project_split_ids",
- "select_project_split_base_records",
- ]
|