twins2mmseg.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os.path as osp
  4. from collections import OrderedDict
  5. import mmengine
  6. import torch
  7. from mmengine.runner import CheckpointLoader
  8. def convert_twins(args, ckpt):
  9. new_ckpt = OrderedDict()
  10. for k, v in list(ckpt.items()):
  11. new_v = v
  12. if k.startswith('head'):
  13. continue
  14. elif k.startswith('patch_embeds'):
  15. if 'proj.' in k:
  16. new_k = k.replace('proj.', 'projection.')
  17. else:
  18. new_k = k
  19. elif k.startswith('blocks'):
  20. # Union
  21. if 'attn.q.' in k:
  22. new_k = k.replace('q.', 'attn.in_proj_')
  23. new_v = torch.cat([v, ckpt[k.replace('attn.q.', 'attn.kv.')]],
  24. dim=0)
  25. elif 'mlp.fc1' in k:
  26. new_k = k.replace('mlp.fc1', 'ffn.layers.0.0')
  27. elif 'mlp.fc2' in k:
  28. new_k = k.replace('mlp.fc2', 'ffn.layers.1')
  29. # Only pcpvt
  30. elif args.model == 'pcpvt':
  31. if 'attn.proj.' in k:
  32. new_k = k.replace('proj.', 'attn.out_proj.')
  33. else:
  34. new_k = k
  35. # Only svt
  36. else:
  37. if 'attn.proj.' in k:
  38. k_lst = k.split('.')
  39. if int(k_lst[2]) % 2 == 1:
  40. new_k = k.replace('proj.', 'attn.out_proj.')
  41. else:
  42. new_k = k
  43. else:
  44. new_k = k
  45. new_k = new_k.replace('blocks.', 'layers.')
  46. elif k.startswith('pos_block'):
  47. new_k = k.replace('pos_block', 'position_encodings')
  48. if 'proj.0.' in new_k:
  49. new_k = new_k.replace('proj.0.', 'proj.')
  50. else:
  51. new_k = k
  52. if 'attn.kv.' not in k:
  53. new_ckpt[new_k] = new_v
  54. return new_ckpt
  55. def main():
  56. parser = argparse.ArgumentParser(
  57. description='Convert keys in timm pretrained vit models to '
  58. 'MMSegmentation style.')
  59. parser.add_argument('src', help='src model path or url')
  60. # The dst path must be a full path of the new checkpoint.
  61. parser.add_argument('dst', help='save path')
  62. parser.add_argument('model', help='model: pcpvt or svt')
  63. args = parser.parse_args()
  64. checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
  65. if 'state_dict' in checkpoint:
  66. # timm checkpoint
  67. state_dict = checkpoint['state_dict']
  68. else:
  69. state_dict = checkpoint
  70. weight = convert_twins(args, state_dict)
  71. mmengine.mkdir_or_exist(osp.dirname(args.dst))
  72. torch.save(weight, args.dst)
  73. if __name__ == '__main__':
  74. main()