| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845 |
- from __future__ import annotations
- from collections.abc import Sequence
- import ptwt
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from .layers_2d import Conv2dBN
- from .lib_mamba.vmamba import SS2D as VMambaSS2D
- class XNetStem2d(nn.Module):
- # Stem reduces spatial size by 4x while lifting features into encoder stage 1.
- def __init__(self, in_channels: int, stem_channels: int, out_channels: int) -> None:
- super().__init__()
- self.block = nn.Sequential(
- Conv2dBN(in_channels, stem_channels, 3, 2, 1),
- nn.ReLU(inplace=True),
- Conv2dBN(stem_channels, stem_channels, 3, 1, 1, groups=stem_channels),
- nn.ReLU(inplace=True),
- Conv2dBN(stem_channels, out_channels, 1, 1, 0),
- nn.ReLU(inplace=True),
- Conv2dBN(out_channels, out_channels, 3, 2, 1),
- nn.ReLU(inplace=True),
- )
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return self.block(x)
- class XNetDownsample2d(nn.Module):
- def __init__(self, in_channels: int, out_channels: int, mode: str = "conv") -> None:
- super().__init__()
- if mode != "conv":
- raise ValueError(f"Unsupported downsample mode: {mode}")
- self.block = nn.Sequential(
- Conv2dBN(in_channels, out_channels, 3, 2, 1),
- nn.ReLU(inplace=True),
- )
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return self.block(x)
- class XLocalBranch2d(nn.Module):
- # Parallel depthwise branches capture short-range texture at two kernel scales.
- def __init__(self, channels: int) -> None:
- super().__init__()
- self.branch3 = nn.Sequential(
- Conv2dBN(channels, channels, 3, 1, 1, groups=channels),
- nn.ReLU(inplace=True),
- Conv2dBN(channels, channels, 1, 1, 0),
- )
- self.branch5 = nn.Sequential(
- Conv2dBN(channels, channels, 5, 1, 2, groups=channels),
- nn.ReLU(inplace=True),
- Conv2dBN(channels, channels, 1, 1, 0),
- )
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return self.branch3(x) + self.branch5(x)
- class XWaveletTransform2d(nn.Module):
- # ptwt-based wavelet decomposition/reconstruction with explicit crop so odd
- # input sizes round-trip to the exact original spatial shape.
- def __init__(
- self, channels: int, wavelet_type: str = "haar", wavelet_level: int = 1
- ) -> None:
- super().__init__()
- self.channels = channels
- self.wavelet_type = wavelet_type
- self.wavelet_level = wavelet_level
- def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
- original_dtype = x.dtype
- with torch.autocast(device_type=x.device.type, enabled=False):
- coeffs = ptwt.wavedec2(
- x.float(), self.wavelet_type, level=self.wavelet_level
- )
- ll = coeffs[0]
- high_parts = coeffs[1]
- high = torch.cat(high_parts, dim=1)
- return ll.to(original_dtype), high.to(original_dtype)
- def inverse(
- self, ll: torch.Tensor, high: torch.Tensor, output_size: tuple[int, int]
- ) -> torch.Tensor:
- original_dtype = ll.dtype
- with torch.autocast(device_type=ll.device.type, enabled=False):
- lh, hl, hh = torch.chunk(high.float(), 3, dim=1)
- coeffs = [ll.float(), (lh, hl, hh)]
- x = ptwt.waverec2(coeffs, self.wavelet_type)
- x = x[:, :, : output_size[0], : output_size[1]]
- return x.to(original_dtype)
- class XWaveletBranch2d(nn.Module):
- # The wavelet branch learns on low/high-frequency components separately and
- # then reconstructs back to the original feature size.
- def __init__(
- self, channels: int, wavelet_type: str = "haar", wavelet_level: int = 1
- ) -> None:
- super().__init__()
- if wavelet_type != "haar":
- raise ValueError(f"Unsupported wavelet type: {wavelet_type}")
- if wavelet_level != 1:
- raise ValueError(
- "Initial XNet implementation only supports wavelet_level=1."
- )
- self.wavelet = XWaveletTransform2d(
- channels, wavelet_type=wavelet_type, wavelet_level=wavelet_level
- )
- self.ll_proj = nn.Sequential(
- Conv2dBN(channels, channels, 3, 1, 1),
- nn.ReLU(inplace=True),
- )
- self.high_proj = nn.Sequential(
- Conv2dBN(channels * 3, channels * 3, 3, 1, 1, groups=channels * 3),
- nn.ReLU(inplace=True),
- Conv2dBN(channels * 3, channels * 3, 1, 1, 0),
- )
- self.out_proj = nn.Sequential(
- Conv2dBN(channels, channels, 1, 1, 0),
- nn.ReLU(inplace=True),
- )
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- output_size = x.shape[-2:]
- ll, high = self.wavelet(x)
- ll = self.ll_proj(ll)
- high = self.high_proj(high)
- x = self.wavelet.inverse(ll, high, output_size=output_size)
- return self.out_proj(x)
- class XSSMGlobalBranch2d(nn.Module):
- # The global branch wraps VMamba and switches scan backend at runtime.
- def __init__(
- self,
- channels: int,
- global_ratio: float = 2.0,
- d_state: int = 16,
- forward_type: str = "v3",
- ssm_backend: str = "auto",
- ) -> None:
- super().__init__()
- hidden_ratio = max(global_ratio, 1.0)
- self.backend = ssm_backend
- self.pre = nn.Sequential(
- Conv2dBN(channels, channels, 1, 1, 0),
- nn.ReLU(inplace=True),
- )
- self.ssm = VMambaSS2D(
- d_model=channels,
- d_state=d_state,
- ssm_ratio=hidden_ratio,
- d_conv=3,
- dropout=0.0,
- initialize="v0",
- forward_type=forward_type,
- channel_first=True,
- )
- self.post = nn.Sequential(
- Conv2dBN(channels, channels, 1, 1, 0),
- nn.ReLU(inplace=True),
- )
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = self.pre(x)
- prev_backend = None
- backend = self.backend.lower()
- if backend == "auto":
- backend = "oflex" if x.is_cuda else "torch"
- if backend == "oflex" and hasattr(self.ssm, "forward_core"):
- prev_backend = self.ssm.forward_core
- self.ssm.forward_core = lambda z, _core=prev_backend: _core(
- z,
- selective_scan_backend="oflex",
- scan_force_torch=False,
- )
- elif backend == "torch" and hasattr(self.ssm, "forward_core"):
- prev_backend = self.ssm.forward_core
- self.ssm.forward_core = lambda z, _core=prev_backend: _core(
- z,
- selective_scan_backend="torch",
- scan_force_torch=True,
- )
- try:
- x = self.ssm(x)
- finally:
- if prev_backend is not None:
- self.ssm.forward_core = prev_backend
- return self.post(x)
- class XGlobalBranch2d(nn.Module):
- def __init__(
- self,
- channels: int,
- global_ratio: float = 2.0,
- ssm_d_state: int = 16,
- ssm_forward_type: str = "v3",
- ssm_backend: str = "auto",
- ) -> None:
- super().__init__()
- self.ssm_branch = XSSMGlobalBranch2d(
- channels=channels,
- global_ratio=global_ratio,
- d_state=ssm_d_state,
- forward_type=ssm_forward_type,
- ssm_backend=ssm_backend,
- )
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return self.ssm_branch(x)
- class XBranchFusion2d(nn.Module):
- def __init__(self, channels: int, num_branches: int = 3) -> None:
- super().__init__()
- fused_channels = channels * num_branches
- hidden_channels = max(channels // 4, 8)
- self.fuse = nn.Sequential(
- Conv2dBN(fused_channels, channels, 1, 1, 0),
- nn.ReLU(inplace=True),
- )
- self.gate = nn.Sequential(
- nn.AdaptiveAvgPool2d(1),
- nn.Conv2d(fused_channels, hidden_channels, kernel_size=1, bias=True),
- nn.ReLU(inplace=True),
- nn.Conv2d(hidden_channels, channels, kernel_size=1, bias=True),
- nn.Sigmoid(),
- )
- def forward(self, branch_outputs: Sequence[torch.Tensor]) -> torch.Tensor:
- x_cat = torch.cat(list(branch_outputs), dim=1)
- x_fused = self.fuse(x_cat)
- gate = self.gate(x_cat)
- return x_fused * gate
- class XTEB2d(nn.Module):
- # XTEB fuses local, wavelet, and global branches with residual post/ffn blocks.
- def __init__(
- self,
- channels: int,
- global_ratio: float = 2.0,
- wavelet_type: str = "haar",
- wavelet_level: int = 1,
- use_wavelet_branch: bool = True,
- use_global_branch: bool = True,
- ssm_d_state: int = 16,
- ssm_forward_type: str = "v3",
- ssm_backend: str = "auto",
- ) -> None:
- super().__init__()
- self.pre_norm = Conv2dBN(channels, channels, 1, 1, 0)
- self.local_branch = XLocalBranch2d(channels)
- self.wavelet_branch = (
- XWaveletBranch2d(
- channels, wavelet_type=wavelet_type, wavelet_level=wavelet_level
- )
- if use_wavelet_branch
- else nn.Identity()
- )
- self.global_branch = (
- XGlobalBranch2d(
- channels,
- global_ratio=global_ratio,
- ssm_d_state=ssm_d_state,
- ssm_forward_type=ssm_forward_type,
- ssm_backend=ssm_backend,
- )
- if use_global_branch
- else nn.Identity()
- )
- self.fusion = XBranchFusion2d(channels, num_branches=3)
- self.post = nn.Sequential(
- Conv2dBN(channels, channels, 3, 1, 1),
- nn.ReLU(inplace=True),
- Conv2dBN(channels, channels, 1, 1, 0, bn_weight_init=0.0),
- )
- self.ffn = nn.Sequential(
- Conv2dBN(channels, channels * 2, 1, 1, 0),
- nn.ReLU(inplace=True),
- Conv2dBN(channels * 2, channels, 1, 1, 0, bn_weight_init=0.0),
- )
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x_in = x
- x = self.pre_norm(x)
- x = x_in + self.post(
- self.fusion(
- [self.local_branch(x), self.wavelet_branch(x), self.global_branch(x)]
- )
- )
- return x + self.ffn(x)
- class XNetEncoderStage2d(nn.Module):
- def __init__(
- self,
- channels: int,
- depth: int,
- global_ratio: float = 2.0,
- wavelet_type: str = "haar",
- wavelet_level: int = 1,
- use_wavelet_branch: bool = True,
- use_global_branch: bool = True,
- ssm_d_state: int = 16,
- ssm_forward_type: str = "v3",
- ssm_backend: str = "auto",
- ) -> None:
- super().__init__()
- self.blocks = nn.Sequential(
- *[
- XTEB2d(
- channels=channels,
- global_ratio=global_ratio,
- wavelet_type=wavelet_type,
- wavelet_level=wavelet_level,
- use_wavelet_branch=use_wavelet_branch,
- use_global_branch=use_global_branch,
- ssm_d_state=ssm_d_state,
- ssm_forward_type=ssm_forward_type,
- ssm_backend=ssm_backend,
- )
- for _ in range(depth)
- ]
- )
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return self.blocks(x)
- class XNetEncoder2d(nn.Module):
- # The encoder is a 4-stage feature pyramid with optional stage-1 global branch.
- def __init__(
- self,
- in_channels: int,
- stem_channels: int,
- encoder_channels: Sequence[int],
- encoder_depths: Sequence[int],
- global_ratio: float = 2.0,
- wavelet_type: str = "haar",
- wavelet_level: int = 1,
- use_wavelet_branch: bool = True,
- use_global_branch_stage1: bool = False,
- ssm_d_state: int = 16,
- ssm_forward_type: str = "v3",
- ssm_backend: str = "auto",
- ) -> None:
- super().__init__()
- if len(encoder_channels) != 4 or len(encoder_depths) != 4:
- raise ValueError("XNetEncoder2d expects 4 encoder stages.")
- c1, c2, c3, c4 = encoder_channels
- d1, d2, d3, d4 = encoder_depths
- self.stem = XNetStem2d(in_channels, stem_channels, c1)
- self.stage1 = XNetEncoderStage2d(
- c1,
- d1,
- global_ratio,
- wavelet_type,
- wavelet_level,
- use_wavelet_branch=use_wavelet_branch,
- use_global_branch=use_global_branch_stage1,
- ssm_d_state=ssm_d_state,
- ssm_forward_type=ssm_forward_type,
- ssm_backend=ssm_backend,
- )
- self.down1 = XNetDownsample2d(c1, c2)
- self.stage2 = XNetEncoderStage2d(
- c2,
- d2,
- global_ratio,
- wavelet_type,
- wavelet_level,
- use_wavelet_branch,
- True,
- ssm_d_state=ssm_d_state,
- ssm_forward_type=ssm_forward_type,
- ssm_backend=ssm_backend,
- )
- self.down2 = XNetDownsample2d(c2, c3)
- self.stage3 = XNetEncoderStage2d(
- c3,
- d3,
- global_ratio,
- wavelet_type,
- wavelet_level,
- use_wavelet_branch,
- True,
- ssm_d_state=ssm_d_state,
- ssm_forward_type=ssm_forward_type,
- ssm_backend=ssm_backend,
- )
- self.down3 = XNetDownsample2d(c3, c4)
- self.stage4 = XNetEncoderStage2d(
- c4,
- d4,
- global_ratio,
- wavelet_type,
- wavelet_level,
- use_wavelet_branch,
- True,
- ssm_d_state=ssm_d_state,
- ssm_forward_type=ssm_forward_type,
- ssm_backend=ssm_backend,
- )
- self.stage_channels = list(encoder_channels)
- def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
- e1 = self.stage1(self.stem(x))
- e2 = self.stage2(self.down1(e1))
- e3 = self.stage3(self.down2(e2))
- e4 = self.stage4(self.down3(e3))
- return [e1, e2, e3, e4]
- class XGuideProjector2d(nn.Module):
- # Guides are projected from encoder features and aligned to decoder resolution.
- def __init__(
- self, in_channels: int, out_channels: int, mode: str = "affine"
- ) -> None:
- super().__init__()
- self.mode = mode
- if mode == "affine":
- self.proj = nn.Sequential(
- Conv2dBN(in_channels, out_channels * 2, 1, 1, 0),
- nn.ReLU(inplace=True),
- nn.Conv2d(out_channels * 2, out_channels * 2, kernel_size=1, bias=True),
- )
- elif mode == "feature":
- self.proj = nn.Sequential(
- Conv2dBN(in_channels, out_channels, 1, 1, 0),
- nn.ReLU(inplace=True),
- )
- else:
- raise ValueError(f"Unsupported guide mode: {mode}")
- def forward(
- self,
- x: torch.Tensor,
- target_size: tuple[int, int],
- ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- x = F.interpolate(x, size=target_size, mode="bilinear", align_corners=False)
- x = self.proj(x)
- if self.mode == "affine":
- gamma, beta = torch.chunk(x, 2, dim=1)
- gamma = torch.sigmoid(gamma) + 0.5
- return gamma, beta
- return x
- class XSkipFusion2d(nn.Module):
- # Decoder input and skip feature are aligned, projected, and fused together.
- def __init__(self, in_channels: int, skip_channels: int, out_channels: int) -> None:
- super().__init__()
- self.input_proj = nn.Sequential(
- Conv2dBN(in_channels, out_channels, 1, 1, 0),
- nn.ReLU(inplace=True),
- )
- self.skip_proj = nn.Sequential(
- Conv2dBN(skip_channels, out_channels, 1, 1, 0),
- nn.ReLU(inplace=True),
- )
- self.fuse = nn.Sequential(
- Conv2dBN(out_channels * 2, out_channels, 3, 1, 1),
- nn.ReLU(inplace=True),
- )
- def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
- x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
- x = self.input_proj(x)
- skip = self.skip_proj(skip)
- return self.fuse(torch.cat([x, skip], dim=1))
- class XGuideModulation2d(nn.Module):
- # Apply either direct affine guide or feature-to-affine modulation.
- def __init__(self, channels: int, guide_mode: str = "affine") -> None:
- super().__init__()
- self.guide_mode = guide_mode
- if guide_mode == "feature":
- self.to_affine = nn.Conv2d(channels, channels * 2, kernel_size=1, bias=True)
- def forward(
- self,
- x: torch.Tensor,
- guide: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
- ) -> torch.Tensor:
- if self.guide_mode == "affine":
- gamma, beta = guide
- else:
- gamma, beta = torch.chunk(self.to_affine(guide), 2, dim=1)
- gamma = torch.sigmoid(gamma) + 0.5
- return gamma * x + beta
- class XFrequencyRefine2d(nn.Module):
- def __init__(
- self,
- channels: int,
- low_freq_radius_h: float = 0.25,
- low_freq_radius_w: float = 0.25,
- learnable_low_freq_radius: bool = True,
- ) -> None:
- super().__init__()
- if low_freq_radius_h <= 0.0 or low_freq_radius_w <= 0.0:
- raise ValueError("Low-frequency radii must be positive.")
- # Gates are predicted from half-spectrum magnitude statistics instead of
- # directly reusing spatial-domain pooled features.
- self.low_gate = nn.Sequential(
- nn.Conv2d(channels, channels, kernel_size=1, bias=True),
- nn.Sigmoid(),
- )
- self.high_gate = nn.Sequential(
- nn.Conv2d(channels, channels, kernel_size=1, bias=True),
- nn.Sigmoid(),
- )
- self.refine = nn.Sequential(
- Conv2dBN(channels, channels, 3, 1, 1, groups=channels),
- nn.ReLU(inplace=True),
- Conv2dBN(channels, channels, 1, 1, 0),
- )
- self.learnable_low_freq_radius = learnable_low_freq_radius
- if learnable_low_freq_radius:
- self.low_freq_radius_h = nn.Parameter(
- torch.tensor(low_freq_radius_h, dtype=torch.float32)
- )
- self.low_freq_radius_w = nn.Parameter(
- torch.tensor(low_freq_radius_w, dtype=torch.float32)
- )
- else:
- self.register_buffer(
- "low_freq_radius_h",
- torch.tensor(low_freq_radius_h, dtype=torch.float32),
- persistent=False,
- )
- self.register_buffer(
- "low_freq_radius_w",
- torch.tensor(low_freq_radius_w, dtype=torch.float32),
- persistent=False,
- )
- def _resolve_radius(
- self, value: torch.Tensor, max_ratio: float, device: torch.device
- ) -> torch.Tensor:
- radius = value.to(device=device, dtype=torch.float32)
- if self.learnable_low_freq_radius:
- radius = torch.sigmoid(radius) * max_ratio
- return torch.clamp(radius, min=1.0e-3, max=max_ratio)
- def _build_low_frequency_mask(
- self, h_freq: int, w_freq: int, device: torch.device
- ) -> torch.Tensor:
- y = torch.arange(h_freq, device=device, dtype=torch.float32)
- x = torch.arange(w_freq, device=device, dtype=torch.float32)
- y = torch.minimum(y, h_freq - y)
- radius_h = self._resolve_radius(self.low_freq_radius_h, 0.5, device) * max(
- h_freq, 1
- )
- radius_w = self._resolve_radius(self.low_freq_radius_w, 1.0, device) * max(
- w_freq, 1
- )
- y = y / torch.clamp(radius_h, min=1.0)
- x = x / torch.clamp(radius_w, min=1.0)
- y_grid, x_grid = torch.meshgrid(y, x, indexing="ij")
- mask = (y_grid.square() + x_grid.square()) <= 1.0
- return mask.unsqueeze(0).unsqueeze(0)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- input_dtype = x.dtype
- if x.dtype != torch.float32:
- x = x.to(torch.float32)
- fft = torch.fft.rfft2(x, norm="ortho")
- h_freq, w_freq = fft.shape[-2], fft.shape[-1]
- low_mask = self._build_low_frequency_mask(h_freq, w_freq, fft.device).to(
- dtype=x.dtype
- )
- low = fft * low_mask
- high = fft - low
- magnitude = fft.abs()
- low_stats = (magnitude * low_mask).mean(dim=(-2, -1), keepdim=True)
- high_stats = (magnitude * (1.0 - low_mask)).mean(dim=(-2, -1), keepdim=True)
- low = low * self.low_gate(low_stats)
- high = high * self.high_gate(high_stats)
- out = torch.fft.irfft2(low + high, s=x.shape[-2:], norm="ortho")
- out = out.to(dtype=input_dtype)
- return self.refine(out)
- class XCRB2d(nn.Module):
- # Decoder block: skip fusion -> guide modulation -> frequency refine -> residual output.
- def __init__(
- self,
- in_channels: int,
- skip_channels: int,
- guide_channels: int,
- out_channels: int,
- guide_mode: str = "affine",
- use_frequency_refine: bool = True,
- low_freq_radius_h: float = 0.25,
- low_freq_radius_w: float = 0.25,
- learnable_low_freq_radius: bool = True,
- ) -> None:
- super().__init__()
- self.skip_fusion = XSkipFusion2d(in_channels, skip_channels, out_channels)
- self.guide_modulation = XGuideModulation2d(out_channels, guide_mode=guide_mode)
- self.frequency_refine = (
- XFrequencyRefine2d(
- out_channels,
- low_freq_radius_h=low_freq_radius_h,
- low_freq_radius_w=low_freq_radius_w,
- learnable_low_freq_radius=learnable_low_freq_radius,
- )
- if use_frequency_refine
- else nn.Identity()
- )
- self.out_refine = nn.Sequential(
- Conv2dBN(out_channels, out_channels, 3, 1, 1),
- nn.ReLU(inplace=True),
- Conv2dBN(out_channels, out_channels, 3, 1, 1, bn_weight_init=0.0),
- )
- self.guide_channels = guide_channels
- def forward(
- self,
- x: torch.Tensor,
- skip: torch.Tensor,
- guide: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
- ) -> torch.Tensor:
- x = self.skip_fusion(x, skip)
- x = self.guide_modulation(x, guide)
- x = x + self.frequency_refine(x)
- return x + self.out_refine(x)
- class XNetHeadRefine2d(nn.Module):
- def __init__(self, channels: int, out_channels: int | None = None) -> None:
- super().__init__()
- if out_channels is None:
- out_channels = channels
- self.block = nn.Sequential(
- Conv2dBN(channels, out_channels, 3, 1, 1),
- nn.ReLU(inplace=True),
- Conv2dBN(out_channels, out_channels, 3, 1, 1),
- nn.ReLU(inplace=True),
- )
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return self.block(x)
- class XNetDecoder2d(nn.Module):
- def __init__(
- self,
- encoder_channels: Sequence[int],
- decoder_channels: Sequence[int] = (128, 64, 32),
- guide_mode: str = "affine",
- use_frequency_refine: bool = True,
- low_freq_radius_h: float = 0.25,
- low_freq_radius_w: float = 0.25,
- learnable_low_freq_radius: bool = True,
- out_channels: int | None = None,
- ) -> None:
- super().__init__()
- if len(encoder_channels) != 4:
- raise ValueError("XNetDecoder2d expects 4 encoder stages.")
- if len(decoder_channels) != 3:
- raise ValueError("XNetDecoder2d expects 3 decoder channels.")
- c1, c2, c3, c4 = encoder_channels
- d4, d3, d2 = decoder_channels
- self.guide4 = XGuideProjector2d(c4, d4, mode=guide_mode)
- self.guide3 = XGuideProjector2d(c3, d3, mode=guide_mode)
- self.guide2 = XGuideProjector2d(c2, d2, mode=guide_mode)
- self.dec4 = XCRB2d(
- c4,
- c3,
- d4,
- d4,
- guide_mode=guide_mode,
- use_frequency_refine=use_frequency_refine,
- low_freq_radius_h=low_freq_radius_h,
- low_freq_radius_w=low_freq_radius_w,
- learnable_low_freq_radius=learnable_low_freq_radius,
- )
- self.dec3 = XCRB2d(
- d4,
- c2,
- d3,
- d3,
- guide_mode=guide_mode,
- use_frequency_refine=use_frequency_refine,
- low_freq_radius_h=low_freq_radius_h,
- low_freq_radius_w=low_freq_radius_w,
- learnable_low_freq_radius=learnable_low_freq_radius,
- )
- self.dec2 = XCRB2d(
- d3,
- c1,
- d2,
- d2,
- guide_mode=guide_mode,
- use_frequency_refine=use_frequency_refine,
- low_freq_radius_h=low_freq_radius_h,
- low_freq_radius_w=low_freq_radius_w,
- learnable_low_freq_radius=learnable_low_freq_radius,
- )
- self.head_refine = XNetHeadRefine2d(d2, out_channels or d2)
- self.out_channels = out_channels or d2
- def forward(
- self,
- features: Sequence[torch.Tensor],
- ) -> tuple[
- torch.Tensor,
- list[torch.Tensor],
- list[torch.Tensor | tuple[torch.Tensor, torch.Tensor]],
- ]:
- e1, e2, e3, e4 = features
- g4 = self.guide4(e4, target_size=e3.shape[-2:])
- d4 = self.dec4(e4, e3, g4)
- g3 = self.guide3(e3, target_size=e2.shape[-2:])
- d3 = self.dec3(d4, e2, g3)
- g2 = self.guide2(e2, target_size=e1.shape[-2:])
- d2 = self.dec2(d3, e1, g2)
- d1 = self.head_refine(d2)
- return d1, [d4, d3, d2, d1], [g4, g3, g2]
- class XNetSegHead2d(nn.Module):
- def __init__(
- self, in_channels: int, num_classes: int, upsample_scale: int = 4
- ) -> None:
- super().__init__()
- self.block = nn.Sequential(
- Conv2dBN(in_channels, in_channels, 3, 1, 1),
- nn.ReLU(inplace=True),
- nn.Conv2d(in_channels, num_classes, kernel_size=1, bias=True),
- )
- self.upsample_scale = upsample_scale
- def forward(self, x: torch.Tensor, output_size: tuple[int, int]) -> torch.Tensor:
- x = self.block(x)
- return F.interpolate(x, size=output_size, mode="bilinear", align_corners=False)
- class XNet2d(nn.Module):
- def __init__(
- self,
- in_channels: int,
- num_classes: int,
- encoder_channels: Sequence[int] = (32, 64, 128, 192),
- encoder_depths: Sequence[int] = (2, 2, 2, 2),
- decoder_channels: Sequence[int] = (128, 64, 32),
- stem_channels: int = 24,
- bottleneck_depth: int = 1,
- global_ratio: float = 2.0,
- wavelet_type: str = "haar",
- wavelet_level: int = 1,
- use_wavelet_branch: bool = True,
- use_global_branch_stage1: bool = False,
- ssm_d_state: int = 16,
- ssm_forward_type: str = "v3",
- ssm_backend: str = "auto",
- use_frequency_refine: bool = True,
- low_freq_radius_h: float = 0.25,
- low_freq_radius_w: float = 0.25,
- learnable_low_freq_radius: bool = True,
- guide_mode: str = "affine",
- out_channels: int | None = None,
- ) -> None:
- super().__init__()
- self.encoder = XNetEncoder2d(
- in_channels=in_channels,
- stem_channels=stem_channels,
- encoder_channels=encoder_channels,
- encoder_depths=encoder_depths,
- global_ratio=global_ratio,
- wavelet_type=wavelet_type,
- wavelet_level=wavelet_level,
- use_wavelet_branch=use_wavelet_branch,
- use_global_branch_stage1=use_global_branch_stage1,
- ssm_d_state=ssm_d_state,
- ssm_forward_type=ssm_forward_type,
- ssm_backend=ssm_backend,
- )
- bottleneck_channels = encoder_channels[-1]
- self.bottleneck = nn.Sequential(
- *[
- XTEB2d(
- channels=bottleneck_channels,
- global_ratio=global_ratio,
- wavelet_type=wavelet_type,
- wavelet_level=wavelet_level,
- use_wavelet_branch=use_wavelet_branch,
- use_global_branch=True,
- ssm_d_state=ssm_d_state,
- ssm_forward_type=ssm_forward_type,
- ssm_backend=ssm_backend,
- )
- for _ in range(bottleneck_depth)
- ]
- )
- self.decoder = XNetDecoder2d(
- encoder_channels=encoder_channels,
- decoder_channels=decoder_channels,
- guide_mode=guide_mode,
- use_frequency_refine=use_frequency_refine,
- low_freq_radius_h=low_freq_radius_h,
- low_freq_radius_w=low_freq_radius_w,
- learnable_low_freq_radius=learnable_low_freq_radius,
- out_channels=out_channels,
- )
- head_in_channels = self.decoder.out_channels
- self.segmentation_head = XNetSegHead2d(head_in_channels, num_classes)
- def forward(
- self, x: torch.Tensor
- ) -> dict[
- str, torch.Tensor | list[torch.Tensor] | list[tuple[torch.Tensor, torch.Tensor]]
- ]:
- encoder_features = self.encoder(x)
- encoder_features[-1] = self.bottleneck(encoder_features[-1])
- decoder_out, decoder_features, guides = self.decoder(encoder_features)
- output_size = x.shape[-2:]
- logits = self.segmentation_head(decoder_out, output_size=output_size)
- outputs: dict[
- str,
- torch.Tensor | list[torch.Tensor] | list[tuple[torch.Tensor, torch.Tensor]],
- ] = {
- "logits": logits,
- "seg_logits": logits,
- "encoder_features": encoder_features,
- "decoder_features": decoder_features,
- "guides": guides,
- }
- return outputs
|