train.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. from __future__ import annotations
  2. import argparse
  3. import sys
  4. from pathlib import Path
  5. ROOT_DIR = Path(__file__).resolve().parents[1]
  6. if str(ROOT_DIR) not in sys.path:
  7. sys.path.insert(0, str(ROOT_DIR))
  8. from lib.trainers import build_trainer
  9. from lib.utils.config import apply_dotlist_overrides, load_yaml_config
  10. def parse_args() -> argparse.Namespace:
  11. parser = argparse.ArgumentParser(description="Unified training entrypoint.")
  12. parser.add_argument(
  13. "--config",
  14. type=str,
  15. required=True,
  16. help="Path to yaml config.",
  17. )
  18. parser.add_argument(
  19. "--trainer",
  20. type=str,
  21. default=None,
  22. help="Override trainer name from config.",
  23. )
  24. parser.add_argument(
  25. "--set",
  26. nargs="*",
  27. default=None,
  28. help="Override config values with key=value pairs, e.g. train.epochs=2 model.use_wavelet_branch=false",
  29. )
  30. return parser.parse_args()
  31. def main() -> None:
  32. args = parse_args()
  33. cfg_path = ROOT_DIR / args.config if not Path(args.config).is_absolute() else Path(args.config)
  34. cfg = load_yaml_config(cfg_path)
  35. cfg = apply_dotlist_overrides(cfg, args.set)
  36. if args.trainer is not None:
  37. cfg.setdefault("trainer", {})
  38. cfg["trainer"]["name"] = args.trainer
  39. trainer = build_trainer(cfg, args=args)
  40. trainer.train()
  41. if __name__ == "__main__":
  42. main()