from __future__ import annotations import argparse import sys from collections import Counter from pathlib import Path ROOT_DIR = Path(__file__).resolve().parents[1] if str(ROOT_DIR) not in sys.path: sys.path.insert(0, str(ROOT_DIR)) from lib.data import ( PROJECT_SPLIT_DATASETS, build_record_dataset, generate_project_splits, ) def main() -> None: parser = argparse.ArgumentParser(description="Generate project train/val split files.") parser.add_argument("--dataset", required=True, help="Dataset name") parser.add_argument("--root", required=True, help="Dataset root") parser.add_argument("--val-ratio", type=float, default=0.2, help="Validation ratio") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument( "--no-stratify", action="store_true", help="Disable class-wise stratified split when class labels exist", ) parser.add_argument( "--overwrite", action="store_true", help="Regenerate split files even when matching files already exist.", ) args = parser.parse_args() if args.dataset not in PROJECT_SPLIT_DATASETS: raise ValueError(f"Dataset '{args.dataset}' does not require project split generation.") split_map = generate_project_splits( dataset_name=args.dataset, root=args.root, val_ratio=args.val_ratio, seed=args.seed, stratify_by_class=not args.no_stratify, reuse_existing=not args.overwrite, ) print("generated/loaded:", {key: len(value) for key, value in split_map.items()}) for split_name in ["train", "val"]: dataset = build_record_dataset(args.dataset, args.root, split=split_name) class_counter = Counter(dataset.records[i].class_name or "none" for i in range(len(dataset.records))) print(split_name, "dataset_len=", len(dataset), "classes=", dict(class_counter)) if __name__ == "__main__": main()