| 1234567891011121314151617181920212223242526272829303132333435363738 |
- from __future__ import annotations
- import torch
- from torch import nn
- from lib.modules.xnet_2d import XNet2d, XTEB2d
- def test_xnet2d_forward_preserves_segmentation_shape() -> None:
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- model = XNet2d(
- in_channels=3,
- num_classes=1,
- encoder_channels=(8, 16, 24, 32),
- encoder_depths=(1, 1, 1, 1),
- decoder_channels=(24, 16, 8),
- stem_channels=8,
- bottleneck_depth=1,
- global_ratio=1.0,
- use_wavelet_branch=True,
- use_global_branch_stage1=False,
- ssm_d_state=1,
- ssm_backend="torch",
- use_frequency_refine=True,
- learnable_low_freq_radius=False,
- ).to(device)
- if device.type == "cpu":
- for module in model.modules():
- if isinstance(module, XTEB2d):
- module.global_branch = nn.Identity()
- model.eval()
- x = torch.randn(2, 3, 64, 64, device=device)
- with torch.no_grad():
- outputs = model(x)
- assert outputs["seg_logits"].shape == (2, 1, 64, 64)
- assert outputs["logits"].shape == outputs["seg_logits"].shape
|