| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224 |
- from __future__ import annotations
- from pathlib import Path
- from typing import Any
- from torch.utils.data import DataLoader
- from .augment import build_segmentation_augmentation
- from .builder import build_dataset_index
- from .collate import record_collate_fn
- from .datasets import SegmentationRecordDataset
- from .project_splits import (
- PROJECT_SPLIT_DATASETS,
- get_project_split_file,
- load_project_split_ids,
- select_project_split_base_records,
- )
- from .records import SegSampleRecord
- from .splits import load_id_txt, load_json_split
- OFFICIAL_SPLIT_FILES: dict[str, dict[str, str]] = {
- "OTU_2d": {
- "train": "train.txt",
- "val": "val.txt",
- },
- "TN3K": {
- "train": "tn3k-trainval.json",
- "val": "tn3k-trainval.json",
- "test": "tn3k-trainval.json",
- },
- "TG3K": {
- "train": "tg3k-trainval.json",
- "val": "tg3k-trainval.json",
- "test": "tg3k-trainval.json",
- },
- }
- def _normalize_id_set(values: list[str]) -> set[str]:
- normalized = set()
- for item in values:
- normalized.add(item)
- try:
- normalized.add(f"{int(item):04d}")
- except ValueError:
- pass
- return normalized
- def _as_exact_id_set(values: list[str]) -> set[str]:
- return {item for item in values}
- def _clone_record(record: SegSampleRecord, split_name: str | None) -> SegSampleRecord:
- return SegSampleRecord(
- dataset_name=record.dataset_name,
- image_path=record.image_path,
- mask_path=record.mask_path,
- annotation_path=record.annotation_path,
- split=split_name,
- sample_id=record.sample_id,
- class_name=record.class_name,
- meta=dict(record.meta),
- )
- def _filter_by_sample_ids(records: list[SegSampleRecord], sample_ids: set[str], split_name: str) -> list[SegSampleRecord]:
- filtered = []
- for record in records:
- if record.sample_id in sample_ids:
- filtered.append(_clone_record(record, split_name))
- return filtered
- def _filter_by_existing_split(records: list[SegSampleRecord], split: str) -> list[SegSampleRecord]:
- return [_clone_record(record, split) for record in records if record.split == split]
- def get_official_split_file(
- dataset_name: str,
- root: str | Path,
- split: str,
- ) -> Path | None:
- split_map = OFFICIAL_SPLIT_FILES.get(dataset_name)
- if split_map is None:
- return None
- relative_path = split_map.get(split)
- if relative_path is None:
- return None
- return Path(root) / relative_path
- def list_supported_splits(dataset_name: str) -> list[str]:
- official = OFFICIAL_SPLIT_FILES.get(dataset_name)
- if official is not None:
- return list(official.keys())
- if dataset_name in PROJECT_SPLIT_DATASETS:
- return ["train", "val"]
- return []
- def apply_official_split(
- dataset_name: str,
- root: str | Path,
- records: list[SegSampleRecord],
- split: str,
- *,
- split_file: str | Path | None = None,
- ) -> list[SegSampleRecord]:
- root = Path(root)
- if dataset_name == "OTU_2d":
- if split not in {"train", "val"}:
- raise ValueError("OTU_2d currently supports official splits: train, val.")
- split_path = Path(split_file) if split_file is not None else get_official_split_file(dataset_name, root, split)
- ids = _normalize_id_set(load_id_txt(split_path))
- return _filter_by_sample_ids(records, ids, split_name=split)
- if dataset_name == "TN3K":
- if split == "test":
- return _filter_by_existing_split(records, "test")
- split_path = Path(split_file) if split_file is not None else get_official_split_file(dataset_name, root, split)
- split_map = load_json_split(split_path)
- if split not in split_map:
- raise ValueError(f"Split '{split}' not found in {split_path}.")
- ids = _normalize_id_set(split_map[split])
- trainval_records = [record for record in records if record.split == "trainval"]
- return _filter_by_sample_ids(trainval_records, ids, split_name=split)
- if dataset_name == "TG3K":
- split_path = Path(split_file) if split_file is not None else get_official_split_file(dataset_name, root, split)
- split_map = load_json_split(split_path)
- if split not in split_map:
- raise ValueError(f"Split '{split}' not found in {split_path}.")
- ids = _normalize_id_set(split_map[split])
- return _filter_by_sample_ids(records, ids, split_name=split)
- if dataset_name in PROJECT_SPLIT_DATASETS:
- if split not in {"train", "val"}:
- raise ValueError(
- f"{dataset_name} currently supports project splits: train, val."
- )
- records = select_project_split_base_records(dataset_name, records)
- split_path = Path(split_file) if split_file is not None else get_project_split_file(root, split)
- ids = _as_exact_id_set(load_project_split_ids(root, split) if split_file is None else load_id_txt(split_path))
- return _filter_by_sample_ids(records, ids, split_name=split)
- filtered = _filter_by_existing_split(records, split)
- if filtered:
- return filtered
- raise ValueError(
- f"No split handler registered for dataset '{dataset_name}' and split '{split}'."
- )
- def build_record_dataset(
- dataset_name: str,
- root: str | Path,
- *,
- split: str | None = None,
- split_file: str | Path | None = None,
- augmentation_config: dict[str, Any] | None = None,
- image_transform=None,
- mask_transform=None,
- ) -> SegmentationRecordDataset:
- records = build_dataset_index(dataset_name, root)
- if split is not None:
- records = apply_official_split(
- dataset_name=dataset_name,
- root=root,
- records=records,
- split=split,
- split_file=split_file,
- )
- return SegmentationRecordDataset(
- records=records,
- joint_transform=build_segmentation_augmentation(augmentation_config),
- image_transform=image_transform,
- mask_transform=mask_transform,
- )
- def build_dataloader(
- dataset_name: str,
- root: str | Path,
- *,
- split: str | None = None,
- split_file: str | Path | None = None,
- batch_size: int = 4,
- shuffle: bool = False,
- num_workers: int = 0,
- augmentation_config: dict[str, Any] | None = None,
- image_transform=None,
- mask_transform=None,
- **loader_kwargs: Any,
- ) -> DataLoader:
- dataset = build_record_dataset(
- dataset_name=dataset_name,
- root=root,
- split=split,
- split_file=split_file,
- augmentation_config=augmentation_config,
- image_transform=image_transform,
- mask_transform=mask_transform,
- )
- return DataLoader(
- dataset,
- batch_size=batch_size,
- shuffle=shuffle,
- num_workers=num_workers,
- collate_fn=loader_kwargs.pop("collate_fn", record_collate_fn),
- **loader_kwargs,
- )
- __all__ = [
- "OFFICIAL_SPLIT_FILES",
- "apply_official_split",
- "build_record_dataset",
- "build_dataloader",
- "get_official_split_file",
- "list_supported_splits",
- ]
|