|
@@ -0,0 +1,980 @@
|
|
|
|
|
+from __future__ import annotations
|
|
|
|
|
+
|
|
|
|
|
+from collections.abc import Sequence
|
|
|
|
|
+
|
|
|
|
|
+import torch
|
|
|
|
|
+import torch.nn as nn
|
|
|
|
|
+import torch.nn.functional as F
|
|
|
|
|
+
|
|
|
|
|
+import ptwt
|
|
|
|
|
+
|
|
|
|
|
+from .layers_2d import Conv2dBN
|
|
|
|
|
+from .lib_mamba.vmamba import SS2D as VMambaSS2D
|
|
|
|
|
+
|
|
|
|
|
+"""
|
|
|
|
|
+## 完成的修改
|
|
|
|
|
+
|
|
|
|
|
+### 1. 小波变换模块迁移至 ptwt
|
|
|
|
|
+- **替换 `XHaarWaveletTransform2d` → `XWaveletTransform2d`**:使用 `ptwt.wavedec2` / `ptwt.waverec2` 实现可逆小波变换
|
|
|
|
|
+- **优势**:
|
|
|
|
|
+ - 支持任意 pywt 兼容小波(haar, db4, sym2, db6 等),通过 `wavelet_type` 参数切换
|
|
|
|
|
+ - 自动处理边界对齐,无需手动 padding/cropping
|
|
|
|
|
+ - 代码更简洁,无手工卷积滤波器
|
|
|
|
|
+- **`XWaveletBranch2d`** 已更新引用新类,移除了 wavelet 类型限制检查
|
|
|
|
|
+
|
|
|
|
|
+### 2. XFrequencyRefine2d 频率域精炼模块分析
|
|
|
|
|
+
|
|
|
|
|
+**发现的问题与修复:**
|
|
|
|
|
+- **原代码 FFT 低频掩码位置错误**:未使用 `fftshift`,直接在左上角做十字掩码,与真实低频位置(四角)不匹配
|
|
|
|
|
+- **已修复**:使用 `fftshift` → 圆形低频掩码 → `ifftshift` 还原的正确流程
|
|
|
|
|
+
|
|
|
|
|
+**设计合理性评估:**
|
|
|
|
|
+| 方面 | 评价 |
|
|
|
|
|
+|------|------|
|
|
|
|
|
+| 低频/高频分离 | ✅ 圆形掩码合理,可调节半径 |
|
|
|
|
|
+| 门控机制 | ⚠️ 门控值来自空间域而非频域,可能损失频域选择性 |
|
|
|
|
|
+| 通道注意力 | ✅ 每个通道独立门控,灵活 |
|
|
|
|
|
+| 重建精度 | ✅ 正交归一化 FFT + 完整频域保留 |
|
|
|
|
|
+| 计算开销 | ⚠️ meshgrid 每步计算,可缓存优化 |
|
|
|
|
|
+
|
|
|
|
|
+**改进建议:**
|
|
|
|
|
+1. 门控可改为频域计算(对 `|fft|` 做平均池化)而非空间域
|
|
|
|
|
+2. 低频半径可改为可学习参数
|
|
|
|
|
+3. meshgrid 可缓存为 buffer 避免重复计算
|
|
|
|
|
+
|
|
|
|
|
+### 验证结果
|
|
|
|
|
+所有模块测试通过,小波分解→重建误差 < 1e-4,输出形状一致。
|
|
|
|
|
+"""
|
|
|
|
|
+
|
|
|
|
|
+# ============================================================
|
|
|
|
|
+# 核心架构:XNet2D 医学图像分割网络
|
|
|
|
|
+# 业务意图:针对超声等医学图像分割任务,融合局部纹理、频率域、全局序列建模三重能力
|
|
|
|
|
+# 设计约束:
|
|
|
|
|
+# - 2D 张量通道优先 (N,C,H,W)
|
|
|
|
|
+# - 所有可逆变换需支持 inverse 恢复原始空间尺寸
|
|
|
|
|
+# - SSM 后端可切换:GPU→oflex,CPU→torch
|
|
|
|
|
+# ============================================================
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+# XNetStem2d:输入茎(Stem)
|
|
|
|
|
+# 为什么:将单张输入图快速降采样 4 倍 (H/4, W/4),并逐步提升通道维度
|
|
|
|
|
+# 关键行为:
|
|
|
|
|
+# - 两次步幅为 2 的卷积实现 4 倍下采样
|
|
|
|
|
+# - 中间嵌入 depthwise 卷积增强局部通道交互
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+class XNetStem2d(nn.Module):
|
|
|
|
|
+ def __init__(self, in_channels: int, stem_channels: int, out_channels: int) -> None:
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self.block = nn.Sequential(
|
|
|
|
|
+ Conv2dBN(in_channels, stem_channels, 3, 2, 1), # 首次下采样 H/2, W/2
|
|
|
|
|
+ nn.ReLU(inplace=True),
|
|
|
|
|
+ Conv2dBN(
|
|
|
|
|
+ stem_channels, stem_channels, 3, 1, 1, groups=stem_channels
|
|
|
|
|
+ ), # depthwise 局部特征增强
|
|
|
|
|
+ nn.ReLU(inplace=True),
|
|
|
|
|
+ Conv2dBN(stem_channels, out_channels, 1, 1, 0), # 通道升维
|
|
|
|
|
+ nn.ReLU(inplace=True),
|
|
|
|
|
+ Conv2dBN(out_channels, out_channels, 3, 2, 1), # 二次下采样 H/4, W/4
|
|
|
|
|
+ nn.ReLU(inplace=True),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
+ return self.block(x)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+# XNetDownsample2d:阶段间下采样器
|
|
|
|
|
+# 为什么:在编码器各阶段之间平滑过渡,降低空间分辨率同时增加通道数
|
|
|
|
|
+# 关键行为:
|
|
|
|
|
+# - 仅支持 conv 模式(扩展点由子类控制)
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+class XNetDownsample2d(nn.Module):
|
|
|
|
|
+ def __init__(self, in_channels: int, out_channels: int, mode: str = "conv") -> None:
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ if mode != "conv":
|
|
|
|
|
+ raise ValueError(f"Unsupported downsample mode: {mode}")
|
|
|
|
|
+ self.block = nn.Sequential(
|
|
|
|
|
+ Conv2dBN(in_channels, out_channels, 3, 2, 1), # H/2, W/2 下采样
|
|
|
|
|
+ nn.ReLU(inplace=True),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
+ return self.block(x)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+# XLocalBranch2d:局部感受野分支
|
|
|
|
|
+# 为什么:并行捕获 3×3 和 5×5 多尺度局部纹理,对医学图像边缘/细微结构敏感
|
|
|
|
|
+# 关键行为:
|
|
|
|
|
+# - 两组 depthwise 卷积 + 1×1 通道压缩
|
|
|
|
|
+# - 输出直接相加(残差式局部特征累积)
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+class XLocalBranch2d(nn.Module):
|
|
|
|
|
+ def __init__(self, channels: int) -> None:
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self.branch3 = nn.Sequential(
|
|
|
|
|
+ Conv2dBN(
|
|
|
|
|
+ channels, channels, 3, 1, 1, groups=channels
|
|
|
|
|
+ ), # 3×3 depthwise 局部感受野
|
|
|
|
|
+ nn.ReLU(inplace=True),
|
|
|
|
|
+ Conv2dBN(channels, channels, 1, 1, 0), # 1×1 通道重映射
|
|
|
|
|
+ )
|
|
|
|
|
+ self.branch5 = nn.Sequential(
|
|
|
|
|
+ Conv2dBN(
|
|
|
|
|
+ channels, channels, 5, 1, 2, groups=channels
|
|
|
|
|
+ ), # 5×5 depthwise 更大感受野
|
|
|
|
|
+ nn.ReLU(inplace=True),
|
|
|
|
|
+ Conv2dBN(channels, channels, 1, 1, 0),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
+ return self.branch3(x) + self.branch5(x) # 多尺度局部特征融合
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+# XWaveletTransform2d:基于 ptwt 的 2D 小波变换
|
|
|
|
|
+# 为什么:将特征分解为低频近似 (LL) 与高频细节 (LH, HL, HH),便于频率域操作
|
|
|
|
|
+# 关键行为:
|
|
|
|
|
+# - 使用 ptwt.wavedec2 / ptwt.waverec2 实现可逆小波分解与重建
|
|
|
|
|
+# - 支持任意 pywt 兼容小波(haar, db4, sym2 等)
|
|
|
|
|
+# - 输出格式:(ll_coeff, (lh_coeff, hl_coeff, hh_coeff))
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+class XWaveletTransform2d(nn.Module):
|
|
|
|
|
+ def __init__(self, wavelet: str = "haar", level: int = 1) -> None:
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self.wavelet = wavelet
|
|
|
|
|
+ self.level = level
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 分解输入张量。
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ ll: 低频近似系数 [B, C, H', W']
|
|
|
|
|
+ high: 高频细节张量,拼接 LH/HL/HH 为 [B, C*3, H', W']
|
|
|
|
|
+ """
|
|
|
|
|
+ coeffs = ptwt.wavedec2(x, self.wavelet, level=self.level)
|
|
|
|
|
+ ll = coeffs[0] # 低频近似
|
|
|
|
|
+ detail_tuple = coeffs[1] # (lh, hl, hh) 元组
|
|
|
|
|
+ high = torch.cat([detail_tuple[0], detail_tuple[1], detail_tuple[2]], dim=1)
|
|
|
|
|
+ return ll, high
|
|
|
|
|
+
|
|
|
|
|
+ def inverse(
|
|
|
|
|
+ self, ll: torch.Tensor, high: torch.Tensor, output_size: tuple[int, int]
|
|
|
|
|
+ ) -> torch.Tensor:
|
|
|
|
|
+ """
|
|
|
|
|
+ 从低频和高频系数重建原始张量。
|
|
|
|
|
+ Args:
|
|
|
|
|
+ ll: 低频近似系数
|
|
|
|
|
+ high: 高频细节张量 [B, C*3, H', W']
|
|
|
|
|
+ output_size: 目标输出尺寸 (H, W)
|
|
|
|
|
+ """
|
|
|
|
|
+ lh = high[:, 0 : high.shape[1] // 3]
|
|
|
|
|
+ hl = high[:, high.shape[1] // 3 : 2 * high.shape[1] // 3]
|
|
|
|
|
+ hh = high[:, 2 * high.shape[1] // 3 :]
|
|
|
|
|
+ coeffs = [ll, (lh, hl, hh)]
|
|
|
|
|
+ # ptwt.waverec2 自动处理边界对齐,无需手动裁剪
|
|
|
|
|
+ return ptwt.waverec2(coeffs, self.wavelet)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+# XWaveletBranch2d:小波分支
|
|
|
|
|
+# 为什么:对小波分解后的低频和高频分别做特征学习,再重建回空间域
|
|
|
|
|
+# 关键行为:
|
|
|
|
|
+# - 当前仅支持 Haar 小波和 level=1(设计约束)
|
|
|
|
|
+# - 高频通道数 = channels * 3,需单独投影
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+class XWaveletBranch2d(nn.Module):
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self, channels: int, wavelet_type: str = "haar", wavelet_level: int = 1
|
|
|
|
|
+ ) -> None:
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self.wavelet = XWaveletTransform2d(wavelet=wavelet_type, level=wavelet_level)
|
|
|
|
|
+ # 低频通道投影
|
|
|
|
|
+ self.ll_proj = nn.Sequential(
|
|
|
|
|
+ Conv2dBN(channels, channels, 3, 1, 1),
|
|
|
|
|
+ nn.ReLU(inplace=True),
|
|
|
|
|
+ )
|
|
|
|
|
+ # 高频通道投影(depthwise 处理多高频分量)
|
|
|
|
|
+ self.high_proj = nn.Sequential(
|
|
|
|
|
+ Conv2dBN(channels * 3, channels * 3, 3, 1, 1, groups=channels * 3),
|
|
|
|
|
+ nn.ReLU(inplace=True),
|
|
|
|
|
+ Conv2dBN(channels * 3, channels * 3, 1, 1, 0),
|
|
|
|
|
+ )
|
|
|
|
|
+ # 重建后输出投影
|
|
|
|
|
+ self.out_proj = nn.Sequential(
|
|
|
|
|
+ Conv2dBN(channels, channels, 1, 1, 0),
|
|
|
|
|
+ nn.ReLU(inplace=True),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
+ output_size = x.shape[-2:]
|
|
|
|
|
+ ll, high = self.wavelet(x) # 分解
|
|
|
|
|
+ ll = self.ll_proj(ll)
|
|
|
|
|
+ high = self.high_proj(high)
|
|
|
|
|
+ x = self.wavelet.inverse(ll, high, output_size=output_size) # 重建
|
|
|
|
|
+ return self.out_proj(x)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+# XSSMGlobalBranch2d:SSM 全局分支(核心:VMamba SS2D)
|
|
|
|
|
+# 为什么:用 State Space Model 捕获长程依赖,弥补卷积局部感受野不足
|
|
|
|
|
+# 关键行为:
|
|
|
|
|
+# - 自动选择后端:CUDA→oflex(快速),否则→torch(兼容)
|
|
|
|
|
+# - 通过 monkey-patch forward_core 动态切换 scan 策略
|
|
|
|
|
+# - 用完后恢复原始 forward_core 避免状态污染
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+class XSSMGlobalBranch2d(nn.Module):
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self,
|
|
|
|
|
+ channels: int,
|
|
|
|
|
+ global_ratio: float = 2.0,
|
|
|
|
|
+ d_state: int = 16,
|
|
|
|
|
+ forward_type: str = "v3",
|
|
|
|
|
+ ssm_backend: str = "auto",
|
|
|
|
|
+ ) -> None:
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ hidden_ratio = max(global_ratio, 1.0) # SSM 隐层缩放比例
|
|
|
|
|
+ self.backend = ssm_backend
|
|
|
|
|
+ self.pre = nn.Sequential(
|
|
|
|
|
+ Conv2dBN(channels, channels, 1, 1, 0), # 预投影归一化
|
|
|
|
|
+ nn.ReLU(inplace=True),
|
|
|
|
|
+ )
|
|
|
|
|
+ self.ssm = VMambaSS2D(
|
|
|
|
|
+ d_model=channels,
|
|
|
|
|
+ d_state=d_state,
|
|
|
|
|
+ ssm_ratio=hidden_ratio,
|
|
|
|
|
+ d_conv=3,
|
|
|
|
|
+ dropout=0.0,
|
|
|
|
|
+ initialize="v0",
|
|
|
|
|
+ forward_type=forward_type,
|
|
|
|
|
+ channel_first=True,
|
|
|
|
|
+ )
|
|
|
|
|
+ self.post = nn.Sequential(
|
|
|
|
|
+ Conv2dBN(channels, channels, 1, 1, 0), # 后投影归一化
|
|
|
|
|
+ nn.ReLU(inplace=True),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
+ x = self.pre(x)
|
|
|
|
|
+ prev_backend = None
|
|
|
|
|
+ backend = self.backend.lower()
|
|
|
|
|
+ if backend == "auto":
|
|
|
|
|
+ backend = "oflex" if x.is_cuda else "torch"
|
|
|
|
|
+
|
|
|
|
|
+ # 动态切换 SSM 后端(避免修改全局配置)
|
|
|
|
|
+ if backend == "oflex" and hasattr(self.ssm, "forward_core"):
|
|
|
|
|
+ prev_backend = self.ssm.forward_core
|
|
|
|
|
+ self.ssm.forward_core = lambda z, _core=prev_backend: _core(
|
|
|
|
|
+ z,
|
|
|
|
|
+ selective_scan_backend="oflex",
|
|
|
|
|
+ scan_force_torch=False,
|
|
|
|
|
+ )
|
|
|
|
|
+ elif backend == "torch" and hasattr(self.ssm, "forward_core"):
|
|
|
|
|
+ prev_backend = self.ssm.forward_core
|
|
|
|
|
+ self.ssm.forward_core = lambda z, _core=prev_backend: _core(
|
|
|
|
|
+ z,
|
|
|
|
|
+ selective_scan_backend="torch",
|
|
|
|
|
+ scan_force_torch=True,
|
|
|
|
|
+ )
|
|
|
|
|
+ try:
|
|
|
|
|
+ x = self.ssm(x) # SSM 全局建模
|
|
|
|
|
+ finally:
|
|
|
|
|
+ if prev_backend is not None:
|
|
|
|
|
+ self.ssm.forward_core = prev_backend # 恢复原始后端
|
|
|
|
|
+ return self.post(x)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+# XGlobalBranch2d:全局分支包装器
|
|
|
|
|
+# 为什么:提供统一接口,将 SSM 分支暴露为可开关的模块
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+class XGlobalBranch2d(nn.Module):
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self,
|
|
|
|
|
+ channels: int,
|
|
|
|
|
+ global_ratio: float = 2.0,
|
|
|
|
|
+ ssm_d_state: int = 16,
|
|
|
|
|
+ ssm_forward_type: str = "v3",
|
|
|
|
|
+ ssm_backend: str = "auto",
|
|
|
|
|
+ ) -> None:
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self.ssm_branch = XSSMGlobalBranch2d(
|
|
|
|
|
+ channels=channels,
|
|
|
|
|
+ global_ratio=global_ratio,
|
|
|
|
|
+ d_state=ssm_d_state,
|
|
|
|
|
+ forward_type=ssm_forward_type,
|
|
|
|
|
+ ssm_backend=ssm_backend,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
+ return self.ssm_branch(x)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+# XBranchFusion2d:多分支特征融合
|
|
|
|
|
+# 为什么:将局部/小波/全局三个分支的输出自适应加权融合
|
|
|
|
|
+# 关键行为:
|
|
|
|
|
+# - 通道拼接 → 1×1 压缩 → 通道注意力门控(Channel Attention Gate)
|
|
|
|
|
+# - 门控值经 Sigmoid 后与融合特征逐元素相乘
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+class XBranchFusion2d(nn.Module):
|
|
|
|
|
+ def __init__(self, channels: int, num_branches: int = 3) -> None:
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ fused_channels = channels * num_branches
|
|
|
|
|
+ hidden_channels = max(channels // 4, 8) # 门控网络隐藏维度
|
|
|
|
|
+ self.fuse = nn.Sequential(
|
|
|
|
|
+ Conv2dBN(fused_channels, channels, 1, 1, 0), # 通道降维融合
|
|
|
|
|
+ nn.ReLU(inplace=True),
|
|
|
|
|
+ )
|
|
|
|
|
+ # 通道注意力门控
|
|
|
|
|
+ self.gate = nn.Sequential(
|
|
|
|
|
+ nn.AdaptiveAvgPool2d(1), # 全局平均池化 → 空间不变
|
|
|
|
|
+ nn.Conv2d(fused_channels, hidden_channels, kernel_size=1, bias=True),
|
|
|
|
|
+ nn.ReLU(inplace=True),
|
|
|
|
|
+ nn.Conv2d(hidden_channels, channels, kernel_size=1, bias=True),
|
|
|
|
|
+ nn.Sigmoid(), # 门控值 [0, 1]
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, branch_outputs: Sequence[torch.Tensor]) -> torch.Tensor:
|
|
|
|
|
+ x_cat = torch.cat(list(branch_outputs), dim=1) # 拼接所有分支
|
|
|
|
|
+ x_fused = self.fuse(x_cat)
|
|
|
|
|
+ gate = self.gate(x_cat) # 计算通道门控
|
|
|
|
|
+ return x_fused * gate # 门控加权融合
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+# XTEB2d:X-Tri-Enhance-Block (2D) — 核心构建块
|
|
|
|
|
+# 为什么:将局部、小波、全局三个分支并行融合,并叠加 FFN 残差
|
|
|
|
|
+# 关键行为:
|
|
|
|
|
+# - pre_norm:先做 1×1 投影再输入多分支
|
|
|
|
|
+# - fusion:XBranchFusion2d 自适应融合三分支
|
|
|
|
|
+# - post + FFN:双层残差连接(post-fusion + FFN)
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+class XTEB2d(nn.Module):
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self,
|
|
|
|
|
+ channels: int,
|
|
|
|
|
+ global_ratio: float = 2.0,
|
|
|
|
|
+ wavelet_type: str = "haar",
|
|
|
|
|
+ wavelet_level: int = 1,
|
|
|
|
|
+ use_wavelet_branch: bool = True,
|
|
|
|
|
+ use_global_branch: bool = True,
|
|
|
|
|
+ ssm_d_state: int = 16,
|
|
|
|
|
+ ssm_forward_type: str = "v3",
|
|
|
|
|
+ ssm_backend: str = "auto",
|
|
|
|
|
+ ) -> None:
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self.pre_norm = Conv2dBN(channels, channels, 1, 1, 0) # 预投影
|
|
|
|
|
+ self.local_branch = XLocalBranch2d(channels) # 局部分支(始终启用)
|
|
|
|
|
+ # 小波分支(可开关)
|
|
|
|
|
+ self.wavelet_branch = (
|
|
|
|
|
+ XWaveletBranch2d(
|
|
|
|
|
+ channels, wavelet_type=wavelet_type, wavelet_level=wavelet_level
|
|
|
|
|
+ )
|
|
|
|
|
+ if use_wavelet_branch
|
|
|
|
|
+ else nn.Identity()
|
|
|
|
|
+ )
|
|
|
|
|
+ # 全局 SSM 分支(可开关)
|
|
|
|
|
+ self.global_branch = (
|
|
|
|
|
+ XGlobalBranch2d(
|
|
|
|
|
+ channels,
|
|
|
|
|
+ global_ratio=global_ratio,
|
|
|
|
|
+ ssm_d_state=ssm_d_state,
|
|
|
|
|
+ ssm_forward_type=ssm_forward_type,
|
|
|
|
|
+ ssm_backend=ssm_backend,
|
|
|
|
|
+ )
|
|
|
|
|
+ if use_global_branch
|
|
|
|
|
+ else nn.Identity()
|
|
|
|
|
+ )
|
|
|
|
|
+ self.fusion = XBranchFusion2d(channels, num_branches=3) # 三分支融合
|
|
|
|
|
+ # 后处理残差块
|
|
|
|
|
+ self.post = nn.Sequential(
|
|
|
|
|
+ Conv2dBN(channels, channels, 3, 1, 1),
|
|
|
|
|
+ nn.ReLU(inplace=True),
|
|
|
|
|
+ Conv2dBN(channels, channels, 1, 1, 0, bn_weight_init=0.0), # 零初始化
|
|
|
|
|
+ )
|
|
|
|
|
+ # FFN 残差块
|
|
|
|
|
+ self.ffn = nn.Sequential(
|
|
|
|
|
+ Conv2dBN(channels, channels * 2, 1, 1, 0), # 通道扩展
|
|
|
|
|
+ nn.ReLU(inplace=True),
|
|
|
|
|
+ Conv2dBN(channels * 2, channels, 1, 1, 0, bn_weight_init=0.0), # 零初始化
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
+ x_in = x
|
|
|
|
|
+ x = self.pre_norm(x)
|
|
|
|
|
+ # 三分支并行 + 融合 + 残差
|
|
|
|
|
+ x = x_in + self.post(
|
|
|
|
|
+ self.fusion(
|
|
|
|
|
+ [self.local_branch(x), self.wavelet_branch(x), self.global_branch(x)]
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
+ # FFN 残差
|
|
|
|
|
+ return x + self.ffn(x)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+# XNetEncoderStage2d:编码器阶段
|
|
|
|
|
+# 为什么:堆叠多个 XTEB2d 块作为单一编码器层级
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+class XNetEncoderStage2d(nn.Module):
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self,
|
|
|
|
|
+ channels: int,
|
|
|
|
|
+ depth: int,
|
|
|
|
|
+ global_ratio: float = 2.0,
|
|
|
|
|
+ wavelet_type: str = "haar",
|
|
|
|
|
+ wavelet_level: int = 1,
|
|
|
|
|
+ use_wavelet_branch: bool = True,
|
|
|
|
|
+ use_global_branch: bool = True,
|
|
|
|
|
+ ssm_d_state: int = 16,
|
|
|
|
|
+ ssm_forward_type: str = "v3",
|
|
|
|
|
+ ssm_backend: str = "auto",
|
|
|
|
|
+ ) -> None:
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self.blocks = nn.Sequential(
|
|
|
|
|
+ *[
|
|
|
|
|
+ XTEB2d(
|
|
|
|
|
+ channels=channels,
|
|
|
|
|
+ global_ratio=global_ratio,
|
|
|
|
|
+ wavelet_type=wavelet_type,
|
|
|
|
|
+ wavelet_level=wavelet_level,
|
|
|
|
|
+ use_wavelet_branch=use_wavelet_branch,
|
|
|
|
|
+ use_global_branch=use_global_branch,
|
|
|
|
|
+ ssm_d_state=ssm_d_state,
|
|
|
|
|
+ ssm_forward_type=ssm_forward_type,
|
|
|
|
|
+ ssm_backend=ssm_backend,
|
|
|
|
|
+ )
|
|
|
|
|
+ for _ in range(depth)
|
|
|
|
|
+ ]
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
+ return self.blocks(x)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+# XNetEncoder2d:完整编码器
|
|
|
|
|
+# 为什么:Stem + 4 个阶段 + 3 个下采样 → 多尺度特征金字塔 [e1, e2, e3, e4]
|
|
|
|
|
+# 关键约束:
|
|
|
|
|
+# - 阶段数固定为 4(由构造函数校验)
|
|
|
|
|
+# - Stage1 默认关闭全局 SSM(浅层特征不适合长程建模)
|
|
|
|
|
+# - stage_channels 属性暴露各阶段输出通道数
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+class XNetEncoder2d(nn.Module):
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self,
|
|
|
|
|
+ in_channels: int,
|
|
|
|
|
+ stem_channels: int,
|
|
|
|
|
+ encoder_channels: Sequence[int],
|
|
|
|
|
+ encoder_depths: Sequence[int],
|
|
|
|
|
+ global_ratio: float = 2.0,
|
|
|
|
|
+ wavelet_type: str = "haar",
|
|
|
|
|
+ wavelet_level: int = 1,
|
|
|
|
|
+ use_wavelet_branch: bool = True,
|
|
|
|
|
+ use_global_branch_stage1: bool = False,
|
|
|
|
|
+ ssm_d_state: int = 16,
|
|
|
|
|
+ ssm_forward_type: str = "v3",
|
|
|
|
|
+ ssm_backend: str = "auto",
|
|
|
|
|
+ ) -> None:
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ if len(encoder_channels) != 4 or len(encoder_depths) != 4:
|
|
|
|
|
+ raise ValueError("XNetEncoder2d expects 4 encoder stages.")
|
|
|
|
|
+ c1, c2, c3, c4 = encoder_channels
|
|
|
|
|
+ d1, d2, d3, d4 = encoder_depths
|
|
|
|
|
+ self.stem = XNetStem2d(in_channels, stem_channels, c1)
|
|
|
|
|
+ # Stage 1:浅层,可选关闭全局分支
|
|
|
|
|
+ self.stage1 = XNetEncoderStage2d(
|
|
|
|
|
+ c1,
|
|
|
|
|
+ d1,
|
|
|
|
|
+ global_ratio,
|
|
|
|
|
+ wavelet_type,
|
|
|
|
|
+ wavelet_level,
|
|
|
|
|
+ use_wavelet_branch=use_wavelet_branch,
|
|
|
|
|
+ use_global_branch=use_global_branch_stage1,
|
|
|
|
|
+ ssm_d_state=ssm_d_state,
|
|
|
|
|
+ ssm_forward_type=ssm_forward_type,
|
|
|
|
|
+ ssm_backend=ssm_backend,
|
|
|
|
|
+ )
|
|
|
|
|
+ self.down1 = XNetDownsample2d(c1, c2)
|
|
|
|
|
+ # Stage 2-4:始终启用全局分支
|
|
|
|
|
+ self.stage2 = XNetEncoderStage2d(
|
|
|
|
|
+ c2,
|
|
|
|
|
+ d2,
|
|
|
|
|
+ global_ratio,
|
|
|
|
|
+ wavelet_type,
|
|
|
|
|
+ wavelet_level,
|
|
|
|
|
+ use_wavelet_branch,
|
|
|
|
|
+ True,
|
|
|
|
|
+ ssm_d_state=ssm_d_state,
|
|
|
|
|
+ ssm_forward_type=ssm_forward_type,
|
|
|
|
|
+ ssm_backend=ssm_backend,
|
|
|
|
|
+ )
|
|
|
|
|
+ self.down2 = XNetDownsample2d(c2, c3)
|
|
|
|
|
+ self.stage3 = XNetEncoderStage2d(
|
|
|
|
|
+ c3,
|
|
|
|
|
+ d3,
|
|
|
|
|
+ global_ratio,
|
|
|
|
|
+ wavelet_type,
|
|
|
|
|
+ wavelet_level,
|
|
|
|
|
+ use_wavelet_branch,
|
|
|
|
|
+ True,
|
|
|
|
|
+ ssm_d_state=ssm_d_state,
|
|
|
|
|
+ ssm_forward_type=ssm_forward_type,
|
|
|
|
|
+ ssm_backend=ssm_backend,
|
|
|
|
|
+ )
|
|
|
|
|
+ self.down3 = XNetDownsample2d(c3, c4)
|
|
|
|
|
+ self.stage4 = XNetEncoderStage2d(
|
|
|
|
|
+ c4,
|
|
|
|
|
+ d4,
|
|
|
|
|
+ global_ratio,
|
|
|
|
|
+ wavelet_type,
|
|
|
|
|
+ wavelet_level,
|
|
|
|
|
+ use_wavelet_branch,
|
|
|
|
|
+ True,
|
|
|
|
|
+ ssm_d_state=ssm_d_state,
|
|
|
|
|
+ ssm_forward_type=ssm_forward_type,
|
|
|
|
|
+ ssm_backend=ssm_backend,
|
|
|
|
|
+ )
|
|
|
|
|
+ self.stage_channels = list(encoder_channels) # 暴露各阶段通道数
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
|
|
|
|
+ e1 = self.stage1(self.stem(x)) # 浅层特征
|
|
|
|
|
+ e2 = self.stage2(self.down1(e1)) # 中层特征
|
|
|
|
|
+ e3 = self.stage3(self.down2(e2)) # 深层特征
|
|
|
|
|
+ e4 = self.stage4(self.down3(e3)) # 最深特征
|
|
|
|
|
+ return [e1, e2, e3, e4] # 多尺度特征金字塔
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+# XGuideProjector2d:引导投影器
|
|
|
|
|
+# 为什么:从编码器特征生成引导信号(guide),用于解码器的自适应调制
|
|
|
|
|
+# 关键行为:
|
|
|
|
|
+# - affine 模式:输出 (gamma, beta) 用于仿射调制
|
|
|
|
|
+# - feature 模式:直接输出特征
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+class XGuideProjector2d(nn.Module):
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self, in_channels: int, out_channels: int, mode: str = "affine"
|
|
|
|
|
+ ) -> None:
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self.mode = mode
|
|
|
|
|
+ if mode == "affine":
|
|
|
|
|
+ # 输出双倍通道 → 后续拆分为 gamma 和 beta
|
|
|
|
|
+ self.proj = nn.Sequential(
|
|
|
|
|
+ Conv2dBN(in_channels, out_channels * 2, 1, 1, 0),
|
|
|
|
|
+ nn.ReLU(inplace=True),
|
|
|
|
|
+ nn.Conv2d(out_channels * 2, out_channels * 2, kernel_size=1, bias=True),
|
|
|
|
|
+ )
|
|
|
|
|
+ elif mode == "feature":
|
|
|
|
|
+ self.proj = nn.Sequential(
|
|
|
|
|
+ Conv2dBN(in_channels, out_channels, 1, 1, 0),
|
|
|
|
|
+ nn.ReLU(inplace=True),
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise ValueError(f"Unsupported guide mode: {mode}")
|
|
|
|
|
+
|
|
|
|
|
+ def forward(
|
|
|
|
|
+ self,
|
|
|
|
|
+ x: torch.Tensor,
|
|
|
|
|
+ target_size: tuple[int, int],
|
|
|
|
|
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
+ # 插值到目标尺寸(guide 需要与解码器特征空间对齐)
|
|
|
|
|
+ x = F.interpolate(x, size=target_size, mode="bilinear", align_corners=False)
|
|
|
|
|
+ x = self.proj(x)
|
|
|
|
|
+ if self.mode == "affine":
|
|
|
|
|
+ gamma, beta = torch.chunk(x, 2, dim=1) # 拆分为仿射参数
|
|
|
|
|
+ gamma = torch.sigmoid(gamma) + 0.5 # gamma 偏置到 [0.5, 1.5]
|
|
|
|
|
+ return gamma, beta
|
|
|
|
|
+ return x
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+# XSkipFusion2d:跳跃连接融合
|
|
|
|
|
+# 为什么:将编码器特征与解码器特征融合后传入
|
|
|
|
|
+# 关键行为:
|
|
|
|
|
+# - 分别投影输入和跳跃特征到相同维度
|
|
|
|
|
+# - 拼接 + 3×3 卷积融合
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+class XSkipFusion2d(nn.Module):
|
|
|
|
|
+ def __init__(self, in_channels: int, skip_channels: int, out_channels: int) -> None:
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self.input_proj = nn.Sequential(
|
|
|
|
|
+ Conv2dBN(in_channels, out_channels, 1, 1, 0), # 解码器特征投影
|
|
|
|
|
+ nn.ReLU(inplace=True),
|
|
|
|
|
+ )
|
|
|
|
|
+ self.skip_proj = nn.Sequential(
|
|
|
|
|
+ Conv2dBN(skip_channels, out_channels, 1, 1, 0), # 跳跃特征投影
|
|
|
|
|
+ nn.ReLU(inplace=True),
|
|
|
|
|
+ )
|
|
|
|
|
+ self.fuse = nn.Sequential(
|
|
|
|
|
+ Conv2dBN(out_channels * 2, out_channels, 3, 1, 1), # 拼接后融合
|
|
|
|
|
+ nn.ReLU(inplace=True),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
+ # 双线性插值对齐空间尺寸
|
|
|
|
|
+ x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
|
|
|
|
|
+ x = self.input_proj(x)
|
|
|
|
|
+ skip = self.skip_proj(skip)
|
|
|
|
|
+ return self.fuse(torch.cat([x, skip], dim=1)) # 通道拼接融合
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+# XGuideModulation2d:引导调制器
|
|
|
|
|
+# 为什么:对特征应用仿射调制 (gamma * x + beta) 或特征驱动调制
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+class XGuideModulation2d(nn.Module):
|
|
|
|
|
+ def __init__(self, channels: int, guide_mode: str = "affine") -> None:
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self.guide_mode = guide_mode
|
|
|
|
|
+ if guide_mode == "feature":
|
|
|
|
|
+ # feature 模式下先将 guide 转为仿射参数
|
|
|
|
|
+ self.to_affine = nn.Conv2d(channels, channels * 2, kernel_size=1, bias=True)
|
|
|
|
|
+
|
|
|
|
|
+ def forward(
|
|
|
|
|
+ self,
|
|
|
|
|
+ x: torch.Tensor,
|
|
|
|
|
+ guide: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
|
|
|
|
+ ) -> torch.Tensor:
|
|
|
|
|
+ if self.guide_mode == "affine":
|
|
|
|
|
+ gamma, beta = guide # 直接使用仿射参数
|
|
|
|
|
+ else:
|
|
|
|
|
+ gamma, beta = torch.chunk(self.to_affine(guide), 2, dim=1)
|
|
|
|
|
+ gamma = torch.sigmoid(gamma) + 0.5
|
|
|
|
|
+ return gamma * x + beta # 仿射调制
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+# XFrequencyRefine2d:频率域精炼
|
|
|
|
|
+# 为什么:在频域对低频/高频分别应用门控,增强关键频率成分
|
|
|
|
|
+# 关键行为:
|
|
|
|
|
+# - FFT → 低频中心保留 + 高频带通 → 逆 FFT
|
|
|
|
|
+# - 门控由自适应平均池化生成
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+class XFrequencyRefine2d(nn.Module):
|
|
|
|
|
+ def __init__(self, channels: int) -> None:
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ # 低频门控
|
|
|
|
|
+ self.low_gate = nn.Sequential(
|
|
|
|
|
+ nn.AdaptiveAvgPool2d(1),
|
|
|
|
|
+ nn.Conv2d(channels, channels, kernel_size=1, bias=True),
|
|
|
|
|
+ nn.Sigmoid(),
|
|
|
|
|
+ )
|
|
|
|
|
+ # 高频门控
|
|
|
|
|
+ self.high_gate = nn.Sequential(
|
|
|
|
|
+ nn.AdaptiveAvgPool2d(1),
|
|
|
|
|
+ nn.Conv2d(channels, channels, kernel_size=1, bias=True),
|
|
|
|
|
+ nn.Sigmoid(),
|
|
|
|
|
+ )
|
|
|
|
|
+ # 频域精炼后的空间域细化
|
|
|
|
|
+ self.refine = nn.Sequential(
|
|
|
|
|
+ Conv2dBN(
|
|
|
|
|
+ channels, channels, 3, 1, 1, groups=channels
|
|
|
|
|
+ ), # depthwise 局部细化
|
|
|
|
|
+ nn.ReLU(inplace=True),
|
|
|
|
|
+ Conv2dBN(channels, channels, 1, 1, 0),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
+ input_dtype = x.dtype
|
|
|
|
|
+ if x.dtype != torch.float32:
|
|
|
|
|
+ x = x.to(torch.float32) # FFT 需要 float32 精度
|
|
|
|
|
+ fft = torch.fft.rfft2(x, norm="ortho") # 实值 FFT
|
|
|
|
|
+ h_freq, w_freq = fft.shape[-2], fft.shape[-1]
|
|
|
|
|
+ # 构建圆形低频掩码(中心位于四个角:FFT 未 shift 时低频在四角)
|
|
|
|
|
+ # 使用 fftshift 将低频移至中心,应用掩码后再 ifftshift 还原
|
|
|
|
|
+ fft_shifted = torch.fft.fftshift(fft, dim=(-2, -1))
|
|
|
|
|
+ low = fft_shifted.clone()
|
|
|
|
|
+ # 圆形低频掩码:保留中心区域
|
|
|
|
|
+ radius_h = h_freq // 4
|
|
|
|
|
+ radius_w = w_freq // 4
|
|
|
|
|
+ y_grid, x_grid = torch.meshgrid(
|
|
|
|
|
+ torch.arange(h_freq, device=fft.device),
|
|
|
|
|
+ torch.arange(w_freq, device=fft.device),
|
|
|
|
|
+ indexing="ij",
|
|
|
|
|
+ )
|
|
|
|
|
+ center_y, center_x = h_freq // 2, w_freq // 2
|
|
|
|
|
+ mask = (y_grid - center_y) ** 2 + (x_grid - center_x) ** 2 <= max(
|
|
|
|
|
+ radius_h, radius_w
|
|
|
|
|
+ ) ** 2
|
|
|
|
|
+ mask = mask.unsqueeze(0).unsqueeze(0).expand(fft.shape[0], fft.shape[1], -1, -1)
|
|
|
|
|
+ low = low * mask # 低频分量
|
|
|
|
|
+ high = fft_shifted - low # 高频 = 全部 - 低频
|
|
|
|
|
+ # 还原到原始 FFT 坐标系
|
|
|
|
|
+ low = torch.fft.ifftshift(low, dim=(-2, -1))
|
|
|
|
|
+ high = torch.fft.ifftshift(high, dim=(-2, -1))
|
|
|
|
|
+ # 应用通道门控(门控值来自空间域)
|
|
|
|
|
+ low = low * self.low_gate(x)
|
|
|
|
|
+ high = high * self.high_gate(x)
|
|
|
|
|
+ out = torch.fft.irfft2(low + high, s=x.shape[-2:], norm="ortho") # 逆 FFT
|
|
|
|
|
+ out = out.to(dtype=input_dtype)
|
|
|
|
|
+ return self.refine(out) # 空间域细化
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+# XCRB2d:X-ResBlock with Guide (2D) — 解码器核心块
|
|
|
|
|
+# 为什么:融合跳跃连接 + 引导调制 + 频率精炼,是解码器重建的基础单元
|
|
|
|
|
+# 数据流:
|
|
|
|
|
+# 输入特征 → SkipFusion → GuideModulation → FrequencyRefine → OutRefine
|
|
|
|
|
+# 每步均有残差连接
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+class XCRB2d(nn.Module):
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self,
|
|
|
|
|
+ in_channels: int,
|
|
|
|
|
+ skip_channels: int,
|
|
|
|
|
+ guide_channels: int,
|
|
|
|
|
+ out_channels: int,
|
|
|
|
|
+ guide_mode: str = "affine",
|
|
|
|
|
+ use_frequency_refine: bool = True,
|
|
|
|
|
+ ) -> None:
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self.skip_fusion = XSkipFusion2d(in_channels, skip_channels, out_channels)
|
|
|
|
|
+ self.guide_modulation = XGuideModulation2d(out_channels, guide_mode=guide_mode)
|
|
|
|
|
+ self.frequency_refine = (
|
|
|
|
|
+ XFrequencyRefine2d(out_channels) if use_frequency_refine else nn.Identity()
|
|
|
|
|
+ )
|
|
|
|
|
+ # 输出细化(零初始化末尾以渐进学习)
|
|
|
|
|
+ self.out_refine = nn.Sequential(
|
|
|
|
|
+ Conv2dBN(out_channels, out_channels, 3, 1, 1),
|
|
|
|
|
+ nn.ReLU(inplace=True),
|
|
|
|
|
+ Conv2dBN(out_channels, out_channels, 3, 1, 1, bn_weight_init=0.0),
|
|
|
|
|
+ )
|
|
|
|
|
+ self.guide_channels = guide_channels
|
|
|
|
|
+
|
|
|
|
|
+ def forward(
|
|
|
|
|
+ self,
|
|
|
|
|
+ x: torch.Tensor,
|
|
|
|
|
+ skip: torch.Tensor,
|
|
|
|
|
+ guide: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
|
|
|
|
+ ) -> torch.Tensor:
|
|
|
|
|
+ x = self.skip_fusion(x, skip) # 跳跃融合
|
|
|
|
|
+ x = self.guide_modulation(x, guide) # 引导调制
|
|
|
|
|
+ x = x + self.frequency_refine(x) # 频率精炼残差
|
|
|
|
|
+ return x + self.out_refine(x) # 输出细化残差
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+# XNetHeadRefine2d:特征精炼头
|
|
|
|
|
+# 为什么:在解码器末端做最后的特征增强
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+class XNetHeadRefine2d(nn.Module):
|
|
|
|
|
+ def __init__(self, channels: int, out_channels: int | None = None) -> None:
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ if out_channels is None:
|
|
|
|
|
+ out_channels = channels
|
|
|
|
|
+ self.block = nn.Sequential(
|
|
|
|
|
+ Conv2dBN(channels, out_channels, 3, 1, 1),
|
|
|
|
|
+ nn.ReLU(inplace=True),
|
|
|
|
|
+ Conv2dBN(out_channels, out_channels, 3, 1, 1),
|
|
|
|
|
+ nn.ReLU(inplace=True),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
+ return self.block(x)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+# XNetDecoder2d:完整解码器
|
|
|
|
|
+# 为什么:从最深特征 e4 逐步上采样,逐层引入引导信号和跳跃连接
|
|
|
|
|
+# 关键数据流:
|
|
|
|
|
+# e4 → guide4 → dec4 → guide3 → dec3 → guide2 → dec2 → head_refine
|
|
|
|
|
+# 返回:输出特征、所有解码特征、所有引导信号(供损失函数使用)
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+class XNetDecoder2d(nn.Module):
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self,
|
|
|
|
|
+ encoder_channels: Sequence[int],
|
|
|
|
|
+ decoder_channels: Sequence[int] = (128, 64, 32),
|
|
|
|
|
+ guide_mode: str = "affine",
|
|
|
|
|
+ use_frequency_refine: bool = True,
|
|
|
|
|
+ out_channels: int | None = None,
|
|
|
|
|
+ ) -> None:
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ if len(encoder_channels) != 4:
|
|
|
|
|
+ raise ValueError("XNetDecoder2d expects 4 encoder stages.")
|
|
|
|
|
+ if len(decoder_channels) != 3:
|
|
|
|
|
+ raise ValueError("XNetDecoder2d expects 3 decoder channels.")
|
|
|
|
|
+ c1, c2, c3, c4 = encoder_channels
|
|
|
|
|
+ d4, d3, d2 = decoder_channels
|
|
|
|
|
+ # 引导投影器(从编码器特征生成 guide)
|
|
|
|
|
+ self.guide4 = XGuideProjector2d(c4, d4, mode=guide_mode)
|
|
|
|
|
+ self.guide3 = XGuideProjector2d(c3, d3, mode=guide_mode)
|
|
|
|
|
+ self.guide2 = XGuideProjector2d(c2, d2, mode=guide_mode)
|
|
|
|
|
+ # 解码块(逐层降通道 + 跳跃融合)
|
|
|
|
|
+ self.dec4 = XCRB2d(
|
|
|
|
|
+ c4,
|
|
|
|
|
+ c3,
|
|
|
|
|
+ d4,
|
|
|
|
|
+ d4,
|
|
|
|
|
+ guide_mode=guide_mode,
|
|
|
|
|
+ use_frequency_refine=use_frequency_refine,
|
|
|
|
|
+ )
|
|
|
|
|
+ self.dec3 = XCRB2d(
|
|
|
|
|
+ d4,
|
|
|
|
|
+ c2,
|
|
|
|
|
+ d3,
|
|
|
|
|
+ d3,
|
|
|
|
|
+ guide_mode=guide_mode,
|
|
|
|
|
+ use_frequency_refine=use_frequency_refine,
|
|
|
|
|
+ )
|
|
|
|
|
+ self.dec2 = XCRB2d(
|
|
|
|
|
+ d3,
|
|
|
|
|
+ c1,
|
|
|
|
|
+ d2,
|
|
|
|
|
+ d2,
|
|
|
|
|
+ guide_mode=guide_mode,
|
|
|
|
|
+ use_frequency_refine=use_frequency_refine,
|
|
|
|
|
+ )
|
|
|
|
|
+ self.head_refine = XNetHeadRefine2d(d2, out_channels or d2)
|
|
|
|
|
+ self.out_channels = out_channels or d2
|
|
|
|
|
+
|
|
|
|
|
+ def forward(
|
|
|
|
|
+ self,
|
|
|
|
|
+ features: Sequence[torch.Tensor],
|
|
|
|
|
+ ) -> tuple[
|
|
|
|
|
+ torch.Tensor,
|
|
|
|
|
+ list[torch.Tensor],
|
|
|
|
|
+ list[torch.Tensor | tuple[torch.Tensor, torch.Tensor]],
|
|
|
|
|
+ ]:
|
|
|
|
|
+ e1, e2, e3, e4 = features
|
|
|
|
|
+ # 从深到浅逐层解码
|
|
|
|
|
+ g4 = self.guide4(e4, target_size=e3.shape[-2:]) # 从 e4 生成 guide
|
|
|
|
|
+ d4 = self.dec4(e4, e3, g4) # 解码 + 跳跃 e3
|
|
|
|
|
+ g3 = self.guide3(e3, target_size=e2.shape[-2:])
|
|
|
|
|
+ d3 = self.dec3(d4, e2, g3) # 解码 + 跳跃 e2
|
|
|
|
|
+ g2 = self.guide2(e2, target_size=e1.shape[-2:])
|
|
|
|
|
+ d2 = self.dec2(d3, e1, g2) # 解码 + 跳跃 e1
|
|
|
|
|
+ d1 = self.head_refine(d2) # 最终精炼
|
|
|
|
|
+ # 返回解码输出、中间特征(用于辅助损失)、引导信号
|
|
|
|
|
+ return d1, [d4, d3, d2, d1], [g4, g3, g2]
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+# XNetSegHead2d:分割头
|
|
|
|
|
+# 为什么:将最终特征映射为 logits 图,并上采样到原始输入尺寸
|
|
|
|
|
+# --------------------------------------------------------------------------
|
|
|
|
|
+class XNetSegHead2d(nn.Module):
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self, in_channels: int, num_classes: int, upsample_scale: int = 4
|
|
|
|
|
+ ) -> None:
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self.block = nn.Sequential(
|
|
|
|
|
+ Conv2dBN(in_channels, in_channels, 3, 1, 1),
|
|
|
|
|
+ nn.ReLU(inplace=True),
|
|
|
|
|
+ nn.Conv2d(
|
|
|
|
|
+ in_channels, num_classes, kernel_size=1, bias=True
|
|
|
|
|
+ ), # 映射到类别数
|
|
|
|
|
+ )
|
|
|
|
|
+ self.upsample_scale = upsample_scale
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x: torch.Tensor, output_size: tuple[int, int]) -> torch.Tensor:
|
|
|
|
|
+ x = self.block(x)
|
|
|
|
|
+ # 双线性上采样到目标尺寸(推理时传入原始输入 H, W)
|
|
|
|
|
+ return F.interpolate(x, size=output_size, mode="bilinear", align_corners=False)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# ==========================================================================
|
|
|
|
|
+# XNet2d:完整网络(编码器 + Bottleneck + 解码器 + 分割头)
|
|
|
|
|
+# 架构概览:
|
|
|
|
|
+# 输入 → Stem → [Stage1 ↓ Stage2 ↓ Stage3 ↓ Stage4] → Bottleneck
|
|
|
|
|
+# → [dec4 ← dec3 ← dec2] → Head → Logits
|
|
|
|
|
+# 业务特点:
|
|
|
|
|
+# - 编码器浅层(Stage1)默认关闭 SSM 以降低计算开销
|
|
|
|
|
+# - 解码器逐层注入 guide 信号,实现自适应特征调制
|
|
|
|
|
+# - 每个解码块支持频率精炼,增强医学图像细节保留
|
|
|
|
|
+# ==========================================================================
|
|
|
|
|
+class XNet2d(nn.Module):
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self,
|
|
|
|
|
+ in_channels: int,
|
|
|
|
|
+ num_classes: int,
|
|
|
|
|
+ encoder_channels: Sequence[int] = (32, 64, 128, 192),
|
|
|
|
|
+ encoder_depths: Sequence[int] = (2, 2, 2, 2),
|
|
|
|
|
+ decoder_channels: Sequence[int] = (128, 64, 32),
|
|
|
|
|
+ stem_channels: int = 24,
|
|
|
|
|
+ bottleneck_depth: int = 1,
|
|
|
|
|
+ global_ratio: float = 2.0,
|
|
|
|
|
+ wavelet_type: str = "haar",
|
|
|
|
|
+ wavelet_level: int = 1,
|
|
|
|
|
+ use_wavelet_branch: bool = True,
|
|
|
|
|
+ use_global_branch_stage1: bool = False,
|
|
|
|
|
+ ssm_d_state: int = 16,
|
|
|
|
|
+ ssm_forward_type: str = "v3",
|
|
|
|
|
+ ssm_backend: str = "auto",
|
|
|
|
|
+ use_frequency_refine: bool = True,
|
|
|
|
|
+ guide_mode: str = "affine",
|
|
|
|
|
+ out_channels: int | None = None,
|
|
|
|
|
+ ) -> None:
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ # 编码器:多尺度特征金字塔
|
|
|
|
|
+ self.encoder = XNetEncoder2d(
|
|
|
|
|
+ in_channels=in_channels,
|
|
|
|
|
+ stem_channels=stem_channels,
|
|
|
|
|
+ encoder_channels=encoder_channels,
|
|
|
|
|
+ encoder_depths=encoder_depths,
|
|
|
|
|
+ global_ratio=global_ratio,
|
|
|
|
|
+ wavelet_type=wavelet_type,
|
|
|
|
|
+ wavelet_level=wavelet_level,
|
|
|
|
|
+ use_wavelet_branch=use_wavelet_branch,
|
|
|
|
|
+ use_global_branch_stage1=use_global_branch_stage1,
|
|
|
|
|
+ ssm_d_state=ssm_d_state,
|
|
|
|
|
+ ssm_forward_type=ssm_forward_type,
|
|
|
|
|
+ ssm_backend=ssm_backend,
|
|
|
|
|
+ )
|
|
|
|
|
+ # Bottleneck:最深特征进一步建模
|
|
|
|
|
+ bottleneck_channels = encoder_channels[-1]
|
|
|
|
|
+ self.bottleneck = nn.Sequential(
|
|
|
|
|
+ *[
|
|
|
|
|
+ XTEB2d(
|
|
|
|
|
+ channels=bottleneck_channels,
|
|
|
|
|
+ global_ratio=global_ratio,
|
|
|
|
|
+ wavelet_type=wavelet_type,
|
|
|
|
|
+ wavelet_level=wavelet_level,
|
|
|
|
|
+ use_wavelet_branch=use_wavelet_branch,
|
|
|
|
|
+ use_global_branch=True, # bottleneck 始终启用全局分支
|
|
|
|
|
+ ssm_d_state=ssm_d_state,
|
|
|
|
|
+ ssm_forward_type=ssm_forward_type,
|
|
|
|
|
+ ssm_backend=ssm_backend,
|
|
|
|
|
+ )
|
|
|
|
|
+ for _ in range(bottleneck_depth)
|
|
|
|
|
+ ]
|
|
|
|
|
+ )
|
|
|
|
|
+ # 解码器
|
|
|
|
|
+ self.decoder = XNetDecoder2d(
|
|
|
|
|
+ encoder_channels=encoder_channels,
|
|
|
|
|
+ decoder_channels=decoder_channels,
|
|
|
|
|
+ guide_mode=guide_mode,
|
|
|
|
|
+ use_frequency_refine=use_frequency_refine,
|
|
|
|
|
+ out_channels=out_channels,
|
|
|
|
|
+ )
|
|
|
|
|
+ # 分割头
|
|
|
|
|
+ head_in_channels = self.decoder.out_channels
|
|
|
|
|
+ self.segmentation_head = XNetSegHead2d(head_in_channels, num_classes)
|
|
|
|
|
+
|
|
|
|
|
+ def forward(
|
|
|
|
|
+ self, x: torch.Tensor
|
|
|
|
|
+ ) -> dict[
|
|
|
|
|
+ str, torch.Tensor | list[torch.Tensor] | list[tuple[torch.Tensor, torch.Tensor]]
|
|
|
|
|
+ ]:
|
|
|
|
|
+ encoder_features = self.encoder(x) # 多尺度特征 [e1, e2, e3, e4]
|
|
|
|
|
+ encoder_features[-1] = self.bottleneck(encoder_features[-1]) # bottleneck
|
|
|
|
|
+ decoder_out, decoder_features, guides = self.decoder(encoder_features) # 解码
|
|
|
|
|
+ output_size = x.shape[-2:]
|
|
|
|
|
+ logits = self.segmentation_head(
|
|
|
|
|
+ decoder_out, output_size=output_size
|
|
|
|
|
+ ) # 分割 logits
|
|
|
|
|
+ # 返回字典:包含 logits、中间特征(用于辅助损失)、引导信号
|
|
|
|
|
+ outputs: dict[
|
|
|
|
|
+ str,
|
|
|
|
|
+ torch.Tensor | list[torch.Tensor] | list[tuple[torch.Tensor, torch.Tensor]],
|
|
|
|
|
+ ] = {
|
|
|
|
|
+ "logits": logits,
|
|
|
|
|
+ "seg_logits": logits,
|
|
|
|
|
+ "encoder_features": encoder_features,
|
|
|
|
|
+ "decoder_features": decoder_features,
|
|
|
|
|
+ "guides": guides,
|
|
|
|
|
+ }
|
|
|
|
|
+ return outputs
|