vit2mmseg.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os.path as osp
  4. from collections import OrderedDict
  5. import mmengine
  6. import torch
  7. from mmengine.runner import CheckpointLoader
  8. def convert_vit(ckpt):
  9. new_ckpt = OrderedDict()
  10. for k, v in ckpt.items():
  11. if k.startswith('head'):
  12. continue
  13. if k.startswith('norm'):
  14. new_k = k.replace('norm.', 'ln1.')
  15. elif k.startswith('patch_embed'):
  16. if 'proj' in k:
  17. new_k = k.replace('proj', 'projection')
  18. else:
  19. new_k = k
  20. elif k.startswith('blocks'):
  21. if 'norm' in k:
  22. new_k = k.replace('norm', 'ln')
  23. elif 'mlp.fc1' in k:
  24. new_k = k.replace('mlp.fc1', 'ffn.layers.0.0')
  25. elif 'mlp.fc2' in k:
  26. new_k = k.replace('mlp.fc2', 'ffn.layers.1')
  27. elif 'attn.qkv' in k:
  28. new_k = k.replace('attn.qkv.', 'attn.attn.in_proj_')
  29. elif 'attn.proj' in k:
  30. new_k = k.replace('attn.proj', 'attn.attn.out_proj')
  31. else:
  32. new_k = k
  33. new_k = new_k.replace('blocks.', 'layers.')
  34. else:
  35. new_k = k
  36. new_ckpt[new_k] = v
  37. return new_ckpt
  38. def main():
  39. parser = argparse.ArgumentParser(
  40. description='Convert keys in timm pretrained vit models to '
  41. 'MMSegmentation style.')
  42. parser.add_argument('src', help='src model path or url')
  43. # The dst path must be a full path of the new checkpoint.
  44. parser.add_argument('dst', help='save path')
  45. args = parser.parse_args()
  46. checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
  47. if 'state_dict' in checkpoint:
  48. # timm checkpoint
  49. state_dict = checkpoint['state_dict']
  50. elif 'model' in checkpoint:
  51. # deit checkpoint
  52. state_dict = checkpoint['model']
  53. else:
  54. state_dict = checkpoint
  55. weight = convert_vit(state_dict)
  56. mmengine.mkdir_or_exist(osp.dirname(args.dst))
  57. torch.save(weight, args.dst)
  58. if __name__ == '__main__':
  59. main()