test_xnet_2d.py 1.1 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. from __future__ import annotations
  2. import torch
  3. from torch import nn
  4. from lib.modules.xnet_2d import XNet2d, XTEB2d
  5. def test_xnet2d_forward_preserves_segmentation_shape() -> None:
  6. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  7. model = XNet2d(
  8. in_channels=3,
  9. num_classes=1,
  10. encoder_channels=(8, 16, 24, 32),
  11. encoder_depths=(1, 1, 1, 1),
  12. decoder_channels=(24, 16, 8),
  13. stem_channels=8,
  14. bottleneck_depth=1,
  15. global_ratio=1.0,
  16. use_wavelet_branch=True,
  17. use_global_branch_stage1=False,
  18. ssm_d_state=1,
  19. ssm_backend="torch",
  20. use_frequency_refine=True,
  21. learnable_low_freq_radius=False,
  22. ).to(device)
  23. if device.type == "cpu":
  24. for module in model.modules():
  25. if isinstance(module, XTEB2d):
  26. module.global_branch = nn.Identity()
  27. model.eval()
  28. x = torch.randn(2, 3, 64, 64, device=device)
  29. with torch.no_grad():
  30. outputs = model(x)
  31. assert outputs["seg_logits"].shape == (2, 1, 64, 64)
  32. assert outputs["logits"].shape == outputs["seg_logits"].shape