from __future__ import annotations import random from collections import defaultdict from pathlib import Path from .builder import build_dataset_index from .records import SegSampleRecord PROJECT_SPLIT_ROOT = Path("splits") / "project" PROJECT_SPLIT_DATASETS = {"BUS-UCLM", "BUSI", "BUS-BRA", "BUS_UC", "CCAUI", "DDTI"} def _project_split_dir(root: str | Path) -> Path: return Path(root) / PROJECT_SPLIT_ROOT def get_project_split_file( root: str | Path, split: str, ) -> Path: return _project_split_dir(root) / f"{split}.txt" def load_project_split_ids( root: str | Path, split: str, ) -> list[str]: path = get_project_split_file(root, split) if not path.exists(): raise FileNotFoundError(f"Project split file not found: {path}") return [ line.strip() for line in path.read_text(encoding="utf-8", errors="ignore").splitlines() if line.strip() ] def _write_split_ids(path: Path, sample_ids: list[str]) -> None: path.parent.mkdir(parents=True, exist_ok=True) if sample_ids: path.write_text("\n".join(sample_ids) + "\n", encoding="utf-8") else: path.write_text("", encoding="utf-8") def _deduplicate_records( dataset_name: str, records: list[SegSampleRecord], ) -> list[SegSampleRecord]: if dataset_name != "BUS_UC": return records # BUS_UC 的 All 与 Benign/Malignant 是重复样本,默认只保留 All 作为正式划分基底。 all_records = [record for record in records if record.class_name == "all"] return all_records if all_records else records def select_project_split_base_records( dataset_name: str, records: list[SegSampleRecord], ) -> list[SegSampleRecord]: return _deduplicate_records(dataset_name, records) def _group_records_for_split( records: list[SegSampleRecord], ) -> dict[str, list[SegSampleRecord]]: groups: dict[str, list[SegSampleRecord]] = defaultdict(list) for record in records: key = record.class_name or "__default__" groups[key].append(record) return groups def _split_group( group_records: list[SegSampleRecord], *, val_ratio: float, rng: random.Random, ) -> tuple[list[SegSampleRecord], list[SegSampleRecord]]: shuffled = list(group_records) rng.shuffle(shuffled) val_count = int(round(len(shuffled) * val_ratio)) if len(shuffled) >= 2: val_count = max(1, min(len(shuffled) - 1, val_count)) elif len(shuffled) == 1: val_count = 0 val_records = shuffled[:val_count] train_records = shuffled[val_count:] return train_records, val_records def generate_project_splits( dataset_name: str, root: str | Path, *, val_ratio: float = 0.2, seed: int = 42, stratify_by_class: bool = True, reuse_existing: bool = True, ) -> dict[str, list[str]]: if dataset_name not in PROJECT_SPLIT_DATASETS: raise ValueError( f"Dataset '{dataset_name}' is not enabled for project split generation." ) if not 0.0 < val_ratio < 1.0: raise ValueError(f"val_ratio must be between 0 and 1, got {val_ratio}.") train_path = get_project_split_file(root, "train") val_path = get_project_split_file(root, "val") if reuse_existing and train_path.exists() and val_path.exists(): return { "train": load_project_split_ids(root, "train"), "val": load_project_split_ids(root, "val"), } records = build_dataset_index(dataset_name, root) records = _deduplicate_records(dataset_name, records) rng = random.Random(seed) train_records: list[SegSampleRecord] = [] val_records: list[SegSampleRecord] = [] if stratify_by_class: groups = _group_records_for_split(records) for group_records in groups.values(): group_train, group_val = _split_group(group_records, val_ratio=val_ratio, rng=rng) train_records.extend(group_train) val_records.extend(group_val) else: train_records, val_records = _split_group(records, val_ratio=val_ratio, rng=rng) train_ids = sorted(record.sample_id for record in train_records if record.sample_id is not None) val_ids = sorted(record.sample_id for record in val_records if record.sample_id is not None) split_dir = _project_split_dir(root) split_dir.mkdir(parents=True, exist_ok=True) _write_split_ids(train_path, train_ids) _write_split_ids(val_path, val_ids) return { "train": train_ids, "val": val_ids, } __all__ = [ "PROJECT_SPLIT_DATASETS", "PROJECT_SPLIT_ROOT", "generate_project_splits", "get_project_split_file", "load_project_split_ids", "select_project_split_base_records", ]