clip2mmseg.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os.path as osp
  4. from collections import OrderedDict
  5. import mmengine
  6. import torch
  7. from mmengine.runner import CheckpointLoader
  8. def convert_vitlayer(paras):
  9. new_para_name = ''
  10. if paras[0] == 'ln_1':
  11. new_para_name = '.'.join(['ln1'] + paras[1:])
  12. elif paras[0] == 'attn':
  13. new_para_name = '.'.join(['attn.attn'] + paras[1:])
  14. elif paras[0] == 'ln_2':
  15. new_para_name = '.'.join(['ln2'] + paras[1:])
  16. elif paras[0] == 'mlp':
  17. if paras[1] == 'c_fc':
  18. new_para_name = '.'.join(['ffn.layers.0.0'] + paras[-1:])
  19. else:
  20. new_para_name = '.'.join(['ffn.layers.1'] + paras[-1:])
  21. else:
  22. print(f'Wrong for {paras}')
  23. return new_para_name
  24. def convert_translayer(paras):
  25. new_para_name = ''
  26. if paras[0] == 'attn':
  27. new_para_name = '.'.join(['attentions.0.attn'] + paras[1:])
  28. elif paras[0] == 'ln_1':
  29. new_para_name = '.'.join(['norms.0'] + paras[1:])
  30. elif paras[0] == 'ln_2':
  31. new_para_name = '.'.join(['norms.1'] + paras[1:])
  32. elif paras[0] == 'mlp':
  33. if paras[1] == 'c_fc':
  34. new_para_name = '.'.join(['ffns.0.layers.0.0'] + paras[2:])
  35. elif paras[1] == 'c_proj':
  36. new_para_name = '.'.join(['ffns.0.layers.1'] + paras[2:])
  37. else:
  38. print(f'Wrong for {paras}')
  39. else:
  40. print(f'Wrong for {paras}')
  41. return new_para_name
  42. def convert_key_name(ckpt, visual_split):
  43. new_ckpt = OrderedDict()
  44. for k, v in ckpt.items():
  45. key_list = k.split('.')
  46. if key_list[0] == 'visual':
  47. new_transform_name = 'image_encoder'
  48. if key_list[1] == 'class_embedding':
  49. new_name = '.'.join([new_transform_name, 'cls_token'])
  50. elif key_list[1] == 'positional_embedding':
  51. new_name = '.'.join([new_transform_name, 'pos_embed'])
  52. elif key_list[1] == 'conv1':
  53. new_name = '.'.join([
  54. new_transform_name, 'patch_embed.projection', key_list[2]
  55. ])
  56. elif key_list[1] == 'ln_pre':
  57. new_name = '.'.join(
  58. [new_transform_name, key_list[1], key_list[2]])
  59. elif key_list[1] == 'transformer':
  60. new_layer_name = 'layers'
  61. layer_index = key_list[3]
  62. paras = key_list[4:]
  63. if int(layer_index) < visual_split:
  64. new_para_name = convert_vitlayer(paras)
  65. new_name = '.'.join([
  66. new_transform_name, new_layer_name, layer_index,
  67. new_para_name
  68. ])
  69. else:
  70. new_para_name = convert_translayer(paras)
  71. new_transform_name = 'decode_head.rec_with_attnbias'
  72. new_layer_name = 'layers'
  73. layer_index = str(int(layer_index) - visual_split)
  74. new_name = '.'.join([
  75. new_transform_name, new_layer_name, layer_index,
  76. new_para_name
  77. ])
  78. elif key_list[1] == 'proj':
  79. new_name = 'decode_head.rec_with_attnbias.proj.weight'
  80. elif key_list[1] == 'ln_post':
  81. new_name = k.replace('visual', 'decode_head.rec_with_attnbias')
  82. else:
  83. print(f'pop parameter: {k}')
  84. continue
  85. else:
  86. text_encoder_name = 'text_encoder'
  87. if key_list[0] == 'transformer':
  88. layer_name = 'transformer'
  89. layer_index = key_list[2]
  90. paras = key_list[3:]
  91. new_para_name = convert_translayer(paras)
  92. new_name = '.'.join([
  93. text_encoder_name, layer_name, layer_index, new_para_name
  94. ])
  95. elif key_list[0] in [
  96. 'positional_embedding', 'text_projection', 'bg_embed',
  97. 'attn_mask', 'logit_scale', 'token_embedding', 'ln_final'
  98. ]:
  99. new_name = 'text_encoder.' + k
  100. else:
  101. print(f'pop parameter: {k}')
  102. continue
  103. new_ckpt[new_name] = v
  104. return new_ckpt
  105. def convert_tensor(ckpt):
  106. cls_token = ckpt['image_encoder.cls_token']
  107. new_cls_token = cls_token.unsqueeze(0).unsqueeze(0)
  108. ckpt['image_encoder.cls_token'] = new_cls_token
  109. pos_embed = ckpt['image_encoder.pos_embed']
  110. new_pos_embed = pos_embed.unsqueeze(0)
  111. ckpt['image_encoder.pos_embed'] = new_pos_embed
  112. proj_weight = ckpt['decode_head.rec_with_attnbias.proj.weight']
  113. new_proj_weight = proj_weight.transpose(1, 0)
  114. ckpt['decode_head.rec_with_attnbias.proj.weight'] = new_proj_weight
  115. return ckpt
  116. def main():
  117. parser = argparse.ArgumentParser(
  118. description='Convert keys in timm pretrained vit models to '
  119. 'MMSegmentation style.')
  120. parser.add_argument('src', help='src model path or url')
  121. # The dst path must be a full path of the new checkpoint.
  122. parser.add_argument('dst', help='save path')
  123. args = parser.parse_args()
  124. if any([s in args.src for s in ['B-16', 'b16', 'base_patch16']]):
  125. visual_split = 9
  126. elif any([s in args.src for s in ['L-14', 'l14', 'large_patch14']]):
  127. visual_split = 18
  128. else:
  129. print('Make sure the clip model is ViT-B/16 or ViT-L/14!')
  130. visual_split = -1
  131. checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
  132. if isinstance(checkpoint, torch.jit.RecursiveScriptModule):
  133. state_dict = checkpoint.state_dict()
  134. else:
  135. if 'state_dict' in checkpoint:
  136. # timm checkpoint
  137. state_dict = checkpoint['state_dict']
  138. elif 'model' in checkpoint:
  139. # deit checkpoint
  140. state_dict = checkpoint['model']
  141. else:
  142. state_dict = checkpoint
  143. weight = convert_key_name(state_dict, visual_split)
  144. weight = convert_tensor(weight)
  145. mmengine.mkdir_or_exist(osp.dirname(args.dst))
  146. torch.save(weight, args.dst)
  147. if __name__ == '__main__':
  148. main()