| 1234567891011121314151617181920212223242526272829303132 |
- import torch
- from lib.SwinTransformer.models.swin_transformer_v2 import SwinTransformerV2
- import warnings
- warnings.filterwarnings("ignore")
- def swin_v2(size, checkpoint_path, img_size=384, **kwargs):
- model = SwinTransformerV2(img_size=img_size,
- window_size=12,
- embed_dim=96,
- depths=[2, 2, 18, 2],
- num_heads=[3, 6, 12, 24],
- drop_rate=0.05,
- attn_drop_rate=0.05,
- drop_path_rate=0.1, **kwargs).cuda()
- if size == "swinv2_base_patch4_window12to24_192to384_22kto1k_ft":
- checkpoint = torch.load(checkpoint_path)
- model.load_state_dict(checkpoint['model'], strict=True)
- print(f"Loaded model from {checkpoint_path}")
- return model
- if __name__ == "__main__":
- model_test = swin_v2("swinv2_base_patch4_window12to24_192to384_22kto1k_ft",
- "../weights/swin_transformer_v2/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth")
- # print(model)
- features = model_test.forward_multiscale_features(torch.randn(1, 3, 384, 384).cuda())
- for i, feature in enumerate(features):
- print(f"Feature {i} shape: {feature.shape}")
- print("参数量: ", sum(p.numel() for p in model_test.parameters()))
|