from __future__ import annotations import importlib import sys import warnings import torch from torch import nn def test_importing_xnet2d_does_not_emit_deprecation_warnings() -> None: modules_to_clear = [ name for name in sys.modules if name == "lib.modules.xnet_2d" or name.startswith("lib.modules.lib_mamba") ] for name in modules_to_clear: sys.modules.pop(name, None) with warnings.catch_warnings(record=True) as caught: warnings.simplefilter("always", DeprecationWarning) importlib.import_module("lib.modules.xnet_2d") assert not [ warning for warning in caught if issubclass(warning.category, DeprecationWarning) ] def test_xnet2d_forward_preserves_segmentation_shape() -> None: from lib.modules.xnet_2d import XNet2d, XTEB2d device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = XNet2d( in_channels=3, num_classes=1, encoder_channels=(8, 16, 24, 32), encoder_depths=(1, 1, 1, 1), decoder_channels=(24, 16, 8), stem_channels=8, bottleneck_depth=1, global_ratio=1.0, use_wavelet_branch=True, use_global_branch_stage1=False, ssm_d_state=1, ssm_backend="torch", use_frequency_refine=True, learnable_low_freq_radius=False, ).to(device) if device.type == "cpu": for module in model.modules(): if isinstance(module, XTEB2d): module.global_branch = nn.Identity() model.eval() x = torch.randn(2, 3, 64, 64, device=device) with torch.no_grad(): outputs = model(x) assert outputs["seg_logits"].shape == (2, 1, 64, 64) assert outputs["logits"].shape == outputs["seg_logits"].shape