publish_model.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import subprocess
  4. import torch
  5. from mmengine.logging import print_log
  6. from mmengine.utils import digit_version
  7. def parse_args():
  8. parser = argparse.ArgumentParser(
  9. description='Process a checkpoint to be published')
  10. parser.add_argument('in_file', help='input checkpoint filename')
  11. parser.add_argument('out_file', help='output checkpoint filename')
  12. parser.add_argument(
  13. '--save-keys',
  14. nargs='+',
  15. type=str,
  16. default=['meta', 'state_dict'],
  17. help='keys to save in the published checkpoint')
  18. args = parser.parse_args()
  19. return args
  20. def process_checkpoint(in_file, out_file, save_keys=['meta', 'state_dict']):
  21. checkpoint = torch.load(in_file, map_location='cpu')
  22. # only keep `meta` and `state_dict` for smaller file size
  23. ckpt_keys = list(checkpoint.keys())
  24. for k in ckpt_keys:
  25. if k not in save_keys:
  26. print_log(
  27. f'Key `{k}` will be removed because it is not in '
  28. f'save_keys. If you want to keep it, '
  29. f'please set --save-keys.',
  30. logger='current')
  31. checkpoint.pop(k, None)
  32. # if it is necessary to remove some sensitive data in checkpoint['meta'],
  33. # add the code here.
  34. if digit_version(torch.__version__) >= digit_version('1.6'):
  35. torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False)
  36. else:
  37. torch.save(checkpoint, out_file)
  38. sha = subprocess.check_output(['sha256sum', out_file]).decode()
  39. if out_file.endswith('.pth'):
  40. out_file_name = out_file[:-4]
  41. else:
  42. out_file_name = out_file
  43. final_file = out_file_name + f'-{sha[:8]}.pth'
  44. subprocess.Popen(['mv', out_file, final_file])
  45. print_log(
  46. f'The published model is saved at {final_file}.', logger='current')
  47. def main():
  48. args = parse_args()
  49. process_checkpoint(args.in_file, args.out_file, args.save_keys)
  50. if __name__ == '__main__':
  51. main()