test.py 1.4 KB

1234567891011121314151617181920212223242526272829303132
  1. import torch
  2. from lib.SwinTransformer.models.swin_transformer_v2 import SwinTransformerV2
  3. import warnings
  4. warnings.filterwarnings("ignore")
  5. def swin_v2(size, checkpoint_path, img_size=384, **kwargs):
  6. model = SwinTransformerV2(img_size=img_size,
  7. window_size=12,
  8. embed_dim=96,
  9. depths=[2, 2, 18, 2],
  10. num_heads=[3, 6, 12, 24],
  11. drop_rate=0.05,
  12. attn_drop_rate=0.05,
  13. drop_path_rate=0.1, **kwargs).cuda()
  14. if size == "swinv2_base_patch4_window12to24_192to384_22kto1k_ft":
  15. checkpoint = torch.load(checkpoint_path)
  16. model.load_state_dict(checkpoint['model'], strict=True)
  17. print(f"Loaded model from {checkpoint_path}")
  18. return model
  19. if __name__ == "__main__":
  20. model_test = swin_v2("swinv2_base_patch4_window12to24_192to384_22kto1k_ft",
  21. "../weights/swin_transformer_v2/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth")
  22. # print(model)
  23. features = model_test.forward_multiscale_features(torch.randn(1, 3, 384, 384).cuda())
  24. for i, feature in enumerate(features):
  25. print(f"Feature {i} shape: {feature.shape}")
  26. print("参数量: ", sum(p.numel() for p in model_test.parameters()))