san2mmseg.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os.path as osp
  4. from collections import OrderedDict
  5. import mmengine
  6. import torch
  7. from mmengine.runner import CheckpointLoader
  8. def convert_key_name(ckpt):
  9. new_ckpt = OrderedDict()
  10. for k, v in ckpt.items():
  11. key_list = k.split('.')
  12. if key_list[0] == 'clip_visual_extractor':
  13. new_transform_name = 'image_encoder'
  14. if key_list[1] == 'class_embedding':
  15. new_name = '.'.join([new_transform_name, 'cls_token'])
  16. elif key_list[1] == 'positional_embedding':
  17. new_name = '.'.join([new_transform_name, 'pos_embed'])
  18. elif key_list[1] == 'conv1':
  19. new_name = '.'.join([
  20. new_transform_name, 'patch_embed.projection', key_list[2]
  21. ])
  22. elif key_list[1] == 'ln_pre':
  23. new_name = '.'.join(
  24. [new_transform_name, key_list[1], key_list[2]])
  25. elif key_list[1] == 'resblocks':
  26. new_layer_name = 'layers'
  27. layer_index = key_list[2]
  28. paras = key_list[3:]
  29. if paras[0] == 'ln_1':
  30. new_para_name = '.'.join(['ln1'] + key_list[4:])
  31. elif paras[0] == 'attn':
  32. new_para_name = '.'.join(['attn.attn'] + key_list[4:])
  33. elif paras[0] == 'ln_2':
  34. new_para_name = '.'.join(['ln2'] + key_list[4:])
  35. elif paras[0] == 'mlp':
  36. if paras[1] == 'c_fc':
  37. new_para_name = '.'.join(['ffn.layers.0.0'] +
  38. key_list[-1:])
  39. else:
  40. new_para_name = '.'.join(['ffn.layers.1'] +
  41. key_list[-1:])
  42. new_name = '.'.join([
  43. new_transform_name, new_layer_name, layer_index,
  44. new_para_name
  45. ])
  46. elif key_list[0] == 'side_adapter_network':
  47. decode_head_name = 'decode_head'
  48. module_name = 'side_adapter_network'
  49. if key_list[1] == 'vit_model':
  50. if key_list[2] == 'blocks':
  51. layer_name = 'encode_layers'
  52. layer_index = key_list[3]
  53. paras = key_list[4:]
  54. if paras[0] == 'norm1':
  55. new_para_name = '.'.join(['ln1'] + key_list[5:])
  56. elif paras[0] == 'attn':
  57. new_para_name = '.'.join(key_list[4:])
  58. new_para_name = new_para_name.replace(
  59. 'attn.qkv.', 'attn.attn.in_proj_')
  60. new_para_name = new_para_name.replace(
  61. 'attn.proj', 'attn.attn.out_proj')
  62. elif paras[0] == 'norm2':
  63. new_para_name = '.'.join(['ln2'] + key_list[5:])
  64. elif paras[0] == 'mlp':
  65. new_para_name = '.'.join(['ffn'] + key_list[5:])
  66. new_para_name = new_para_name.replace(
  67. 'fc1', 'layers.0.0')
  68. new_para_name = new_para_name.replace(
  69. 'fc2', 'layers.1')
  70. else:
  71. print(f'Wrong for {k}')
  72. new_name = '.'.join([
  73. decode_head_name, module_name, layer_name, layer_index,
  74. new_para_name
  75. ])
  76. elif key_list[2] == 'pos_embed':
  77. new_name = '.'.join(
  78. [decode_head_name, module_name, 'pos_embed'])
  79. elif key_list[2] == 'patch_embed':
  80. new_name = '.'.join([
  81. decode_head_name, module_name, 'patch_embed',
  82. 'projection', key_list[4]
  83. ])
  84. else:
  85. print(f'Wrong for {k}')
  86. elif key_list[1] == 'query_embed' or key_list[
  87. 1] == 'query_pos_embed':
  88. new_name = '.'.join(
  89. [decode_head_name, module_name, key_list[1]])
  90. elif key_list[1] == 'fusion_layers':
  91. layer_name = 'conv_clips'
  92. layer_index = key_list[2][-1]
  93. paras = '.'.join(key_list[3:])
  94. new_para_name = paras.replace('input_proj.0', '0')
  95. new_para_name = new_para_name.replace('input_proj.1', '1.conv')
  96. new_name = '.'.join([
  97. decode_head_name, module_name, layer_name, layer_index,
  98. new_para_name
  99. ])
  100. elif key_list[1] == 'mask_decoder':
  101. new_name = 'decode_head.' + k
  102. else:
  103. print(f'Wrong for {k}')
  104. elif key_list[0] == 'clip_rec_head':
  105. module_name = 'rec_with_attnbias'
  106. if key_list[1] == 'proj':
  107. new_name = '.'.join(
  108. [decode_head_name, module_name, 'proj.weight'])
  109. elif key_list[1] == 'ln_post':
  110. new_name = '.'.join(
  111. [decode_head_name, module_name, 'ln_post', key_list[2]])
  112. elif key_list[1] == 'resblocks':
  113. new_layer_name = 'layers'
  114. layer_index = key_list[2]
  115. paras = key_list[3:]
  116. if paras[0] == 'ln_1':
  117. new_para_name = '.'.join(['norms.0'] + paras[1:])
  118. elif paras[0] == 'attn':
  119. new_para_name = '.'.join(['attentions.0.attn'] + paras[1:])
  120. elif paras[0] == 'ln_2':
  121. new_para_name = '.'.join(['norms.1'] + paras[1:])
  122. elif paras[0] == 'mlp':
  123. if paras[1] == 'c_fc':
  124. new_para_name = '.'.join(['ffns.0.layers.0.0'] +
  125. paras[2:])
  126. elif paras[1] == 'c_proj':
  127. new_para_name = '.'.join(['ffns.0.layers.1'] +
  128. paras[2:])
  129. else:
  130. print(f'Wrong for {k}')
  131. new_name = '.'.join([
  132. decode_head_name, module_name, new_layer_name, layer_index,
  133. new_para_name
  134. ])
  135. else:
  136. print(f'Wrong for {k}')
  137. elif key_list[0] == 'ov_classifier':
  138. text_encoder_name = 'text_encoder'
  139. if key_list[1] == 'transformer':
  140. layer_name = 'transformer'
  141. layer_index = key_list[3]
  142. paras = key_list[4:]
  143. if paras[0] == 'attn':
  144. new_para_name = '.'.join(['attentions.0.attn'] + paras[1:])
  145. elif paras[0] == 'ln_1':
  146. new_para_name = '.'.join(['norms.0'] + paras[1:])
  147. elif paras[0] == 'ln_2':
  148. new_para_name = '.'.join(['norms.1'] + paras[1:])
  149. elif paras[0] == 'mlp':
  150. if paras[1] == 'c_fc':
  151. new_para_name = '.'.join(['ffns.0.layers.0.0'] +
  152. paras[2:])
  153. elif paras[1] == 'c_proj':
  154. new_para_name = '.'.join(['ffns.0.layers.1'] +
  155. paras[2:])
  156. else:
  157. print(f'Wrong for {k}')
  158. else:
  159. print(f'Wrong for {k}')
  160. new_name = '.'.join([
  161. text_encoder_name, layer_name, layer_index, new_para_name
  162. ])
  163. elif key_list[1] in [
  164. 'positional_embedding', 'text_projection', 'bg_embed',
  165. 'attn_mask', 'logit_scale', 'token_embedding', 'ln_final'
  166. ]:
  167. new_name = k.replace('ov_classifier', 'text_encoder')
  168. else:
  169. print(f'Wrong for {k}')
  170. elif key_list[0] == 'criterion':
  171. new_name = k
  172. else:
  173. print(f'Wrong for {k}')
  174. new_ckpt[new_name] = v
  175. return new_ckpt
  176. def convert_tensor(ckpt):
  177. cls_token = ckpt['image_encoder.cls_token']
  178. new_cls_token = cls_token.unsqueeze(0).unsqueeze(0)
  179. ckpt['image_encoder.cls_token'] = new_cls_token
  180. pos_embed = ckpt['image_encoder.pos_embed']
  181. new_pos_embed = pos_embed.unsqueeze(0)
  182. ckpt['image_encoder.pos_embed'] = new_pos_embed
  183. proj_weight = ckpt['decode_head.rec_with_attnbias.proj.weight']
  184. new_proj_weight = proj_weight.transpose(1, 0)
  185. ckpt['decode_head.rec_with_attnbias.proj.weight'] = new_proj_weight
  186. return ckpt
  187. def main():
  188. parser = argparse.ArgumentParser(
  189. description='Convert keys in timm pretrained vit models to '
  190. 'MMSegmentation style.')
  191. parser.add_argument('src', help='src model path or url')
  192. # The dst path must be a full path of the new checkpoint.
  193. parser.add_argument('dst', help='save path')
  194. args = parser.parse_args()
  195. checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
  196. if 'state_dict' in checkpoint:
  197. # timm checkpoint
  198. state_dict = checkpoint['state_dict']
  199. elif 'model' in checkpoint:
  200. # deit checkpoint
  201. state_dict = checkpoint['model']
  202. else:
  203. state_dict = checkpoint
  204. weight = convert_key_name(state_dict)
  205. weight = convert_tensor(weight)
  206. mmengine.mkdir_or_exist(osp.dirname(args.dst))
  207. torch.save(weight, args.dst)
  208. if __name__ == '__main__':
  209. main()