from __future__ import annotations import torch from torch import nn from lib.modules.xnet_2d import XNet2d, XTEB2d def test_xnet2d_forward_preserves_segmentation_shape() -> None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = XNet2d( in_channels=3, num_classes=1, encoder_channels=(8, 16, 24, 32), encoder_depths=(1, 1, 1, 1), decoder_channels=(24, 16, 8), stem_channels=8, bottleneck_depth=1, global_ratio=1.0, use_wavelet_branch=True, use_global_branch_stage1=False, ssm_d_state=1, ssm_backend="torch", use_frequency_refine=True, learnable_low_freq_radius=False, ).to(device) if device.type == "cpu": for module in model.modules(): if isinstance(module, XTEB2d): module.global_branch = nn.Identity() model.eval() x = torch.randn(2, 3, 64, 64, device=device) with torch.no_grad(): outputs = model(x) assert outputs["seg_logits"].shape == (2, 1, 64, 64) assert outputs["logits"].shape == outputs["seg_logits"].shape