mit2mmseg.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os.path as osp
  4. from collections import OrderedDict
  5. import mmengine
  6. import torch
  7. from mmengine.runner import CheckpointLoader
  8. def convert_mit(ckpt):
  9. new_ckpt = OrderedDict()
  10. # Process the concat between q linear weights and kv linear weights
  11. for k, v in ckpt.items():
  12. if k.startswith('head'):
  13. continue
  14. # patch embedding conversion
  15. elif k.startswith('patch_embed'):
  16. stage_i = int(k.split('.')[0].replace('patch_embed', ''))
  17. new_k = k.replace(f'patch_embed{stage_i}', f'layers.{stage_i-1}.0')
  18. new_v = v
  19. if 'proj.' in new_k:
  20. new_k = new_k.replace('proj.', 'projection.')
  21. # transformer encoder layer conversion
  22. elif k.startswith('block'):
  23. stage_i = int(k.split('.')[0].replace('block', ''))
  24. new_k = k.replace(f'block{stage_i}', f'layers.{stage_i-1}.1')
  25. new_v = v
  26. if 'attn.q.' in new_k:
  27. sub_item_k = k.replace('q.', 'kv.')
  28. new_k = new_k.replace('q.', 'attn.in_proj_')
  29. new_v = torch.cat([v, ckpt[sub_item_k]], dim=0)
  30. elif 'attn.kv.' in new_k:
  31. continue
  32. elif 'attn.proj.' in new_k:
  33. new_k = new_k.replace('proj.', 'attn.out_proj.')
  34. elif 'attn.sr.' in new_k:
  35. new_k = new_k.replace('sr.', 'sr.')
  36. elif 'mlp.' in new_k:
  37. string = f'{new_k}-'
  38. new_k = new_k.replace('mlp.', 'ffn.layers.')
  39. if 'fc1.weight' in new_k or 'fc2.weight' in new_k:
  40. new_v = v.reshape((*v.shape, 1, 1))
  41. new_k = new_k.replace('fc1.', '0.')
  42. new_k = new_k.replace('dwconv.dwconv.', '1.')
  43. new_k = new_k.replace('fc2.', '4.')
  44. string += f'{new_k} {v.shape}-{new_v.shape}'
  45. # norm layer conversion
  46. elif k.startswith('norm'):
  47. stage_i = int(k.split('.')[0].replace('norm', ''))
  48. new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i-1}.2')
  49. new_v = v
  50. else:
  51. new_k = k
  52. new_v = v
  53. new_ckpt[new_k] = new_v
  54. return new_ckpt
  55. def main():
  56. parser = argparse.ArgumentParser(
  57. description='Convert keys in official pretrained segformer to '
  58. 'MMSegmentation style.')
  59. parser.add_argument('src', help='src model path or url')
  60. # The dst path must be a full path of the new checkpoint.
  61. parser.add_argument('dst', help='save path')
  62. args = parser.parse_args()
  63. checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
  64. if 'state_dict' in checkpoint:
  65. state_dict = checkpoint['state_dict']
  66. elif 'model' in checkpoint:
  67. state_dict = checkpoint['model']
  68. else:
  69. state_dict = checkpoint
  70. weight = convert_mit(state_dict)
  71. mmengine.mkdir_or_exist(osp.dirname(args.dst))
  72. torch.save(weight, args.dst)
  73. if __name__ == '__main__':
  74. main()