XNet_method.md 6.3 KB

XNet2d 方法定义与当前实现

更新日期:2026-05-31 当前状态:已去掉旧版 X-shaped 斜向 guide 流,decoder 恢复为普通 U-Net 同尺度 skip 连接。

1. 方法总述

XNet2d 是当前项目用于 2D 超声图像分割的主模型。当前实现可以概括为:

CNN-Wavelet-VMamba encoder + plain U-Net skip decoder + segmentation head

模型保留的核心设计是:

  1. encoder 使用 XTEB2d,在每个尺度内融合局部纹理、小波频率信息和 VMamba-style SS2D 全局建模。
  2. bottleneck 继续使用 XTEB2d 加强最深层语义表征。
  3. decoder 使用普通 U-Net 式逐级上采样和同尺度 skip 融合。
  4. decoder block 内可选使用频率细化模块,帮助恢复低对比度边界和局部细节。
  5. segmentation head 将 H/4 x W/4 的 decoder 输出上采样回输入分辨率。

旧版设计中的 XGuideProjector2dXGuideModulation2d 和 diagonal guide path 已从当前代码中移除。配置文件中仍可能保留 guide_mode 字段,但它只用于兼容旧 YAML,不参与当前前向计算。

2. 符号与张量尺寸

设输入图像为:

I ∈ R^{B×C_in×H×W}

4 级 encoder 输出:

E1 ∈ R^{B×C1×H/4 ×W/4}
E2 ∈ R^{B×C2×H/8 ×W/8}
E3 ∈ R^{B×C3×H/16×W/16}
E4 ∈ R^{B×C4×H/32×W/32}

默认通道配置:

C1 = 32, C2 = 64, C3 = 128, C4 = 192

decoder 输出:

D4 ∈ R^{B×128×H/16×W/16}
D3 ∈ R^{B×64 ×H/8 ×W/8}
D2 ∈ R^{B×32 ×H/4 ×W/4}
D1 ∈ R^{B×32 ×H/4 ×W/4}

最终分割输出:

Y ∈ R^{B×K×H×W}

其中 K 是输出类别数。当前二值分割配置通常取 K=1

3. 当前网络拓扑

Input
  |
  v
Stem
  |
  v
E1 -> Down1 -> E2 -> Down2 -> E3 -> Down3 -> E4
  |             |             |
  |             |             |
  |             |             v
  |             |        Bottleneck
  |             |             |
  |             |             v
  |             |        Dec4 + skip E3 -> D4
  |             |             |
  |             v             v
  |        Dec3 + skip E2 -> D3
  |             |
  v             v
Dec2 + skip E1 -> D2
  |
  v
HeadRefine -> SegHead -> logits

形式化写法:

B0 = Bottleneck(E4)
D4 = Dec4(B0, E3)
D3 = Dec3(D4, E2)
D2 = Dec2(D3, E1)
D1 = Refine(D2)
Y  = SegHead(D1)

这就是普通 U-Net 的同尺度 skip 连接方式,没有额外斜向 guide:

E3 -> D4
E2 -> D3
E1 -> D2

4. Encoder:XTEB2d

XTEB2d 是 encoder 的基本 block。它的职责是提取稳健的超声图像表征,包含三类信息:

  1. Local branch:深度可分离卷积,建模局部纹理、 speckle pattern 和边界细节。
  2. Wavelet branch:一级 Haar DWT/IDWT,分别处理低频结构和高频细节。
  3. Global branch:VMamba-style SS2D,建模长程依赖和全局结构一致性。

简化流程:

Input X
  |
  +-- local branch
  |
  +-- wavelet branch
  |
  +-- global SS2D branch
  |
  v
concat -> 1x1 fusion -> channel gate -> residual FFN -> Output

当前实现细节:

wavelet_type = haar
wavelet_level = 1
ssm_forward_type = v3
ssm_backend = auto

ssm_backend=auto 时,CUDA 上优先使用 oflex selective scan,CPU 上走 torch fallback。

5. Decoder:普通 U-Net Skip + Frequency Refine

当前 decoder block 是 XCRB2d。它不再是 cross-guided block,而是普通 reconstruction block:

decoder input
  |
  v
bilinear upsample to skip size
  |
  v
1x1 projection
  |
  +-----------------------------+
                                |
same-scale skip                 |
  |                             |
  v                             |
1x1 projection                  |
  |                             |
  +----------- concat ----------+
                  |
                  v
             3x3 fusion
                  |
                  v
        optional frequency refine
                  |
                  v
        residual spatial refine

数学上可写为:

U_i = P_u(Up(D_{i+1}))
S_i = P_s(E_i)
F_i = Conv3×3(concat[U_i, S_i])
R_i = FrequencyRefine(F_i)
D_i = F_i + R_i + SpatialRefine(F_i + R_i)

其中 FrequencyRefine 可通过配置关闭:

model:
  use_frequency_refine: false

6. Frequency Refine

XFrequencyRefine2d 用于 decoder 融合后的特征。它在频域中分离低频和高频成分,并用轻量 gate 做重加权。

流程:

feature F
  |
  v
cast to float32 if needed
  |
  v
rfft2
  |
  +-- low frequency mask
  |
  +-- high frequency residual
  |
  v
low/high learnable gates
  |
  v
irfft2
  |
  v
cast back to input dtype
  |
  v
depthwise conv refine

FFT 部分显式使用 float32,用于避免 AMP 下复杂半精度 FFT warning。

7. Forward 输出

XNet2d.forward(x) 返回:

{
    "logits": logits,
    "seg_logits": logits,
    "encoder_features": [e1, e2, e3, e4],
    "decoder_features": [d4, d3, d2, d1],
    "guides": [],
}

训练主链只使用:

outputs["seg_logits"]

encoder_featuresdecoder_features 用于调试、可视化和后续辅助分析。guides 当前固定为空列表,只是为了兼容旧调试接口。

8. 当前配置说明

典型模型配置:

model:
  in_channels: 3
  encoder_channels: [32, 64, 128, 192]
  encoder_depths: [2, 2, 2, 2]
  decoder_channels: [128, 64, 32]
  stem_channels: 24
  bottleneck_depth: 1
  global_ratio: 2.0
  wavelet_type: haar
  wavelet_level: 1
  use_wavelet_branch: true
  use_global_branch_stage1: false
  ssm_d_state: 16
  ssm_forward_type: v3
  ssm_backend: auto
  use_frequency_refine: true
  guide_mode: affine
  out_channels: null

注意:guide_mode 当前不再控制任何 decoder guide 模块,只是兼容旧配置字段。后续如果清理 YAML,可以删除该字段。

9. 建议消融实验

当前实现更适合做下面这些消融:

  1. 去掉 wavelet branch。
  2. 去掉 global SS2D branch。
  3. 去掉 decoder frequency refine。
  4. 调整 encoder depth,例如 [2,2,2,2][2,2,3,2]
  5. 调整 decoder channel,例如 [128,64,32] 与更宽配置。

10. 小结

当前 XNet2d 的方法定位应写成:

An asymmetric encoder-decoder segmentation network with tri-branch ultrasound feature encoding and a plain U-Net skip decoder.