| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556 |
- from __future__ import annotations
- import argparse
- import sys
- from collections import Counter
- from pathlib import Path
- ROOT_DIR = Path(__file__).resolve().parents[1]
- if str(ROOT_DIR) not in sys.path:
- sys.path.insert(0, str(ROOT_DIR))
- from lib.data import (
- PROJECT_SPLIT_DATASETS,
- build_record_dataset,
- generate_project_splits,
- )
- def main() -> None:
- parser = argparse.ArgumentParser(description="Generate project train/val split files.")
- parser.add_argument("--dataset", required=True, help="Dataset name")
- parser.add_argument("--root", required=True, help="Dataset root")
- parser.add_argument("--val-ratio", type=float, default=0.2, help="Validation ratio")
- parser.add_argument("--seed", type=int, default=42, help="Random seed")
- parser.add_argument(
- "--no-stratify",
- action="store_true",
- help="Disable class-wise stratified split when class labels exist",
- )
- parser.add_argument(
- "--overwrite",
- action="store_true",
- help="Regenerate split files even when matching files already exist.",
- )
- args = parser.parse_args()
- if args.dataset not in PROJECT_SPLIT_DATASETS:
- raise ValueError(f"Dataset '{args.dataset}' does not require project split generation.")
- split_map = generate_project_splits(
- dataset_name=args.dataset,
- root=args.root,
- val_ratio=args.val_ratio,
- seed=args.seed,
- stratify_by_class=not args.no_stratify,
- reuse_existing=not args.overwrite,
- )
- print("generated/loaded:", {key: len(value) for key, value in split_map.items()})
- for split_name in ["train", "val"]:
- dataset = build_record_dataset(args.dataset, args.root, split=split_name)
- class_counter = Counter(dataset.records[i].class_name or "none" for i in range(len(dataset.records)))
- print(split_name, "dataset_len=", len(dataset), "classes=", dict(class_counter))
- if __name__ == "__main__":
- main()
|