run.py 1.0 KB

1234567891011121314151617181920212223242526272829
  1. import argparse
  2. from configs import get_cfg
  3. from util.net import init_training
  4. from util.util import run_pre, init_checkpoint
  5. from trainer import get_trainer
  6. import warnings
  7. warnings.filterwarnings("ignore")
  8. def main():
  9. parser = argparse.ArgumentParser()
  10. parser.add_argument('-c', '--cfg_path', default='configs/mobilemamba/mobilemamba_t2.py')
  11. parser.add_argument('-m', '--mode', default='train', choices=['train', 'test', 'test_net', 'ft', 'search'])
  12. parser.add_argument('--sleep', type=int, default=-1)
  13. parser.add_argument('--memory', type=int, default=-1)
  14. parser.add_argument('--dist_url', default='env://', type=str, help='url used to set up distributed training')
  15. parser.add_argument('--logger_rank', default=0, type=int, help='GPU id to use.')
  16. parser.add_argument('opts', help='path.key=value', default=None, nargs=argparse.REMAINDER,)
  17. cfg_terminal = parser.parse_args()
  18. cfg = get_cfg(cfg_terminal)
  19. run_pre(cfg)
  20. init_training(cfg)
  21. init_checkpoint(cfg)
  22. trainer = get_trainer(cfg)
  23. trainer.run()
  24. if __name__ == '__main__':
  25. main()