vitjax2mmseg.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os.path as osp
  4. import mmengine
  5. import numpy as np
  6. import torch
  7. def vit_jax_to_torch(jax_weights, num_layer=12):
  8. torch_weights = dict()
  9. # patch embedding
  10. conv_filters = jax_weights['embedding/kernel']
  11. conv_filters = conv_filters.permute(3, 2, 0, 1)
  12. torch_weights['patch_embed.projection.weight'] = conv_filters
  13. torch_weights['patch_embed.projection.bias'] = jax_weights[
  14. 'embedding/bias']
  15. # pos embedding
  16. torch_weights['pos_embed'] = jax_weights[
  17. 'Transformer/posembed_input/pos_embedding']
  18. # cls token
  19. torch_weights['cls_token'] = jax_weights['cls']
  20. # head
  21. torch_weights['ln1.weight'] = jax_weights['Transformer/encoder_norm/scale']
  22. torch_weights['ln1.bias'] = jax_weights['Transformer/encoder_norm/bias']
  23. # transformer blocks
  24. for i in range(num_layer):
  25. jax_block = f'Transformer/encoderblock_{i}'
  26. torch_block = f'layers.{i}'
  27. # attention norm
  28. torch_weights[f'{torch_block}.ln1.weight'] = jax_weights[
  29. f'{jax_block}/LayerNorm_0/scale']
  30. torch_weights[f'{torch_block}.ln1.bias'] = jax_weights[
  31. f'{jax_block}/LayerNorm_0/bias']
  32. # attention
  33. query_weight = jax_weights[
  34. f'{jax_block}/MultiHeadDotProductAttention_1/query/kernel']
  35. query_bias = jax_weights[
  36. f'{jax_block}/MultiHeadDotProductAttention_1/query/bias']
  37. key_weight = jax_weights[
  38. f'{jax_block}/MultiHeadDotProductAttention_1/key/kernel']
  39. key_bias = jax_weights[
  40. f'{jax_block}/MultiHeadDotProductAttention_1/key/bias']
  41. value_weight = jax_weights[
  42. f'{jax_block}/MultiHeadDotProductAttention_1/value/kernel']
  43. value_bias = jax_weights[
  44. f'{jax_block}/MultiHeadDotProductAttention_1/value/bias']
  45. qkv_weight = torch.from_numpy(
  46. np.stack((query_weight, key_weight, value_weight), 1))
  47. qkv_weight = torch.flatten(qkv_weight, start_dim=1)
  48. qkv_bias = torch.from_numpy(
  49. np.stack((query_bias, key_bias, value_bias), 0))
  50. qkv_bias = torch.flatten(qkv_bias, start_dim=0)
  51. torch_weights[f'{torch_block}.attn.attn.in_proj_weight'] = qkv_weight
  52. torch_weights[f'{torch_block}.attn.attn.in_proj_bias'] = qkv_bias
  53. to_out_weight = jax_weights[
  54. f'{jax_block}/MultiHeadDotProductAttention_1/out/kernel']
  55. to_out_weight = torch.flatten(to_out_weight, start_dim=0, end_dim=1)
  56. torch_weights[
  57. f'{torch_block}.attn.attn.out_proj.weight'] = to_out_weight
  58. torch_weights[f'{torch_block}.attn.attn.out_proj.bias'] = jax_weights[
  59. f'{jax_block}/MultiHeadDotProductAttention_1/out/bias']
  60. # mlp norm
  61. torch_weights[f'{torch_block}.ln2.weight'] = jax_weights[
  62. f'{jax_block}/LayerNorm_2/scale']
  63. torch_weights[f'{torch_block}.ln2.bias'] = jax_weights[
  64. f'{jax_block}/LayerNorm_2/bias']
  65. # mlp
  66. torch_weights[f'{torch_block}.ffn.layers.0.0.weight'] = jax_weights[
  67. f'{jax_block}/MlpBlock_3/Dense_0/kernel']
  68. torch_weights[f'{torch_block}.ffn.layers.0.0.bias'] = jax_weights[
  69. f'{jax_block}/MlpBlock_3/Dense_0/bias']
  70. torch_weights[f'{torch_block}.ffn.layers.1.weight'] = jax_weights[
  71. f'{jax_block}/MlpBlock_3/Dense_1/kernel']
  72. torch_weights[f'{torch_block}.ffn.layers.1.bias'] = jax_weights[
  73. f'{jax_block}/MlpBlock_3/Dense_1/bias']
  74. # transpose weights
  75. for k, v in torch_weights.items():
  76. if 'weight' in k and 'patch_embed' not in k and 'ln' not in k:
  77. v = v.permute(1, 0)
  78. torch_weights[k] = v
  79. return torch_weights
  80. def main():
  81. # stole refactoring code from Robin Strudel, thanks
  82. parser = argparse.ArgumentParser(
  83. description='Convert keys from jax official pretrained vit models to '
  84. 'MMSegmentation style.')
  85. parser.add_argument('src', help='src model path or url')
  86. # The dst path must be a full path of the new checkpoint.
  87. parser.add_argument('dst', help='save path')
  88. args = parser.parse_args()
  89. jax_weights = np.load(args.src)
  90. jax_weights_tensor = {}
  91. for key in jax_weights.files:
  92. value = torch.from_numpy(jax_weights[key])
  93. jax_weights_tensor[key] = value
  94. if 'L_16-i21k' in args.src:
  95. num_layer = 24
  96. else:
  97. num_layer = 12
  98. torch_weights = vit_jax_to_torch(jax_weights_tensor, num_layer)
  99. mmengine.mkdir_or_exist(osp.dirname(args.dst))
  100. torch.save(torch_weights, args.dst)
  101. if __name__ == '__main__':
  102. main()