build_sam.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import torch
  6. from functools import partial
  7. from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
  8. def build_sam_vit_h(checkpoint=None):
  9. return _build_sam(
  10. encoder_embed_dim=1280,
  11. encoder_depth=32,
  12. encoder_num_heads=16,
  13. encoder_global_attn_indexes=[7, 15, 23, 31],
  14. checkpoint=checkpoint,
  15. )
  16. build_sam = build_sam_vit_h
  17. def build_sam_vit_l(checkpoint=None):
  18. return _build_sam(
  19. encoder_embed_dim=1024,
  20. encoder_depth=24,
  21. encoder_num_heads=16,
  22. encoder_global_attn_indexes=[5, 11, 17, 23],
  23. checkpoint=checkpoint,
  24. )
  25. def build_sam_vit_b(checkpoint=None):
  26. return _build_sam(
  27. encoder_embed_dim=768,
  28. encoder_depth=12,
  29. encoder_num_heads=12,
  30. encoder_global_attn_indexes=[2, 5, 8, 11],
  31. checkpoint=checkpoint,
  32. )
  33. sam_model_registry = {
  34. "default": build_sam_vit_h,
  35. "vit_h": build_sam_vit_h,
  36. "vit_l": build_sam_vit_l,
  37. "vit_b": build_sam_vit_b,
  38. }
  39. def _build_sam(
  40. encoder_embed_dim,
  41. encoder_depth,
  42. encoder_num_heads,
  43. encoder_global_attn_indexes,
  44. checkpoint=None,
  45. ):
  46. prompt_embed_dim = 256
  47. image_size = 1024
  48. vit_patch_size = 16
  49. image_embedding_size = image_size // vit_patch_size
  50. sam = Sam(
  51. image_encoder=ImageEncoderViT(
  52. depth=encoder_depth,
  53. embed_dim=encoder_embed_dim,
  54. img_size=image_size,
  55. mlp_ratio=4,
  56. norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
  57. num_heads=encoder_num_heads,
  58. patch_size=vit_patch_size,
  59. qkv_bias=True,
  60. use_rel_pos=True,
  61. global_attn_indexes=encoder_global_attn_indexes,
  62. window_size=14,
  63. out_chans=prompt_embed_dim,
  64. ),
  65. prompt_encoder=PromptEncoder(
  66. embed_dim=prompt_embed_dim,
  67. image_embedding_size=(image_embedding_size, image_embedding_size),
  68. input_image_size=(image_size, image_size),
  69. mask_in_chans=16,
  70. ),
  71. mask_decoder=MaskDecoder(
  72. num_multimask_outputs=3,
  73. transformer=TwoWayTransformer(
  74. depth=2,
  75. embedding_dim=prompt_embed_dim,
  76. mlp_dim=2048,
  77. num_heads=8,
  78. ),
  79. transformer_dim=prompt_embed_dim,
  80. iou_head_depth=3,
  81. iou_head_hidden_dim=256,
  82. ),
  83. pixel_mean=[123.675, 116.28, 103.53],
  84. pixel_std=[58.395, 57.12, 57.375],
  85. )
  86. sam.eval()
  87. if checkpoint is not None:
  88. with open(checkpoint, "rb") as f:
  89. state_dict = torch.load(f)
  90. sam.load_state_dict(state_dict)
  91. return sam