from __future__ import annotations from collections.abc import Sequence import ptwt import torch import torch.nn as nn import torch.nn.functional as F from .layers_2d import Conv2dBN from .lib_mamba.vmamba import SS2D as VMambaSS2D class XNetStem2d(nn.Module): # Stem reduces spatial size by 4x while lifting features into encoder stage 1. def __init__(self, in_channels: int, stem_channels: int, out_channels: int) -> None: super().__init__() self.block = nn.Sequential( Conv2dBN(in_channels, stem_channels, 3, 2, 1), nn.ReLU(inplace=True), Conv2dBN(stem_channels, stem_channels, 3, 1, 1, groups=stem_channels), nn.ReLU(inplace=True), Conv2dBN(stem_channels, out_channels, 1, 1, 0), nn.ReLU(inplace=True), Conv2dBN(out_channels, out_channels, 3, 2, 1), nn.ReLU(inplace=True), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.block(x) class XNetDownsample2d(nn.Module): def __init__(self, in_channels: int, out_channels: int, mode: str = "conv") -> None: super().__init__() if mode != "conv": raise ValueError(f"Unsupported downsample mode: {mode}") self.block = nn.Sequential( Conv2dBN(in_channels, out_channels, 3, 2, 1), nn.ReLU(inplace=True), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.block(x) class XLocalBranch2d(nn.Module): # Parallel depthwise branches capture short-range texture at two kernel scales. def __init__(self, channels: int) -> None: super().__init__() self.branch3 = nn.Sequential( Conv2dBN(channels, channels, 3, 1, 1, groups=channels), nn.ReLU(inplace=True), Conv2dBN(channels, channels, 1, 1, 0), ) self.branch5 = nn.Sequential( Conv2dBN(channels, channels, 5, 1, 2, groups=channels), nn.ReLU(inplace=True), Conv2dBN(channels, channels, 1, 1, 0), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.branch3(x) + self.branch5(x) class XWaveletTransform2d(nn.Module): # ptwt-based wavelet decomposition/reconstruction with explicit crop so odd # input sizes round-trip to the exact original spatial shape. def __init__( self, channels: int, wavelet_type: str = "haar", wavelet_level: int = 1 ) -> None: super().__init__() self.channels = channels self.wavelet_type = wavelet_type self.wavelet_level = wavelet_level def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: original_dtype = x.dtype with torch.autocast(device_type=x.device.type, enabled=False): coeffs = ptwt.wavedec2( x.float(), self.wavelet_type, level=self.wavelet_level ) ll = coeffs[0] high_parts = coeffs[1] high = torch.cat(high_parts, dim=1) return ll.to(original_dtype), high.to(original_dtype) def inverse( self, ll: torch.Tensor, high: torch.Tensor, output_size: tuple[int, int] ) -> torch.Tensor: original_dtype = ll.dtype with torch.autocast(device_type=ll.device.type, enabled=False): lh, hl, hh = torch.chunk(high.float(), 3, dim=1) coeffs = [ll.float(), (lh, hl, hh)] x = ptwt.waverec2(coeffs, self.wavelet_type) x = x[:, :, : output_size[0], : output_size[1]] return x.to(original_dtype) class XWaveletBranch2d(nn.Module): # The wavelet branch learns on low/high-frequency components separately and # then reconstructs back to the original feature size. def __init__( self, channels: int, wavelet_type: str = "haar", wavelet_level: int = 1 ) -> None: super().__init__() if wavelet_type != "haar": raise ValueError(f"Unsupported wavelet type: {wavelet_type}") if wavelet_level != 1: raise ValueError( "Initial XNet implementation only supports wavelet_level=1." ) self.wavelet = XWaveletTransform2d( channels, wavelet_type=wavelet_type, wavelet_level=wavelet_level ) self.ll_proj = nn.Sequential( Conv2dBN(channels, channels, 3, 1, 1), nn.ReLU(inplace=True), ) self.high_proj = nn.Sequential( Conv2dBN(channels * 3, channels * 3, 3, 1, 1, groups=channels * 3), nn.ReLU(inplace=True), Conv2dBN(channels * 3, channels * 3, 1, 1, 0), ) self.out_proj = nn.Sequential( Conv2dBN(channels, channels, 1, 1, 0), nn.ReLU(inplace=True), ) def forward(self, x: torch.Tensor) -> torch.Tensor: output_size = x.shape[-2:] ll, high = self.wavelet(x) ll = self.ll_proj(ll) high = self.high_proj(high) x = self.wavelet.inverse(ll, high, output_size=output_size) return self.out_proj(x) class XSSMGlobalBranch2d(nn.Module): # The global branch wraps VMamba and switches scan backend at runtime. def __init__( self, channels: int, global_ratio: float = 2.0, d_state: int = 16, forward_type: str = "v3", ssm_backend: str = "auto", ) -> None: super().__init__() hidden_ratio = max(global_ratio, 1.0) self.backend = ssm_backend self.pre = nn.Sequential( Conv2dBN(channels, channels, 1, 1, 0), nn.ReLU(inplace=True), ) self.ssm = VMambaSS2D( d_model=channels, d_state=d_state, ssm_ratio=hidden_ratio, d_conv=3, dropout=0.0, initialize="v0", forward_type=forward_type, channel_first=True, ) self.post = nn.Sequential( Conv2dBN(channels, channels, 1, 1, 0), nn.ReLU(inplace=True), ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.pre(x) prev_backend = None backend = self.backend.lower() if backend == "auto": backend = "oflex" if x.is_cuda else "torch" if backend == "oflex" and hasattr(self.ssm, "forward_core"): prev_backend = self.ssm.forward_core self.ssm.forward_core = lambda z, _core=prev_backend: _core( z, selective_scan_backend="oflex", scan_force_torch=False, ) elif backend == "torch" and hasattr(self.ssm, "forward_core"): prev_backend = self.ssm.forward_core self.ssm.forward_core = lambda z, _core=prev_backend: _core( z, selective_scan_backend="torch", scan_force_torch=True, ) try: x = self.ssm(x) finally: if prev_backend is not None: self.ssm.forward_core = prev_backend return self.post(x) class XGlobalBranch2d(nn.Module): def __init__( self, channels: int, global_ratio: float = 2.0, ssm_d_state: int = 16, ssm_forward_type: str = "v3", ssm_backend: str = "auto", ) -> None: super().__init__() self.ssm_branch = XSSMGlobalBranch2d( channels=channels, global_ratio=global_ratio, d_state=ssm_d_state, forward_type=ssm_forward_type, ssm_backend=ssm_backend, ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.ssm_branch(x) class XBranchFusion2d(nn.Module): def __init__(self, channels: int, num_branches: int = 3) -> None: super().__init__() fused_channels = channels * num_branches hidden_channels = max(channels // 4, 8) self.fuse = nn.Sequential( Conv2dBN(fused_channels, channels, 1, 1, 0), nn.ReLU(inplace=True), ) self.gate = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(fused_channels, hidden_channels, kernel_size=1, bias=True), nn.ReLU(inplace=True), nn.Conv2d(hidden_channels, channels, kernel_size=1, bias=True), nn.Sigmoid(), ) def forward(self, branch_outputs: Sequence[torch.Tensor]) -> torch.Tensor: x_cat = torch.cat(list(branch_outputs), dim=1) x_fused = self.fuse(x_cat) gate = self.gate(x_cat) return x_fused * gate class XTEB2d(nn.Module): # XTEB fuses local, wavelet, and global branches with residual post/ffn blocks. def __init__( self, channels: int, global_ratio: float = 2.0, wavelet_type: str = "haar", wavelet_level: int = 1, use_wavelet_branch: bool = True, use_global_branch: bool = True, ssm_d_state: int = 16, ssm_forward_type: str = "v3", ssm_backend: str = "auto", ) -> None: super().__init__() self.pre_norm = Conv2dBN(channels, channels, 1, 1, 0) self.local_branch = XLocalBranch2d(channels) self.wavelet_branch = ( XWaveletBranch2d( channels, wavelet_type=wavelet_type, wavelet_level=wavelet_level ) if use_wavelet_branch else nn.Identity() ) self.global_branch = ( XGlobalBranch2d( channels, global_ratio=global_ratio, ssm_d_state=ssm_d_state, ssm_forward_type=ssm_forward_type, ssm_backend=ssm_backend, ) if use_global_branch else nn.Identity() ) self.fusion = XBranchFusion2d(channels, num_branches=3) self.post = nn.Sequential( Conv2dBN(channels, channels, 3, 1, 1), nn.ReLU(inplace=True), Conv2dBN(channels, channels, 1, 1, 0, bn_weight_init=0.0), ) self.ffn = nn.Sequential( Conv2dBN(channels, channels * 2, 1, 1, 0), nn.ReLU(inplace=True), Conv2dBN(channels * 2, channels, 1, 1, 0, bn_weight_init=0.0), ) def forward(self, x: torch.Tensor) -> torch.Tensor: x_in = x x = self.pre_norm(x) x = x_in + self.post( self.fusion( [self.local_branch(x), self.wavelet_branch(x), self.global_branch(x)] ) ) return x + self.ffn(x) class XNetEncoderStage2d(nn.Module): def __init__( self, channels: int, depth: int, global_ratio: float = 2.0, wavelet_type: str = "haar", wavelet_level: int = 1, use_wavelet_branch: bool = True, use_global_branch: bool = True, ssm_d_state: int = 16, ssm_forward_type: str = "v3", ssm_backend: str = "auto", ) -> None: super().__init__() self.blocks = nn.Sequential( *[ XTEB2d( channels=channels, global_ratio=global_ratio, wavelet_type=wavelet_type, wavelet_level=wavelet_level, use_wavelet_branch=use_wavelet_branch, use_global_branch=use_global_branch, ssm_d_state=ssm_d_state, ssm_forward_type=ssm_forward_type, ssm_backend=ssm_backend, ) for _ in range(depth) ] ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.blocks(x) class XNetEncoder2d(nn.Module): # The encoder is a 4-stage feature pyramid with optional stage-1 global branch. def __init__( self, in_channels: int, stem_channels: int, encoder_channels: Sequence[int], encoder_depths: Sequence[int], global_ratio: float = 2.0, wavelet_type: str = "haar", wavelet_level: int = 1, use_wavelet_branch: bool = True, use_global_branch_stage1: bool = False, ssm_d_state: int = 16, ssm_forward_type: str = "v3", ssm_backend: str = "auto", ) -> None: super().__init__() if len(encoder_channels) != 4 or len(encoder_depths) != 4: raise ValueError("XNetEncoder2d expects 4 encoder stages.") c1, c2, c3, c4 = encoder_channels d1, d2, d3, d4 = encoder_depths self.stem = XNetStem2d(in_channels, stem_channels, c1) self.stage1 = XNetEncoderStage2d( c1, d1, global_ratio, wavelet_type, wavelet_level, use_wavelet_branch=use_wavelet_branch, use_global_branch=use_global_branch_stage1, ssm_d_state=ssm_d_state, ssm_forward_type=ssm_forward_type, ssm_backend=ssm_backend, ) self.down1 = XNetDownsample2d(c1, c2) self.stage2 = XNetEncoderStage2d( c2, d2, global_ratio, wavelet_type, wavelet_level, use_wavelet_branch, True, ssm_d_state=ssm_d_state, ssm_forward_type=ssm_forward_type, ssm_backend=ssm_backend, ) self.down2 = XNetDownsample2d(c2, c3) self.stage3 = XNetEncoderStage2d( c3, d3, global_ratio, wavelet_type, wavelet_level, use_wavelet_branch, True, ssm_d_state=ssm_d_state, ssm_forward_type=ssm_forward_type, ssm_backend=ssm_backend, ) self.down3 = XNetDownsample2d(c3, c4) self.stage4 = XNetEncoderStage2d( c4, d4, global_ratio, wavelet_type, wavelet_level, use_wavelet_branch, True, ssm_d_state=ssm_d_state, ssm_forward_type=ssm_forward_type, ssm_backend=ssm_backend, ) self.stage_channels = list(encoder_channels) def forward(self, x: torch.Tensor) -> list[torch.Tensor]: e1 = self.stage1(self.stem(x)) e2 = self.stage2(self.down1(e1)) e3 = self.stage3(self.down2(e2)) e4 = self.stage4(self.down3(e3)) return [e1, e2, e3, e4] class XSkipFusion2d(nn.Module): # Decoder input and skip feature are aligned, projected, and fused together. def __init__(self, in_channels: int, skip_channels: int, out_channels: int) -> None: super().__init__() self.input_proj = nn.Sequential( Conv2dBN(in_channels, out_channels, 1, 1, 0), nn.ReLU(inplace=True), ) self.skip_proj = nn.Sequential( Conv2dBN(skip_channels, out_channels, 1, 1, 0), nn.ReLU(inplace=True), ) self.fuse = nn.Sequential( Conv2dBN(out_channels * 2, out_channels, 3, 1, 1), nn.ReLU(inplace=True), ) def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor: x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False) x = self.input_proj(x) skip = self.skip_proj(skip) return self.fuse(torch.cat([x, skip], dim=1)) class XFrequencyRefine2d(nn.Module): def __init__( self, channels: int, low_freq_radius_h: float = 0.25, low_freq_radius_w: float = 0.25, learnable_low_freq_radius: bool = True, ) -> None: super().__init__() if low_freq_radius_h <= 0.0 or low_freq_radius_w <= 0.0: raise ValueError("Low-frequency radii must be positive.") # Gates are predicted from half-spectrum magnitude statistics instead of # directly reusing spatial-domain pooled features. self.low_gate = nn.Sequential( nn.Conv2d(channels, channels, kernel_size=1, bias=True), nn.Sigmoid(), ) self.high_gate = nn.Sequential( nn.Conv2d(channels, channels, kernel_size=1, bias=True), nn.Sigmoid(), ) self.refine = nn.Sequential( Conv2dBN(channels, channels, 3, 1, 1, groups=channels), nn.ReLU(inplace=True), Conv2dBN(channels, channels, 1, 1, 0), ) self.learnable_low_freq_radius = learnable_low_freq_radius if learnable_low_freq_radius: self.low_freq_radius_h = nn.Parameter( torch.tensor(low_freq_radius_h, dtype=torch.float32) ) self.low_freq_radius_w = nn.Parameter( torch.tensor(low_freq_radius_w, dtype=torch.float32) ) else: self.register_buffer( "low_freq_radius_h", torch.tensor(low_freq_radius_h, dtype=torch.float32), persistent=False, ) self.register_buffer( "low_freq_radius_w", torch.tensor(low_freq_radius_w, dtype=torch.float32), persistent=False, ) def _resolve_radius( self, value: torch.Tensor, max_ratio: float, device: torch.device ) -> torch.Tensor: radius = value.to(device=device, dtype=torch.float32) if self.learnable_low_freq_radius: radius = torch.sigmoid(radius) * max_ratio return torch.clamp(radius, min=1.0e-3, max=max_ratio) def _build_low_frequency_mask( self, h_freq: int, w_freq: int, device: torch.device ) -> torch.Tensor: y = torch.arange(h_freq, device=device, dtype=torch.float32) x = torch.arange(w_freq, device=device, dtype=torch.float32) y = torch.minimum(y, h_freq - y) radius_h = self._resolve_radius(self.low_freq_radius_h, 0.5, device) * max( h_freq, 1 ) radius_w = self._resolve_radius(self.low_freq_radius_w, 1.0, device) * max( w_freq, 1 ) y = y / torch.clamp(radius_h, min=1.0) x = x / torch.clamp(radius_w, min=1.0) y_grid, x_grid = torch.meshgrid(y, x, indexing="ij") mask = (y_grid.square() + x_grid.square()) <= 1.0 return mask.unsqueeze(0).unsqueeze(0) def forward(self, x: torch.Tensor) -> torch.Tensor: input_dtype = x.dtype if x.dtype != torch.float32: x = x.to(torch.float32) fft = torch.fft.rfft2(x, norm="ortho") h_freq, w_freq = fft.shape[-2], fft.shape[-1] low_mask = self._build_low_frequency_mask(h_freq, w_freq, fft.device).to( dtype=x.dtype ) low = fft * low_mask high = fft - low magnitude = fft.abs() low_stats = (magnitude * low_mask).mean(dim=(-2, -1), keepdim=True) high_stats = (magnitude * (1.0 - low_mask)).mean(dim=(-2, -1), keepdim=True) low = low * self.low_gate(low_stats) high = high * self.high_gate(high_stats) out = torch.fft.irfft2(low + high, s=x.shape[-2:], norm="ortho") out = out.to(dtype=input_dtype) return self.refine(out) class XCRB2d(nn.Module): # Decoder block: U-Net skip fusion -> frequency refine -> residual output. def __init__( self, in_channels: int, skip_channels: int, out_channels: int, use_frequency_refine: bool = True, low_freq_radius_h: float = 0.25, low_freq_radius_w: float = 0.25, learnable_low_freq_radius: bool = True, ) -> None: super().__init__() self.skip_fusion = XSkipFusion2d(in_channels, skip_channels, out_channels) self.frequency_refine = ( XFrequencyRefine2d( out_channels, low_freq_radius_h=low_freq_radius_h, low_freq_radius_w=low_freq_radius_w, learnable_low_freq_radius=learnable_low_freq_radius, ) if use_frequency_refine else nn.Identity() ) self.out_refine = nn.Sequential( Conv2dBN(out_channels, out_channels, 3, 1, 1), nn.ReLU(inplace=True), Conv2dBN(out_channels, out_channels, 3, 1, 1, bn_weight_init=0.0), ) def forward( self, x: torch.Tensor, skip: torch.Tensor, ) -> torch.Tensor: x = self.skip_fusion(x, skip) x = x + self.frequency_refine(x) return x + self.out_refine(x) class XNetHeadRefine2d(nn.Module): def __init__(self, channels: int, out_channels: int | None = None) -> None: super().__init__() if out_channels is None: out_channels = channels self.block = nn.Sequential( Conv2dBN(channels, out_channels, 3, 1, 1), nn.ReLU(inplace=True), Conv2dBN(out_channels, out_channels, 3, 1, 1), nn.ReLU(inplace=True), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.block(x) class XNetDecoder2d(nn.Module): def __init__( self, encoder_channels: Sequence[int], decoder_channels: Sequence[int] = (128, 64, 32), guide_mode: str = "affine", use_frequency_refine: bool = True, low_freq_radius_h: float = 0.25, low_freq_radius_w: float = 0.25, learnable_low_freq_radius: bool = True, out_channels: int | None = None, ) -> None: super().__init__() if len(encoder_channels) != 4: raise ValueError("XNetDecoder2d expects 4 encoder stages.") if len(decoder_channels) != 3: raise ValueError("XNetDecoder2d expects 3 decoder channels.") c1, c2, c3, c4 = encoder_channels d4, d3, d2 = decoder_channels self.guide_mode = guide_mode self.dec4 = XCRB2d( c4, c3, d4, use_frequency_refine=use_frequency_refine, low_freq_radius_h=low_freq_radius_h, low_freq_radius_w=low_freq_radius_w, learnable_low_freq_radius=learnable_low_freq_radius, ) self.dec3 = XCRB2d( d4, c2, d3, use_frequency_refine=use_frequency_refine, low_freq_radius_h=low_freq_radius_h, low_freq_radius_w=low_freq_radius_w, learnable_low_freq_radius=learnable_low_freq_radius, ) self.dec2 = XCRB2d( d3, c1, d2, use_frequency_refine=use_frequency_refine, low_freq_radius_h=low_freq_radius_h, low_freq_radius_w=low_freq_radius_w, learnable_low_freq_radius=learnable_low_freq_radius, ) self.head_refine = XNetHeadRefine2d(d2, out_channels or d2) self.out_channels = out_channels or d2 def forward( self, features: Sequence[torch.Tensor], ) -> tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]: e1, e2, e3, e4 = features d4 = self.dec4(e4, e3) d3 = self.dec3(d4, e2) d2 = self.dec2(d3, e1) d1 = self.head_refine(d2) return d1, [d4, d3, d2, d1], [] class XNetSegHead2d(nn.Module): def __init__( self, in_channels: int, num_classes: int, upsample_scale: int = 4 ) -> None: super().__init__() self.block = nn.Sequential( Conv2dBN(in_channels, in_channels, 3, 1, 1), nn.ReLU(inplace=True), nn.Conv2d(in_channels, num_classes, kernel_size=1, bias=True), ) self.upsample_scale = upsample_scale def forward(self, x: torch.Tensor, output_size: tuple[int, int]) -> torch.Tensor: x = self.block(x) return F.interpolate(x, size=output_size, mode="bilinear", align_corners=False) class XNet2d(nn.Module): def __init__( self, in_channels: int, num_classes: int, encoder_channels: Sequence[int] = (32, 64, 128, 192), encoder_depths: Sequence[int] = (2, 2, 2, 2), decoder_channels: Sequence[int] = (128, 64, 32), stem_channels: int = 24, bottleneck_depth: int = 1, global_ratio: float = 2.0, wavelet_type: str = "haar", wavelet_level: int = 1, use_wavelet_branch: bool = True, use_global_branch_stage1: bool = False, ssm_d_state: int = 16, ssm_forward_type: str = "v3", ssm_backend: str = "auto", use_frequency_refine: bool = True, low_freq_radius_h: float = 0.25, low_freq_radius_w: float = 0.25, learnable_low_freq_radius: bool = True, guide_mode: str = "affine", out_channels: int | None = None, ) -> None: super().__init__() self.encoder = XNetEncoder2d( in_channels=in_channels, stem_channels=stem_channels, encoder_channels=encoder_channels, encoder_depths=encoder_depths, global_ratio=global_ratio, wavelet_type=wavelet_type, wavelet_level=wavelet_level, use_wavelet_branch=use_wavelet_branch, use_global_branch_stage1=use_global_branch_stage1, ssm_d_state=ssm_d_state, ssm_forward_type=ssm_forward_type, ssm_backend=ssm_backend, ) bottleneck_channels = encoder_channels[-1] self.bottleneck = nn.Sequential( *[ XTEB2d( channels=bottleneck_channels, global_ratio=global_ratio, wavelet_type=wavelet_type, wavelet_level=wavelet_level, use_wavelet_branch=use_wavelet_branch, use_global_branch=True, ssm_d_state=ssm_d_state, ssm_forward_type=ssm_forward_type, ssm_backend=ssm_backend, ) for _ in range(bottleneck_depth) ] ) self.decoder = XNetDecoder2d( encoder_channels=encoder_channels, decoder_channels=decoder_channels, guide_mode=guide_mode, use_frequency_refine=use_frequency_refine, low_freq_radius_h=low_freq_radius_h, low_freq_radius_w=low_freq_radius_w, learnable_low_freq_radius=learnable_low_freq_radius, out_channels=out_channels, ) head_in_channels = self.decoder.out_channels self.segmentation_head = XNetSegHead2d(head_in_channels, num_classes) def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor | list[torch.Tensor]]: encoder_features = self.encoder(x) encoder_features[-1] = self.bottleneck(encoder_features[-1]) decoder_out, decoder_features, guides = self.decoder(encoder_features) output_size = x.shape[-2:] logits = self.segmentation_head(decoder_out, output_size=output_size) outputs: dict[str, torch.Tensor | list[torch.Tensor]] = { "logits": logits, "seg_logits": logits, "encoder_features": encoder_features, "decoder_features": decoder_features, "guides": guides, } return outputs