xnet_2d_zh.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980
  1. from __future__ import annotations
  2. from collections.abc import Sequence
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. import ptwt
  7. from .layers_2d import Conv2dBN
  8. from .lib_mamba.vmamba import SS2D as VMambaSS2D
  9. """
  10. ## 完成的修改
  11. ### 1. 小波变换模块迁移至 ptwt
  12. - **替换 `XHaarWaveletTransform2d` → `XWaveletTransform2d`**:使用 `ptwt.wavedec2` / `ptwt.waverec2` 实现可逆小波变换
  13. - **优势**:
  14. - 支持任意 pywt 兼容小波(haar, db4, sym2, db6 等),通过 `wavelet_type` 参数切换
  15. - 自动处理边界对齐,无需手动 padding/cropping
  16. - 代码更简洁,无手工卷积滤波器
  17. - **`XWaveletBranch2d`** 已更新引用新类,移除了 wavelet 类型限制检查
  18. ### 2. XFrequencyRefine2d 频率域精炼模块分析
  19. **发现的问题与修复:**
  20. - **原代码 FFT 低频掩码位置错误**:未使用 `fftshift`,直接在左上角做十字掩码,与真实低频位置(四角)不匹配
  21. - **已修复**:使用 `fftshift` → 圆形低频掩码 → `ifftshift` 还原的正确流程
  22. **设计合理性评估:**
  23. | 方面 | 评价 |
  24. |------|------|
  25. | 低频/高频分离 | ✅ 圆形掩码合理,可调节半径 |
  26. | 门控机制 | ⚠️ 门控值来自空间域而非频域,可能损失频域选择性 |
  27. | 通道注意力 | ✅ 每个通道独立门控,灵活 |
  28. | 重建精度 | ✅ 正交归一化 FFT + 完整频域保留 |
  29. | 计算开销 | ⚠️ meshgrid 每步计算,可缓存优化 |
  30. **改进建议:**
  31. 1. 门控可改为频域计算(对 `|fft|` 做平均池化)而非空间域
  32. 2. 低频半径可改为可学习参数
  33. 3. meshgrid 可缓存为 buffer 避免重复计算
  34. ### 验证结果
  35. 所有模块测试通过,小波分解→重建误差 < 1e-4,输出形状一致。
  36. """
  37. # ============================================================
  38. # 核心架构:XNet2D 医学图像分割网络
  39. # 业务意图:针对超声等医学图像分割任务,融合局部纹理、频率域、全局序列建模三重能力
  40. # 设计约束:
  41. # - 2D 张量通道优先 (N,C,H,W)
  42. # - 所有可逆变换需支持 inverse 恢复原始空间尺寸
  43. # - SSM 后端可切换:GPU→oflex,CPU→torch
  44. # ============================================================
  45. # --------------------------------------------------------------------------
  46. # XNetStem2d:输入茎(Stem)
  47. # 为什么:将单张输入图快速降采样 4 倍 (H/4, W/4),并逐步提升通道维度
  48. # 关键行为:
  49. # - 两次步幅为 2 的卷积实现 4 倍下采样
  50. # - 中间嵌入 depthwise 卷积增强局部通道交互
  51. # --------------------------------------------------------------------------
  52. class XNetStem2d(nn.Module):
  53. def __init__(self, in_channels: int, stem_channels: int, out_channels: int) -> None:
  54. super().__init__()
  55. self.block = nn.Sequential(
  56. Conv2dBN(in_channels, stem_channels, 3, 2, 1), # 首次下采样 H/2, W/2
  57. nn.ReLU(inplace=True),
  58. Conv2dBN(
  59. stem_channels, stem_channels, 3, 1, 1, groups=stem_channels
  60. ), # depthwise 局部特征增强
  61. nn.ReLU(inplace=True),
  62. Conv2dBN(stem_channels, out_channels, 1, 1, 0), # 通道升维
  63. nn.ReLU(inplace=True),
  64. Conv2dBN(out_channels, out_channels, 3, 2, 1), # 二次下采样 H/4, W/4
  65. nn.ReLU(inplace=True),
  66. )
  67. def forward(self, x: torch.Tensor) -> torch.Tensor:
  68. return self.block(x)
  69. # --------------------------------------------------------------------------
  70. # XNetDownsample2d:阶段间下采样器
  71. # 为什么:在编码器各阶段之间平滑过渡,降低空间分辨率同时增加通道数
  72. # 关键行为:
  73. # - 仅支持 conv 模式(扩展点由子类控制)
  74. # --------------------------------------------------------------------------
  75. class XNetDownsample2d(nn.Module):
  76. def __init__(self, in_channels: int, out_channels: int, mode: str = "conv") -> None:
  77. super().__init__()
  78. if mode != "conv":
  79. raise ValueError(f"Unsupported downsample mode: {mode}")
  80. self.block = nn.Sequential(
  81. Conv2dBN(in_channels, out_channels, 3, 2, 1), # H/2, W/2 下采样
  82. nn.ReLU(inplace=True),
  83. )
  84. def forward(self, x: torch.Tensor) -> torch.Tensor:
  85. return self.block(x)
  86. # --------------------------------------------------------------------------
  87. # XLocalBranch2d:局部感受野分支
  88. # 为什么:并行捕获 3×3 和 5×5 多尺度局部纹理,对医学图像边缘/细微结构敏感
  89. # 关键行为:
  90. # - 两组 depthwise 卷积 + 1×1 通道压缩
  91. # - 输出直接相加(残差式局部特征累积)
  92. # --------------------------------------------------------------------------
  93. class XLocalBranch2d(nn.Module):
  94. def __init__(self, channels: int) -> None:
  95. super().__init__()
  96. self.branch3 = nn.Sequential(
  97. Conv2dBN(
  98. channels, channels, 3, 1, 1, groups=channels
  99. ), # 3×3 depthwise 局部感受野
  100. nn.ReLU(inplace=True),
  101. Conv2dBN(channels, channels, 1, 1, 0), # 1×1 通道重映射
  102. )
  103. self.branch5 = nn.Sequential(
  104. Conv2dBN(
  105. channels, channels, 5, 1, 2, groups=channels
  106. ), # 5×5 depthwise 更大感受野
  107. nn.ReLU(inplace=True),
  108. Conv2dBN(channels, channels, 1, 1, 0),
  109. )
  110. def forward(self, x: torch.Tensor) -> torch.Tensor:
  111. return self.branch3(x) + self.branch5(x) # 多尺度局部特征融合
  112. # --------------------------------------------------------------------------
  113. # XWaveletTransform2d:基于 ptwt 的 2D 小波变换
  114. # 为什么:将特征分解为低频近似 (LL) 与高频细节 (LH, HL, HH),便于频率域操作
  115. # 关键行为:
  116. # - 使用 ptwt.wavedec2 / ptwt.waverec2 实现可逆小波分解与重建
  117. # - 支持任意 pywt 兼容小波(haar, db4, sym2 等)
  118. # - 输出格式:(ll_coeff, (lh_coeff, hl_coeff, hh_coeff))
  119. # --------------------------------------------------------------------------
  120. class XWaveletTransform2d(nn.Module):
  121. def __init__(self, wavelet: str = "haar", level: int = 1) -> None:
  122. super().__init__()
  123. self.wavelet = wavelet
  124. self.level = level
  125. def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  126. """
  127. 分解输入张量。
  128. Returns:
  129. ll: 低频近似系数 [B, C, H', W']
  130. high: 高频细节张量,拼接 LH/HL/HH 为 [B, C*3, H', W']
  131. """
  132. coeffs = ptwt.wavedec2(x, self.wavelet, level=self.level)
  133. ll = coeffs[0] # 低频近似
  134. detail_tuple = coeffs[1] # (lh, hl, hh) 元组
  135. high = torch.cat([detail_tuple[0], detail_tuple[1], detail_tuple[2]], dim=1)
  136. return ll, high
  137. def inverse(
  138. self, ll: torch.Tensor, high: torch.Tensor, output_size: tuple[int, int]
  139. ) -> torch.Tensor:
  140. """
  141. 从低频和高频系数重建原始张量。
  142. Args:
  143. ll: 低频近似系数
  144. high: 高频细节张量 [B, C*3, H', W']
  145. output_size: 目标输出尺寸 (H, W)
  146. """
  147. lh = high[:, 0 : high.shape[1] // 3]
  148. hl = high[:, high.shape[1] // 3 : 2 * high.shape[1] // 3]
  149. hh = high[:, 2 * high.shape[1] // 3 :]
  150. coeffs = [ll, (lh, hl, hh)]
  151. # ptwt.waverec2 自动处理边界对齐,无需手动裁剪
  152. return ptwt.waverec2(coeffs, self.wavelet)
  153. # --------------------------------------------------------------------------
  154. # XWaveletBranch2d:小波分支
  155. # 为什么:对小波分解后的低频和高频分别做特征学习,再重建回空间域
  156. # 关键行为:
  157. # - 当前仅支持 Haar 小波和 level=1(设计约束)
  158. # - 高频通道数 = channels * 3,需单独投影
  159. # --------------------------------------------------------------------------
  160. class XWaveletBranch2d(nn.Module):
  161. def __init__(
  162. self, channels: int, wavelet_type: str = "haar", wavelet_level: int = 1
  163. ) -> None:
  164. super().__init__()
  165. self.wavelet = XWaveletTransform2d(wavelet=wavelet_type, level=wavelet_level)
  166. # 低频通道投影
  167. self.ll_proj = nn.Sequential(
  168. Conv2dBN(channels, channels, 3, 1, 1),
  169. nn.ReLU(inplace=True),
  170. )
  171. # 高频通道投影(depthwise 处理多高频分量)
  172. self.high_proj = nn.Sequential(
  173. Conv2dBN(channels * 3, channels * 3, 3, 1, 1, groups=channels * 3),
  174. nn.ReLU(inplace=True),
  175. Conv2dBN(channels * 3, channels * 3, 1, 1, 0),
  176. )
  177. # 重建后输出投影
  178. self.out_proj = nn.Sequential(
  179. Conv2dBN(channels, channels, 1, 1, 0),
  180. nn.ReLU(inplace=True),
  181. )
  182. def forward(self, x: torch.Tensor) -> torch.Tensor:
  183. output_size = x.shape[-2:]
  184. ll, high = self.wavelet(x) # 分解
  185. ll = self.ll_proj(ll)
  186. high = self.high_proj(high)
  187. x = self.wavelet.inverse(ll, high, output_size=output_size) # 重建
  188. return self.out_proj(x)
  189. # --------------------------------------------------------------------------
  190. # XSSMGlobalBranch2d:SSM 全局分支(核心:VMamba SS2D)
  191. # 为什么:用 State Space Model 捕获长程依赖,弥补卷积局部感受野不足
  192. # 关键行为:
  193. # - 自动选择后端:CUDA→oflex(快速),否则→torch(兼容)
  194. # - 通过 monkey-patch forward_core 动态切换 scan 策略
  195. # - 用完后恢复原始 forward_core 避免状态污染
  196. # --------------------------------------------------------------------------
  197. class XSSMGlobalBranch2d(nn.Module):
  198. def __init__(
  199. self,
  200. channels: int,
  201. global_ratio: float = 2.0,
  202. d_state: int = 16,
  203. forward_type: str = "v3",
  204. ssm_backend: str = "auto",
  205. ) -> None:
  206. super().__init__()
  207. hidden_ratio = max(global_ratio, 1.0) # SSM 隐层缩放比例
  208. self.backend = ssm_backend
  209. self.pre = nn.Sequential(
  210. Conv2dBN(channels, channels, 1, 1, 0), # 预投影归一化
  211. nn.ReLU(inplace=True),
  212. )
  213. self.ssm = VMambaSS2D(
  214. d_model=channels,
  215. d_state=d_state,
  216. ssm_ratio=hidden_ratio,
  217. d_conv=3,
  218. dropout=0.0,
  219. initialize="v0",
  220. forward_type=forward_type,
  221. channel_first=True,
  222. )
  223. self.post = nn.Sequential(
  224. Conv2dBN(channels, channels, 1, 1, 0), # 后投影归一化
  225. nn.ReLU(inplace=True),
  226. )
  227. def forward(self, x: torch.Tensor) -> torch.Tensor:
  228. x = self.pre(x)
  229. prev_backend = None
  230. backend = self.backend.lower()
  231. if backend == "auto":
  232. backend = "oflex" if x.is_cuda else "torch"
  233. # 动态切换 SSM 后端(避免修改全局配置)
  234. if backend == "oflex" and hasattr(self.ssm, "forward_core"):
  235. prev_backend = self.ssm.forward_core
  236. self.ssm.forward_core = lambda z, _core=prev_backend: _core(
  237. z,
  238. selective_scan_backend="oflex",
  239. scan_force_torch=False,
  240. )
  241. elif backend == "torch" and hasattr(self.ssm, "forward_core"):
  242. prev_backend = self.ssm.forward_core
  243. self.ssm.forward_core = lambda z, _core=prev_backend: _core(
  244. z,
  245. selective_scan_backend="torch",
  246. scan_force_torch=True,
  247. )
  248. try:
  249. x = self.ssm(x) # SSM 全局建模
  250. finally:
  251. if prev_backend is not None:
  252. self.ssm.forward_core = prev_backend # 恢复原始后端
  253. return self.post(x)
  254. # --------------------------------------------------------------------------
  255. # XGlobalBranch2d:全局分支包装器
  256. # 为什么:提供统一接口,将 SSM 分支暴露为可开关的模块
  257. # --------------------------------------------------------------------------
  258. class XGlobalBranch2d(nn.Module):
  259. def __init__(
  260. self,
  261. channels: int,
  262. global_ratio: float = 2.0,
  263. ssm_d_state: int = 16,
  264. ssm_forward_type: str = "v3",
  265. ssm_backend: str = "auto",
  266. ) -> None:
  267. super().__init__()
  268. self.ssm_branch = XSSMGlobalBranch2d(
  269. channels=channels,
  270. global_ratio=global_ratio,
  271. d_state=ssm_d_state,
  272. forward_type=ssm_forward_type,
  273. ssm_backend=ssm_backend,
  274. )
  275. def forward(self, x: torch.Tensor) -> torch.Tensor:
  276. return self.ssm_branch(x)
  277. # --------------------------------------------------------------------------
  278. # XBranchFusion2d:多分支特征融合
  279. # 为什么:将局部/小波/全局三个分支的输出自适应加权融合
  280. # 关键行为:
  281. # - 通道拼接 → 1×1 压缩 → 通道注意力门控(Channel Attention Gate)
  282. # - 门控值经 Sigmoid 后与融合特征逐元素相乘
  283. # --------------------------------------------------------------------------
  284. class XBranchFusion2d(nn.Module):
  285. def __init__(self, channels: int, num_branches: int = 3) -> None:
  286. super().__init__()
  287. fused_channels = channels * num_branches
  288. hidden_channels = max(channels // 4, 8) # 门控网络隐藏维度
  289. self.fuse = nn.Sequential(
  290. Conv2dBN(fused_channels, channels, 1, 1, 0), # 通道降维融合
  291. nn.ReLU(inplace=True),
  292. )
  293. # 通道注意力门控
  294. self.gate = nn.Sequential(
  295. nn.AdaptiveAvgPool2d(1), # 全局平均池化 → 空间不变
  296. nn.Conv2d(fused_channels, hidden_channels, kernel_size=1, bias=True),
  297. nn.ReLU(inplace=True),
  298. nn.Conv2d(hidden_channels, channels, kernel_size=1, bias=True),
  299. nn.Sigmoid(), # 门控值 [0, 1]
  300. )
  301. def forward(self, branch_outputs: Sequence[torch.Tensor]) -> torch.Tensor:
  302. x_cat = torch.cat(list(branch_outputs), dim=1) # 拼接所有分支
  303. x_fused = self.fuse(x_cat)
  304. gate = self.gate(x_cat) # 计算通道门控
  305. return x_fused * gate # 门控加权融合
  306. # --------------------------------------------------------------------------
  307. # XTEB2d:X-Tri-Enhance-Block (2D) — 核心构建块
  308. # 为什么:将局部、小波、全局三个分支并行融合,并叠加 FFN 残差
  309. # 关键行为:
  310. # - pre_norm:先做 1×1 投影再输入多分支
  311. # - fusion:XBranchFusion2d 自适应融合三分支
  312. # - post + FFN:双层残差连接(post-fusion + FFN)
  313. # --------------------------------------------------------------------------
  314. class XTEB2d(nn.Module):
  315. def __init__(
  316. self,
  317. channels: int,
  318. global_ratio: float = 2.0,
  319. wavelet_type: str = "haar",
  320. wavelet_level: int = 1,
  321. use_wavelet_branch: bool = True,
  322. use_global_branch: bool = True,
  323. ssm_d_state: int = 16,
  324. ssm_forward_type: str = "v3",
  325. ssm_backend: str = "auto",
  326. ) -> None:
  327. super().__init__()
  328. self.pre_norm = Conv2dBN(channels, channels, 1, 1, 0) # 预投影
  329. self.local_branch = XLocalBranch2d(channels) # 局部分支(始终启用)
  330. # 小波分支(可开关)
  331. self.wavelet_branch = (
  332. XWaveletBranch2d(
  333. channels, wavelet_type=wavelet_type, wavelet_level=wavelet_level
  334. )
  335. if use_wavelet_branch
  336. else nn.Identity()
  337. )
  338. # 全局 SSM 分支(可开关)
  339. self.global_branch = (
  340. XGlobalBranch2d(
  341. channels,
  342. global_ratio=global_ratio,
  343. ssm_d_state=ssm_d_state,
  344. ssm_forward_type=ssm_forward_type,
  345. ssm_backend=ssm_backend,
  346. )
  347. if use_global_branch
  348. else nn.Identity()
  349. )
  350. self.fusion = XBranchFusion2d(channels, num_branches=3) # 三分支融合
  351. # 后处理残差块
  352. self.post = nn.Sequential(
  353. Conv2dBN(channels, channels, 3, 1, 1),
  354. nn.ReLU(inplace=True),
  355. Conv2dBN(channels, channels, 1, 1, 0, bn_weight_init=0.0), # 零初始化
  356. )
  357. # FFN 残差块
  358. self.ffn = nn.Sequential(
  359. Conv2dBN(channels, channels * 2, 1, 1, 0), # 通道扩展
  360. nn.ReLU(inplace=True),
  361. Conv2dBN(channels * 2, channels, 1, 1, 0, bn_weight_init=0.0), # 零初始化
  362. )
  363. def forward(self, x: torch.Tensor) -> torch.Tensor:
  364. x_in = x
  365. x = self.pre_norm(x)
  366. # 三分支并行 + 融合 + 残差
  367. x = x_in + self.post(
  368. self.fusion(
  369. [self.local_branch(x), self.wavelet_branch(x), self.global_branch(x)]
  370. )
  371. )
  372. # FFN 残差
  373. return x + self.ffn(x)
  374. # --------------------------------------------------------------------------
  375. # XNetEncoderStage2d:编码器阶段
  376. # 为什么:堆叠多个 XTEB2d 块作为单一编码器层级
  377. # --------------------------------------------------------------------------
  378. class XNetEncoderStage2d(nn.Module):
  379. def __init__(
  380. self,
  381. channels: int,
  382. depth: int,
  383. global_ratio: float = 2.0,
  384. wavelet_type: str = "haar",
  385. wavelet_level: int = 1,
  386. use_wavelet_branch: bool = True,
  387. use_global_branch: bool = True,
  388. ssm_d_state: int = 16,
  389. ssm_forward_type: str = "v3",
  390. ssm_backend: str = "auto",
  391. ) -> None:
  392. super().__init__()
  393. self.blocks = nn.Sequential(
  394. *[
  395. XTEB2d(
  396. channels=channels,
  397. global_ratio=global_ratio,
  398. wavelet_type=wavelet_type,
  399. wavelet_level=wavelet_level,
  400. use_wavelet_branch=use_wavelet_branch,
  401. use_global_branch=use_global_branch,
  402. ssm_d_state=ssm_d_state,
  403. ssm_forward_type=ssm_forward_type,
  404. ssm_backend=ssm_backend,
  405. )
  406. for _ in range(depth)
  407. ]
  408. )
  409. def forward(self, x: torch.Tensor) -> torch.Tensor:
  410. return self.blocks(x)
  411. # --------------------------------------------------------------------------
  412. # XNetEncoder2d:完整编码器
  413. # 为什么:Stem + 4 个阶段 + 3 个下采样 → 多尺度特征金字塔 [e1, e2, e3, e4]
  414. # 关键约束:
  415. # - 阶段数固定为 4(由构造函数校验)
  416. # - Stage1 默认关闭全局 SSM(浅层特征不适合长程建模)
  417. # - stage_channels 属性暴露各阶段输出通道数
  418. # --------------------------------------------------------------------------
  419. class XNetEncoder2d(nn.Module):
  420. def __init__(
  421. self,
  422. in_channels: int,
  423. stem_channels: int,
  424. encoder_channels: Sequence[int],
  425. encoder_depths: Sequence[int],
  426. global_ratio: float = 2.0,
  427. wavelet_type: str = "haar",
  428. wavelet_level: int = 1,
  429. use_wavelet_branch: bool = True,
  430. use_global_branch_stage1: bool = False,
  431. ssm_d_state: int = 16,
  432. ssm_forward_type: str = "v3",
  433. ssm_backend: str = "auto",
  434. ) -> None:
  435. super().__init__()
  436. if len(encoder_channels) != 4 or len(encoder_depths) != 4:
  437. raise ValueError("XNetEncoder2d expects 4 encoder stages.")
  438. c1, c2, c3, c4 = encoder_channels
  439. d1, d2, d3, d4 = encoder_depths
  440. self.stem = XNetStem2d(in_channels, stem_channels, c1)
  441. # Stage 1:浅层,可选关闭全局分支
  442. self.stage1 = XNetEncoderStage2d(
  443. c1,
  444. d1,
  445. global_ratio,
  446. wavelet_type,
  447. wavelet_level,
  448. use_wavelet_branch=use_wavelet_branch,
  449. use_global_branch=use_global_branch_stage1,
  450. ssm_d_state=ssm_d_state,
  451. ssm_forward_type=ssm_forward_type,
  452. ssm_backend=ssm_backend,
  453. )
  454. self.down1 = XNetDownsample2d(c1, c2)
  455. # Stage 2-4:始终启用全局分支
  456. self.stage2 = XNetEncoderStage2d(
  457. c2,
  458. d2,
  459. global_ratio,
  460. wavelet_type,
  461. wavelet_level,
  462. use_wavelet_branch,
  463. True,
  464. ssm_d_state=ssm_d_state,
  465. ssm_forward_type=ssm_forward_type,
  466. ssm_backend=ssm_backend,
  467. )
  468. self.down2 = XNetDownsample2d(c2, c3)
  469. self.stage3 = XNetEncoderStage2d(
  470. c3,
  471. d3,
  472. global_ratio,
  473. wavelet_type,
  474. wavelet_level,
  475. use_wavelet_branch,
  476. True,
  477. ssm_d_state=ssm_d_state,
  478. ssm_forward_type=ssm_forward_type,
  479. ssm_backend=ssm_backend,
  480. )
  481. self.down3 = XNetDownsample2d(c3, c4)
  482. self.stage4 = XNetEncoderStage2d(
  483. c4,
  484. d4,
  485. global_ratio,
  486. wavelet_type,
  487. wavelet_level,
  488. use_wavelet_branch,
  489. True,
  490. ssm_d_state=ssm_d_state,
  491. ssm_forward_type=ssm_forward_type,
  492. ssm_backend=ssm_backend,
  493. )
  494. self.stage_channels = list(encoder_channels) # 暴露各阶段通道数
  495. def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
  496. e1 = self.stage1(self.stem(x)) # 浅层特征
  497. e2 = self.stage2(self.down1(e1)) # 中层特征
  498. e3 = self.stage3(self.down2(e2)) # 深层特征
  499. e4 = self.stage4(self.down3(e3)) # 最深特征
  500. return [e1, e2, e3, e4] # 多尺度特征金字塔
  501. # --------------------------------------------------------------------------
  502. # XGuideProjector2d:引导投影器
  503. # 为什么:从编码器特征生成引导信号(guide),用于解码器的自适应调制
  504. # 关键行为:
  505. # - affine 模式:输出 (gamma, beta) 用于仿射调制
  506. # - feature 模式:直接输出特征
  507. # --------------------------------------------------------------------------
  508. class XGuideProjector2d(nn.Module):
  509. def __init__(
  510. self, in_channels: int, out_channels: int, mode: str = "affine"
  511. ) -> None:
  512. super().__init__()
  513. self.mode = mode
  514. if mode == "affine":
  515. # 输出双倍通道 → 后续拆分为 gamma 和 beta
  516. self.proj = nn.Sequential(
  517. Conv2dBN(in_channels, out_channels * 2, 1, 1, 0),
  518. nn.ReLU(inplace=True),
  519. nn.Conv2d(out_channels * 2, out_channels * 2, kernel_size=1, bias=True),
  520. )
  521. elif mode == "feature":
  522. self.proj = nn.Sequential(
  523. Conv2dBN(in_channels, out_channels, 1, 1, 0),
  524. nn.ReLU(inplace=True),
  525. )
  526. else:
  527. raise ValueError(f"Unsupported guide mode: {mode}")
  528. def forward(
  529. self,
  530. x: torch.Tensor,
  531. target_size: tuple[int, int],
  532. ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
  533. # 插值到目标尺寸(guide 需要与解码器特征空间对齐)
  534. x = F.interpolate(x, size=target_size, mode="bilinear", align_corners=False)
  535. x = self.proj(x)
  536. if self.mode == "affine":
  537. gamma, beta = torch.chunk(x, 2, dim=1) # 拆分为仿射参数
  538. gamma = torch.sigmoid(gamma) + 0.5 # gamma 偏置到 [0.5, 1.5]
  539. return gamma, beta
  540. return x
  541. # --------------------------------------------------------------------------
  542. # XSkipFusion2d:跳跃连接融合
  543. # 为什么:将编码器特征与解码器特征融合后传入
  544. # 关键行为:
  545. # - 分别投影输入和跳跃特征到相同维度
  546. # - 拼接 + 3×3 卷积融合
  547. # --------------------------------------------------------------------------
  548. class XSkipFusion2d(nn.Module):
  549. def __init__(self, in_channels: int, skip_channels: int, out_channels: int) -> None:
  550. super().__init__()
  551. self.input_proj = nn.Sequential(
  552. Conv2dBN(in_channels, out_channels, 1, 1, 0), # 解码器特征投影
  553. nn.ReLU(inplace=True),
  554. )
  555. self.skip_proj = nn.Sequential(
  556. Conv2dBN(skip_channels, out_channels, 1, 1, 0), # 跳跃特征投影
  557. nn.ReLU(inplace=True),
  558. )
  559. self.fuse = nn.Sequential(
  560. Conv2dBN(out_channels * 2, out_channels, 3, 1, 1), # 拼接后融合
  561. nn.ReLU(inplace=True),
  562. )
  563. def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
  564. # 双线性插值对齐空间尺寸
  565. x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
  566. x = self.input_proj(x)
  567. skip = self.skip_proj(skip)
  568. return self.fuse(torch.cat([x, skip], dim=1)) # 通道拼接融合
  569. # --------------------------------------------------------------------------
  570. # XGuideModulation2d:引导调制器
  571. # 为什么:对特征应用仿射调制 (gamma * x + beta) 或特征驱动调制
  572. # --------------------------------------------------------------------------
  573. class XGuideModulation2d(nn.Module):
  574. def __init__(self, channels: int, guide_mode: str = "affine") -> None:
  575. super().__init__()
  576. self.guide_mode = guide_mode
  577. if guide_mode == "feature":
  578. # feature 模式下先将 guide 转为仿射参数
  579. self.to_affine = nn.Conv2d(channels, channels * 2, kernel_size=1, bias=True)
  580. def forward(
  581. self,
  582. x: torch.Tensor,
  583. guide: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
  584. ) -> torch.Tensor:
  585. if self.guide_mode == "affine":
  586. gamma, beta = guide # 直接使用仿射参数
  587. else:
  588. gamma, beta = torch.chunk(self.to_affine(guide), 2, dim=1)
  589. gamma = torch.sigmoid(gamma) + 0.5
  590. return gamma * x + beta # 仿射调制
  591. # --------------------------------------------------------------------------
  592. # XFrequencyRefine2d:频率域精炼
  593. # 为什么:在频域对低频/高频分别应用门控,增强关键频率成分
  594. # 关键行为:
  595. # - FFT → 低频中心保留 + 高频带通 → 逆 FFT
  596. # - 门控由自适应平均池化生成
  597. # --------------------------------------------------------------------------
  598. class XFrequencyRefine2d(nn.Module):
  599. def __init__(self, channels: int) -> None:
  600. super().__init__()
  601. # 低频门控
  602. self.low_gate = nn.Sequential(
  603. nn.AdaptiveAvgPool2d(1),
  604. nn.Conv2d(channels, channels, kernel_size=1, bias=True),
  605. nn.Sigmoid(),
  606. )
  607. # 高频门控
  608. self.high_gate = nn.Sequential(
  609. nn.AdaptiveAvgPool2d(1),
  610. nn.Conv2d(channels, channels, kernel_size=1, bias=True),
  611. nn.Sigmoid(),
  612. )
  613. # 频域精炼后的空间域细化
  614. self.refine = nn.Sequential(
  615. Conv2dBN(
  616. channels, channels, 3, 1, 1, groups=channels
  617. ), # depthwise 局部细化
  618. nn.ReLU(inplace=True),
  619. Conv2dBN(channels, channels, 1, 1, 0),
  620. )
  621. def forward(self, x: torch.Tensor) -> torch.Tensor:
  622. input_dtype = x.dtype
  623. if x.dtype != torch.float32:
  624. x = x.to(torch.float32) # FFT 需要 float32 精度
  625. fft = torch.fft.rfft2(x, norm="ortho") # 实值 FFT
  626. h_freq, w_freq = fft.shape[-2], fft.shape[-1]
  627. # 构建圆形低频掩码(中心位于四个角:FFT 未 shift 时低频在四角)
  628. # 使用 fftshift 将低频移至中心,应用掩码后再 ifftshift 还原
  629. fft_shifted = torch.fft.fftshift(fft, dim=(-2, -1))
  630. low = fft_shifted.clone()
  631. # 圆形低频掩码:保留中心区域
  632. radius_h = h_freq // 4
  633. radius_w = w_freq // 4
  634. y_grid, x_grid = torch.meshgrid(
  635. torch.arange(h_freq, device=fft.device),
  636. torch.arange(w_freq, device=fft.device),
  637. indexing="ij",
  638. )
  639. center_y, center_x = h_freq // 2, w_freq // 2
  640. mask = (y_grid - center_y) ** 2 + (x_grid - center_x) ** 2 <= max(
  641. radius_h, radius_w
  642. ) ** 2
  643. mask = mask.unsqueeze(0).unsqueeze(0).expand(fft.shape[0], fft.shape[1], -1, -1)
  644. low = low * mask # 低频分量
  645. high = fft_shifted - low # 高频 = 全部 - 低频
  646. # 还原到原始 FFT 坐标系
  647. low = torch.fft.ifftshift(low, dim=(-2, -1))
  648. high = torch.fft.ifftshift(high, dim=(-2, -1))
  649. # 应用通道门控(门控值来自空间域)
  650. low = low * self.low_gate(x)
  651. high = high * self.high_gate(x)
  652. out = torch.fft.irfft2(low + high, s=x.shape[-2:], norm="ortho") # 逆 FFT
  653. out = out.to(dtype=input_dtype)
  654. return self.refine(out) # 空间域细化
  655. # --------------------------------------------------------------------------
  656. # XCRB2d:X-ResBlock with Guide (2D) — 解码器核心块
  657. # 为什么:融合跳跃连接 + 引导调制 + 频率精炼,是解码器重建的基础单元
  658. # 数据流:
  659. # 输入特征 → SkipFusion → GuideModulation → FrequencyRefine → OutRefine
  660. # 每步均有残差连接
  661. # --------------------------------------------------------------------------
  662. class XCRB2d(nn.Module):
  663. def __init__(
  664. self,
  665. in_channels: int,
  666. skip_channels: int,
  667. guide_channels: int,
  668. out_channels: int,
  669. guide_mode: str = "affine",
  670. use_frequency_refine: bool = True,
  671. ) -> None:
  672. super().__init__()
  673. self.skip_fusion = XSkipFusion2d(in_channels, skip_channels, out_channels)
  674. self.guide_modulation = XGuideModulation2d(out_channels, guide_mode=guide_mode)
  675. self.frequency_refine = (
  676. XFrequencyRefine2d(out_channels) if use_frequency_refine else nn.Identity()
  677. )
  678. # 输出细化(零初始化末尾以渐进学习)
  679. self.out_refine = nn.Sequential(
  680. Conv2dBN(out_channels, out_channels, 3, 1, 1),
  681. nn.ReLU(inplace=True),
  682. Conv2dBN(out_channels, out_channels, 3, 1, 1, bn_weight_init=0.0),
  683. )
  684. self.guide_channels = guide_channels
  685. def forward(
  686. self,
  687. x: torch.Tensor,
  688. skip: torch.Tensor,
  689. guide: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
  690. ) -> torch.Tensor:
  691. x = self.skip_fusion(x, skip) # 跳跃融合
  692. x = self.guide_modulation(x, guide) # 引导调制
  693. x = x + self.frequency_refine(x) # 频率精炼残差
  694. return x + self.out_refine(x) # 输出细化残差
  695. # --------------------------------------------------------------------------
  696. # XNetHeadRefine2d:特征精炼头
  697. # 为什么:在解码器末端做最后的特征增强
  698. # --------------------------------------------------------------------------
  699. class XNetHeadRefine2d(nn.Module):
  700. def __init__(self, channels: int, out_channels: int | None = None) -> None:
  701. super().__init__()
  702. if out_channels is None:
  703. out_channels = channels
  704. self.block = nn.Sequential(
  705. Conv2dBN(channels, out_channels, 3, 1, 1),
  706. nn.ReLU(inplace=True),
  707. Conv2dBN(out_channels, out_channels, 3, 1, 1),
  708. nn.ReLU(inplace=True),
  709. )
  710. def forward(self, x: torch.Tensor) -> torch.Tensor:
  711. return self.block(x)
  712. # --------------------------------------------------------------------------
  713. # XNetDecoder2d:完整解码器
  714. # 为什么:从最深特征 e4 逐步上采样,逐层引入引导信号和跳跃连接
  715. # 关键数据流:
  716. # e4 → guide4 → dec4 → guide3 → dec3 → guide2 → dec2 → head_refine
  717. # 返回:输出特征、所有解码特征、所有引导信号(供损失函数使用)
  718. # --------------------------------------------------------------------------
  719. class XNetDecoder2d(nn.Module):
  720. def __init__(
  721. self,
  722. encoder_channels: Sequence[int],
  723. decoder_channels: Sequence[int] = (128, 64, 32),
  724. guide_mode: str = "affine",
  725. use_frequency_refine: bool = True,
  726. out_channels: int | None = None,
  727. ) -> None:
  728. super().__init__()
  729. if len(encoder_channels) != 4:
  730. raise ValueError("XNetDecoder2d expects 4 encoder stages.")
  731. if len(decoder_channels) != 3:
  732. raise ValueError("XNetDecoder2d expects 3 decoder channels.")
  733. c1, c2, c3, c4 = encoder_channels
  734. d4, d3, d2 = decoder_channels
  735. # 引导投影器(从编码器特征生成 guide)
  736. self.guide4 = XGuideProjector2d(c4, d4, mode=guide_mode)
  737. self.guide3 = XGuideProjector2d(c3, d3, mode=guide_mode)
  738. self.guide2 = XGuideProjector2d(c2, d2, mode=guide_mode)
  739. # 解码块(逐层降通道 + 跳跃融合)
  740. self.dec4 = XCRB2d(
  741. c4,
  742. c3,
  743. d4,
  744. d4,
  745. guide_mode=guide_mode,
  746. use_frequency_refine=use_frequency_refine,
  747. )
  748. self.dec3 = XCRB2d(
  749. d4,
  750. c2,
  751. d3,
  752. d3,
  753. guide_mode=guide_mode,
  754. use_frequency_refine=use_frequency_refine,
  755. )
  756. self.dec2 = XCRB2d(
  757. d3,
  758. c1,
  759. d2,
  760. d2,
  761. guide_mode=guide_mode,
  762. use_frequency_refine=use_frequency_refine,
  763. )
  764. self.head_refine = XNetHeadRefine2d(d2, out_channels or d2)
  765. self.out_channels = out_channels or d2
  766. def forward(
  767. self,
  768. features: Sequence[torch.Tensor],
  769. ) -> tuple[
  770. torch.Tensor,
  771. list[torch.Tensor],
  772. list[torch.Tensor | tuple[torch.Tensor, torch.Tensor]],
  773. ]:
  774. e1, e2, e3, e4 = features
  775. # 从深到浅逐层解码
  776. g4 = self.guide4(e4, target_size=e3.shape[-2:]) # 从 e4 生成 guide
  777. d4 = self.dec4(e4, e3, g4) # 解码 + 跳跃 e3
  778. g3 = self.guide3(e3, target_size=e2.shape[-2:])
  779. d3 = self.dec3(d4, e2, g3) # 解码 + 跳跃 e2
  780. g2 = self.guide2(e2, target_size=e1.shape[-2:])
  781. d2 = self.dec2(d3, e1, g2) # 解码 + 跳跃 e1
  782. d1 = self.head_refine(d2) # 最终精炼
  783. # 返回解码输出、中间特征(用于辅助损失)、引导信号
  784. return d1, [d4, d3, d2, d1], [g4, g3, g2]
  785. # --------------------------------------------------------------------------
  786. # XNetSegHead2d:分割头
  787. # 为什么:将最终特征映射为 logits 图,并上采样到原始输入尺寸
  788. # --------------------------------------------------------------------------
  789. class XNetSegHead2d(nn.Module):
  790. def __init__(
  791. self, in_channels: int, num_classes: int, upsample_scale: int = 4
  792. ) -> None:
  793. super().__init__()
  794. self.block = nn.Sequential(
  795. Conv2dBN(in_channels, in_channels, 3, 1, 1),
  796. nn.ReLU(inplace=True),
  797. nn.Conv2d(
  798. in_channels, num_classes, kernel_size=1, bias=True
  799. ), # 映射到类别数
  800. )
  801. self.upsample_scale = upsample_scale
  802. def forward(self, x: torch.Tensor, output_size: tuple[int, int]) -> torch.Tensor:
  803. x = self.block(x)
  804. # 双线性上采样到目标尺寸(推理时传入原始输入 H, W)
  805. return F.interpolate(x, size=output_size, mode="bilinear", align_corners=False)
  806. # ==========================================================================
  807. # XNet2d:完整网络(编码器 + Bottleneck + 解码器 + 分割头)
  808. # 架构概览:
  809. # 输入 → Stem → [Stage1 ↓ Stage2 ↓ Stage3 ↓ Stage4] → Bottleneck
  810. # → [dec4 ← dec3 ← dec2] → Head → Logits
  811. # 业务特点:
  812. # - 编码器浅层(Stage1)默认关闭 SSM 以降低计算开销
  813. # - 解码器逐层注入 guide 信号,实现自适应特征调制
  814. # - 每个解码块支持频率精炼,增强医学图像细节保留
  815. # ==========================================================================
  816. class XNet2d(nn.Module):
  817. def __init__(
  818. self,
  819. in_channels: int,
  820. num_classes: int,
  821. encoder_channels: Sequence[int] = (32, 64, 128, 192),
  822. encoder_depths: Sequence[int] = (2, 2, 2, 2),
  823. decoder_channels: Sequence[int] = (128, 64, 32),
  824. stem_channels: int = 24,
  825. bottleneck_depth: int = 1,
  826. global_ratio: float = 2.0,
  827. wavelet_type: str = "haar",
  828. wavelet_level: int = 1,
  829. use_wavelet_branch: bool = True,
  830. use_global_branch_stage1: bool = False,
  831. ssm_d_state: int = 16,
  832. ssm_forward_type: str = "v3",
  833. ssm_backend: str = "auto",
  834. use_frequency_refine: bool = True,
  835. guide_mode: str = "affine",
  836. out_channels: int | None = None,
  837. ) -> None:
  838. super().__init__()
  839. # 编码器:多尺度特征金字塔
  840. self.encoder = XNetEncoder2d(
  841. in_channels=in_channels,
  842. stem_channels=stem_channels,
  843. encoder_channels=encoder_channels,
  844. encoder_depths=encoder_depths,
  845. global_ratio=global_ratio,
  846. wavelet_type=wavelet_type,
  847. wavelet_level=wavelet_level,
  848. use_wavelet_branch=use_wavelet_branch,
  849. use_global_branch_stage1=use_global_branch_stage1,
  850. ssm_d_state=ssm_d_state,
  851. ssm_forward_type=ssm_forward_type,
  852. ssm_backend=ssm_backend,
  853. )
  854. # Bottleneck:最深特征进一步建模
  855. bottleneck_channels = encoder_channels[-1]
  856. self.bottleneck = nn.Sequential(
  857. *[
  858. XTEB2d(
  859. channels=bottleneck_channels,
  860. global_ratio=global_ratio,
  861. wavelet_type=wavelet_type,
  862. wavelet_level=wavelet_level,
  863. use_wavelet_branch=use_wavelet_branch,
  864. use_global_branch=True, # bottleneck 始终启用全局分支
  865. ssm_d_state=ssm_d_state,
  866. ssm_forward_type=ssm_forward_type,
  867. ssm_backend=ssm_backend,
  868. )
  869. for _ in range(bottleneck_depth)
  870. ]
  871. )
  872. # 解码器
  873. self.decoder = XNetDecoder2d(
  874. encoder_channels=encoder_channels,
  875. decoder_channels=decoder_channels,
  876. guide_mode=guide_mode,
  877. use_frequency_refine=use_frequency_refine,
  878. out_channels=out_channels,
  879. )
  880. # 分割头
  881. head_in_channels = self.decoder.out_channels
  882. self.segmentation_head = XNetSegHead2d(head_in_channels, num_classes)
  883. def forward(
  884. self, x: torch.Tensor
  885. ) -> dict[
  886. str, torch.Tensor | list[torch.Tensor] | list[tuple[torch.Tensor, torch.Tensor]]
  887. ]:
  888. encoder_features = self.encoder(x) # 多尺度特征 [e1, e2, e3, e4]
  889. encoder_features[-1] = self.bottleneck(encoder_features[-1]) # bottleneck
  890. decoder_out, decoder_features, guides = self.decoder(encoder_features) # 解码
  891. output_size = x.shape[-2:]
  892. logits = self.segmentation_head(
  893. decoder_out, output_size=output_size
  894. ) # 分割 logits
  895. # 返回字典:包含 logits、中间特征(用于辅助损失)、引导信号
  896. outputs: dict[
  897. str,
  898. torch.Tensor | list[torch.Tensor] | list[tuple[torch.Tensor, torch.Tensor]],
  899. ] = {
  900. "logits": logits,
  901. "seg_logits": logits,
  902. "encoder_features": encoder_features,
  903. "decoder_features": decoder_features,
  904. "guides": guides,
  905. }
  906. return outputs