test_v1.py 781 B

1234567891011121314151617181920212223242526272829303132
  1. import os
  2. # 加入当前目录的父目录到系统路径,以便导入lib模块
  3. current_dir = os.path.dirname(os.path.abspath(__file__))
  4. parent_dir = os.path.dirname(current_dir)
  5. if parent_dir not in os.sys.path:
  6. os.sys.path.insert(0, parent_dir)
  7. import torch
  8. from lib.modules import XNet2d
  9. assert torch.cuda.is_available(), "CUDA is not available"
  10. model = (
  11. XNet2d(
  12. in_channels=3,
  13. num_classes=1,
  14. encoder_channels=(32, 64, 128, 192),
  15. encoder_depths=(2, 2, 2, 2),
  16. decoder_channels=(128, 64, 32),
  17. ssm_backend="auto",
  18. )
  19. .cuda()
  20. .eval()
  21. )
  22. x = torch.randn(1, 3, 128, 128, device="cuda")
  23. with torch.no_grad():
  24. y = model(x)
  25. print("keys:", sorted(y.keys()))
  26. print("seg_logits:", tuple(y["seg_logits"].shape))