from __future__ import annotations from pathlib import Path from typing import Any from torch.utils.data import DataLoader from .augment import build_segmentation_augmentation from .builder import build_dataset_index from .collate import record_collate_fn from .datasets import SegmentationRecordDataset from .project_splits import ( PROJECT_SPLIT_DATASETS, get_project_split_file, load_project_split_ids, select_project_split_base_records, ) from .records import SegSampleRecord from .splits import load_id_txt, load_json_split OFFICIAL_SPLIT_FILES: dict[str, dict[str, str]] = { "OTU_2d": { "train": "train.txt", "val": "val.txt", }, "TN3K": { "train": "tn3k-trainval.json", "val": "tn3k-trainval.json", "test": "tn3k-trainval.json", }, "TG3K": { "train": "tg3k-trainval.json", "val": "tg3k-trainval.json", "test": "tg3k-trainval.json", }, } def _normalize_id_set(values: list[str]) -> set[str]: normalized = set() for item in values: normalized.add(item) try: normalized.add(f"{int(item):04d}") except ValueError: pass return normalized def _as_exact_id_set(values: list[str]) -> set[str]: return {item for item in values} def _clone_record(record: SegSampleRecord, split_name: str | None) -> SegSampleRecord: return SegSampleRecord( dataset_name=record.dataset_name, image_path=record.image_path, mask_path=record.mask_path, annotation_path=record.annotation_path, split=split_name, sample_id=record.sample_id, class_name=record.class_name, meta=dict(record.meta), ) def _filter_by_sample_ids(records: list[SegSampleRecord], sample_ids: set[str], split_name: str) -> list[SegSampleRecord]: filtered = [] for record in records: if record.sample_id in sample_ids: filtered.append(_clone_record(record, split_name)) return filtered def _filter_by_existing_split(records: list[SegSampleRecord], split: str) -> list[SegSampleRecord]: return [_clone_record(record, split) for record in records if record.split == split] def get_official_split_file( dataset_name: str, root: str | Path, split: str, ) -> Path | None: split_map = OFFICIAL_SPLIT_FILES.get(dataset_name) if split_map is None: return None relative_path = split_map.get(split) if relative_path is None: return None return Path(root) / relative_path def list_supported_splits(dataset_name: str) -> list[str]: official = OFFICIAL_SPLIT_FILES.get(dataset_name) if official is not None: return list(official.keys()) if dataset_name in PROJECT_SPLIT_DATASETS: return ["train", "val"] return [] def apply_official_split( dataset_name: str, root: str | Path, records: list[SegSampleRecord], split: str, *, split_file: str | Path | None = None, ) -> list[SegSampleRecord]: root = Path(root) if dataset_name == "OTU_2d": if split not in {"train", "val"}: raise ValueError("OTU_2d currently supports official splits: train, val.") split_path = Path(split_file) if split_file is not None else get_official_split_file(dataset_name, root, split) ids = _normalize_id_set(load_id_txt(split_path)) return _filter_by_sample_ids(records, ids, split_name=split) if dataset_name == "TN3K": if split == "test": return _filter_by_existing_split(records, "test") split_path = Path(split_file) if split_file is not None else get_official_split_file(dataset_name, root, split) split_map = load_json_split(split_path) if split not in split_map: raise ValueError(f"Split '{split}' not found in {split_path}.") ids = _normalize_id_set(split_map[split]) trainval_records = [record for record in records if record.split == "trainval"] return _filter_by_sample_ids(trainval_records, ids, split_name=split) if dataset_name == "TG3K": split_path = Path(split_file) if split_file is not None else get_official_split_file(dataset_name, root, split) split_map = load_json_split(split_path) if split not in split_map: raise ValueError(f"Split '{split}' not found in {split_path}.") ids = _normalize_id_set(split_map[split]) return _filter_by_sample_ids(records, ids, split_name=split) if dataset_name in PROJECT_SPLIT_DATASETS: if split not in {"train", "val"}: raise ValueError( f"{dataset_name} currently supports project splits: train, val." ) records = select_project_split_base_records(dataset_name, records) split_path = Path(split_file) if split_file is not None else get_project_split_file(root, split) ids = _as_exact_id_set(load_project_split_ids(root, split) if split_file is None else load_id_txt(split_path)) return _filter_by_sample_ids(records, ids, split_name=split) filtered = _filter_by_existing_split(records, split) if filtered: return filtered raise ValueError( f"No split handler registered for dataset '{dataset_name}' and split '{split}'." ) def build_record_dataset( dataset_name: str, root: str | Path, *, split: str | None = None, split_file: str | Path | None = None, augmentation_config: dict[str, Any] | None = None, image_transform=None, mask_transform=None, ) -> SegmentationRecordDataset: records = build_dataset_index(dataset_name, root) if split is not None: records = apply_official_split( dataset_name=dataset_name, root=root, records=records, split=split, split_file=split_file, ) return SegmentationRecordDataset( records=records, joint_transform=build_segmentation_augmentation(augmentation_config), image_transform=image_transform, mask_transform=mask_transform, ) def build_dataloader( dataset_name: str, root: str | Path, *, split: str | None = None, split_file: str | Path | None = None, batch_size: int = 4, shuffle: bool = False, num_workers: int = 0, augmentation_config: dict[str, Any] | None = None, image_transform=None, mask_transform=None, **loader_kwargs: Any, ) -> DataLoader: dataset = build_record_dataset( dataset_name=dataset_name, root=root, split=split, split_file=split_file, augmentation_config=augmentation_config, image_transform=image_transform, mask_transform=mask_transform, ) return DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=loader_kwargs.pop("collate_fn", record_collate_fn), **loader_kwargs, ) __all__ = [ "OFFICIAL_SPLIT_FILES", "apply_official_split", "build_record_dataset", "build_dataloader", "get_official_split_file", "list_supported_splits", ]