test_xnet_2d.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. from __future__ import annotations
  2. import importlib
  3. import sys
  4. import warnings
  5. import torch
  6. from torch import nn
  7. def test_importing_xnet2d_does_not_emit_deprecation_warnings() -> None:
  8. modules_to_clear = [
  9. name
  10. for name in sys.modules
  11. if name == "lib.modules.xnet_2d" or name.startswith("lib.modules.lib_mamba")
  12. ]
  13. for name in modules_to_clear:
  14. sys.modules.pop(name, None)
  15. with warnings.catch_warnings(record=True) as caught:
  16. warnings.simplefilter("always", DeprecationWarning)
  17. importlib.import_module("lib.modules.xnet_2d")
  18. assert not [
  19. warning
  20. for warning in caught
  21. if issubclass(warning.category, DeprecationWarning)
  22. ]
  23. def test_xnet2d_forward_preserves_segmentation_shape() -> None:
  24. from lib.modules.xnet_2d import XNet2d, XTEB2d
  25. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  26. model = XNet2d(
  27. in_channels=3,
  28. num_classes=1,
  29. encoder_channels=(8, 16, 24, 32),
  30. encoder_depths=(1, 1, 1, 1),
  31. decoder_channels=(24, 16, 8),
  32. stem_channels=8,
  33. bottleneck_depth=1,
  34. global_ratio=1.0,
  35. use_wavelet_branch=True,
  36. use_global_branch_stage1=False,
  37. ssm_d_state=1,
  38. ssm_backend="torch",
  39. use_frequency_refine=True,
  40. learnable_low_freq_radius=False,
  41. ).to(device)
  42. if device.type == "cpu":
  43. for module in model.modules():
  44. if isinstance(module, XTEB2d):
  45. module.global_branch = nn.Identity()
  46. model.eval()
  47. x = torch.randn(2, 3, 64, 64, device=device)
  48. with torch.no_grad():
  49. outputs = model(x)
  50. assert outputs["seg_logits"].shape == (2, 1, 64, 64)
  51. assert outputs["logits"].shape == outputs["seg_logits"].shape