import torch from lib.SwinTransformer.models.swin_transformer_v2 import SwinTransformerV2 import warnings warnings.filterwarnings("ignore") def swin_v2(size, checkpoint_path, img_size=384, **kwargs): model = SwinTransformerV2(img_size=img_size, window_size=12, embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24], drop_rate=0.05, attn_drop_rate=0.05, drop_path_rate=0.1, **kwargs).cuda() if size == "swinv2_base_patch4_window12to24_192to384_22kto1k_ft": checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint['model'], strict=True) print(f"Loaded model from {checkpoint_path}") return model if __name__ == "__main__": model_test = swin_v2("swinv2_base_patch4_window12to24_192to384_22kto1k_ft", "../weights/swin_transformer_v2/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth") # print(model) features = model_test.forward_multiscale_features(torch.randn(1, 3, 384, 384).cuda()) for i, feature in enumerate(features): print(f"Feature {i} shape: {feature.shape}") print("参数量: ", sum(p.numel() for p in model_test.parameters()))