| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253 |
- from __future__ import annotations
- import argparse
- import sys
- from pathlib import Path
- ROOT_DIR = Path(__file__).resolve().parents[1]
- if str(ROOT_DIR) not in sys.path:
- sys.path.insert(0, str(ROOT_DIR))
- from lib.trainers import build_trainer
- from lib.utils.config import apply_dotlist_overrides, load_yaml_config
- def parse_args() -> argparse.Namespace:
- parser = argparse.ArgumentParser(description="Unified training entrypoint.")
- parser.add_argument(
- "--config",
- type=str,
- required=True,
- help="Path to yaml config.",
- )
- parser.add_argument(
- "--trainer",
- type=str,
- default=None,
- help="Override trainer name from config.",
- )
- parser.add_argument(
- "--set",
- nargs="*",
- default=None,
- help="Override config values with key=value pairs, e.g. train.epochs=2 model.load_weights=false",
- )
- return parser.parse_args()
- def main() -> None:
- args = parse_args()
- cfg_path = ROOT_DIR / args.config if not Path(args.config).is_absolute() else Path(args.config)
- cfg = load_yaml_config(cfg_path)
- cfg = apply_dotlist_overrides(cfg, args.set)
- if args.trainer is not None:
- cfg.setdefault("trainer", {})
- cfg["trainer"]["name"] = args.trainer
- trainer = build_trainer(cfg, args=args)
- trainer.train()
- if __name__ == "__main__":
- main()
|