generate_project_split.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. from __future__ import annotations
  2. import argparse
  3. import sys
  4. from collections import Counter
  5. from pathlib import Path
  6. ROOT_DIR = Path(__file__).resolve().parents[1]
  7. if str(ROOT_DIR) not in sys.path:
  8. sys.path.insert(0, str(ROOT_DIR))
  9. from lib.data import (
  10. PROJECT_SPLIT_DATASETS,
  11. build_record_dataset,
  12. generate_project_splits,
  13. )
  14. def main() -> None:
  15. parser = argparse.ArgumentParser(description="Generate project train/val split files.")
  16. parser.add_argument("--dataset", required=True, help="Dataset name")
  17. parser.add_argument("--root", required=True, help="Dataset root")
  18. parser.add_argument("--val-ratio", type=float, default=0.2, help="Validation ratio")
  19. parser.add_argument("--seed", type=int, default=42, help="Random seed")
  20. parser.add_argument(
  21. "--no-stratify",
  22. action="store_true",
  23. help="Disable class-wise stratified split when class labels exist",
  24. )
  25. parser.add_argument(
  26. "--overwrite",
  27. action="store_true",
  28. help="Regenerate split files even when matching files already exist.",
  29. )
  30. args = parser.parse_args()
  31. if args.dataset not in PROJECT_SPLIT_DATASETS:
  32. raise ValueError(f"Dataset '{args.dataset}' does not require project split generation.")
  33. split_map = generate_project_splits(
  34. dataset_name=args.dataset,
  35. root=args.root,
  36. val_ratio=args.val_ratio,
  37. seed=args.seed,
  38. stratify_by_class=not args.no_stratify,
  39. reuse_existing=not args.overwrite,
  40. )
  41. print("generated/loaded:", {key: len(value) for key, value in split_map.items()})
  42. for split_name in ["train", "val"]:
  43. dataset = build_record_dataset(args.dataset, args.root, split=split_name)
  44. class_counter = Counter(dataset.records[i].class_name or "none" for i in range(len(dataset.records)))
  45. print(split_name, "dataset_len=", len(dataset), "classes=", dict(class_counter))
  46. if __name__ == "__main__":
  47. main()