beit2mmseg.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os.path as osp
  4. from collections import OrderedDict
  5. import mmengine
  6. import torch
  7. from mmengine.runner import CheckpointLoader
  8. def convert_beit(ckpt):
  9. new_ckpt = OrderedDict()
  10. for k, v in ckpt.items():
  11. if k.startswith('patch_embed'):
  12. new_key = k.replace('patch_embed.proj', 'patch_embed.projection')
  13. new_ckpt[new_key] = v
  14. if k.startswith('blocks'):
  15. new_key = k.replace('blocks', 'layers')
  16. if 'norm' in new_key:
  17. new_key = new_key.replace('norm', 'ln')
  18. elif 'mlp.fc1' in new_key:
  19. new_key = new_key.replace('mlp.fc1', 'ffn.layers.0.0')
  20. elif 'mlp.fc2' in new_key:
  21. new_key = new_key.replace('mlp.fc2', 'ffn.layers.1')
  22. new_ckpt[new_key] = v
  23. else:
  24. new_key = k
  25. new_ckpt[new_key] = v
  26. return new_ckpt
  27. def main():
  28. parser = argparse.ArgumentParser(
  29. description='Convert keys in official pretrained beit models to'
  30. 'MMSegmentation style.')
  31. parser.add_argument('src', help='src model path or url')
  32. # The dst path must be a full path of the new checkpoint.
  33. parser.add_argument('dst', help='save path')
  34. args = parser.parse_args()
  35. checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
  36. if 'state_dict' in checkpoint:
  37. state_dict = checkpoint['state_dict']
  38. elif 'model' in checkpoint:
  39. state_dict = checkpoint['model']
  40. else:
  41. state_dict = checkpoint
  42. weight = convert_beit(state_dict)
  43. mmengine.mkdir_or_exist(osp.dirname(args.dst))
  44. torch.save(weight, args.dst)
  45. if __name__ == '__main__':
  46. main()