swin2mmseg.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os.path as osp
  4. from collections import OrderedDict
  5. import mmengine
  6. import torch
  7. from mmengine.runner import CheckpointLoader
  8. def convert_swin(ckpt):
  9. new_ckpt = OrderedDict()
  10. def correct_unfold_reduction_order(x):
  11. out_channel, in_channel = x.shape
  12. x = x.reshape(out_channel, 4, in_channel // 4)
  13. x = x[:, [0, 2, 1, 3], :].transpose(1,
  14. 2).reshape(out_channel, in_channel)
  15. return x
  16. def correct_unfold_norm_order(x):
  17. in_channel = x.shape[0]
  18. x = x.reshape(4, in_channel // 4)
  19. x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
  20. return x
  21. for k, v in ckpt.items():
  22. if k.startswith('head'):
  23. continue
  24. elif k.startswith('layers'):
  25. new_v = v
  26. if 'attn.' in k:
  27. new_k = k.replace('attn.', 'attn.w_msa.')
  28. elif 'mlp.' in k:
  29. if 'mlp.fc1.' in k:
  30. new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.')
  31. elif 'mlp.fc2.' in k:
  32. new_k = k.replace('mlp.fc2.', 'ffn.layers.1.')
  33. else:
  34. new_k = k.replace('mlp.', 'ffn.')
  35. elif 'downsample' in k:
  36. new_k = k
  37. if 'reduction.' in k:
  38. new_v = correct_unfold_reduction_order(v)
  39. elif 'norm.' in k:
  40. new_v = correct_unfold_norm_order(v)
  41. else:
  42. new_k = k
  43. new_k = new_k.replace('layers', 'stages', 1)
  44. elif k.startswith('patch_embed'):
  45. new_v = v
  46. if 'proj' in k:
  47. new_k = k.replace('proj', 'projection')
  48. else:
  49. new_k = k
  50. else:
  51. new_v = v
  52. new_k = k
  53. new_ckpt[new_k] = new_v
  54. return new_ckpt
  55. def main():
  56. parser = argparse.ArgumentParser(
  57. description='Convert keys in official pretrained swin models to'
  58. 'MMSegmentation style.')
  59. parser.add_argument('src', help='src model path or url')
  60. # The dst path must be a full path of the new checkpoint.
  61. parser.add_argument('dst', help='save path')
  62. args = parser.parse_args()
  63. checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
  64. if 'state_dict' in checkpoint:
  65. state_dict = checkpoint['state_dict']
  66. elif 'model' in checkpoint:
  67. state_dict = checkpoint['model']
  68. else:
  69. state_dict = checkpoint
  70. weight = convert_swin(state_dict)
  71. mmengine.mkdir_or_exist(osp.dirname(args.dst))
  72. torch.save(weight, args.dst)
  73. if __name__ == '__main__':
  74. main()