| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283 |
- from __future__ import annotations
- import importlib
- import sys
- import warnings
- import torch
- from torch import nn
- def test_importing_xnet2d_does_not_emit_deprecation_warnings() -> None:
- modules_to_clear = [
- name
- for name in sys.modules
- if name == "lib.modules.xnet_2d" or name.startswith("lib.modules.lib_mamba")
- ]
- for name in modules_to_clear:
- sys.modules.pop(name, None)
- with warnings.catch_warnings(record=True) as caught:
- warnings.simplefilter("always", DeprecationWarning)
- importlib.import_module("lib.modules.xnet_2d")
- assert not [
- warning
- for warning in caught
- if issubclass(warning.category, DeprecationWarning)
- ]
- def test_xnet2d_forward_preserves_segmentation_shape() -> None:
- from lib.modules.xnet_2d import XNet2d, XTEB2d
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- model = XNet2d(
- in_channels=3,
- num_classes=1,
- encoder_channels=(8, 16, 24, 32),
- encoder_depths=(1, 1, 1, 1),
- decoder_channels=(24, 16, 8),
- stem_channels=8,
- bottleneck_depth=1,
- global_ratio=1.0,
- use_wavelet_branch=True,
- use_global_branch_stage1=False,
- ssm_d_state=1,
- ssm_backend="torch",
- use_frequency_refine=True,
- learnable_low_freq_radius=False,
- ).to(device)
- if device.type == "cpu":
- for module in model.modules():
- if isinstance(module, XTEB2d):
- module.global_branch = nn.Identity()
- model.eval()
- x = torch.randn(2, 3, 64, 64, device=device)
- with torch.no_grad():
- outputs = model(x)
- assert outputs["seg_logits"].shape == (2, 1, 64, 64)
- assert outputs["logits"].shape == outputs["seg_logits"].shape
- def test_xnet2d_decoder_uses_plain_unet_skip_connections() -> None:
- from lib.modules.xnet_2d import XNet2d
- model = XNet2d(
- in_channels=3,
- num_classes=1,
- encoder_channels=(8, 16, 24, 32),
- encoder_depths=(1, 1, 1, 1),
- decoder_channels=(24, 16, 8),
- stem_channels=8,
- bottleneck_depth=1,
- use_global_branch_stage1=False,
- ssm_d_state=1,
- ssm_backend="torch",
- )
- decoder_module_names = dict(model.decoder.named_modules())
- assert not any(name.startswith("guide") for name in decoder_module_names)
|