xnet_2d.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838
  1. from __future__ import annotations
  2. from collections.abc import Sequence
  3. import ptwt
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from .layers_2d import Conv2dBN
  8. from .lib_mamba.vmamba import SS2D as VMambaSS2D
  9. class XNetStem2d(nn.Module):
  10. # Stem reduces spatial size by 4x while lifting features into encoder stage 1.
  11. def __init__(self, in_channels: int, stem_channels: int, out_channels: int) -> None:
  12. super().__init__()
  13. self.block = nn.Sequential(
  14. Conv2dBN(in_channels, stem_channels, 3, 2, 1),
  15. nn.ReLU(inplace=True),
  16. Conv2dBN(stem_channels, stem_channels, 3, 1, 1, groups=stem_channels),
  17. nn.ReLU(inplace=True),
  18. Conv2dBN(stem_channels, out_channels, 1, 1, 0),
  19. nn.ReLU(inplace=True),
  20. Conv2dBN(out_channels, out_channels, 3, 2, 1),
  21. nn.ReLU(inplace=True),
  22. )
  23. def forward(self, x: torch.Tensor) -> torch.Tensor:
  24. return self.block(x)
  25. class XNetDownsample2d(nn.Module):
  26. def __init__(self, in_channels: int, out_channels: int, mode: str = "conv") -> None:
  27. super().__init__()
  28. if mode != "conv":
  29. raise ValueError(f"Unsupported downsample mode: {mode}")
  30. self.block = nn.Sequential(
  31. Conv2dBN(in_channels, out_channels, 3, 2, 1),
  32. nn.ReLU(inplace=True),
  33. )
  34. def forward(self, x: torch.Tensor) -> torch.Tensor:
  35. return self.block(x)
  36. class XLocalBranch2d(nn.Module):
  37. # Parallel depthwise branches capture short-range texture at two kernel scales.
  38. def __init__(self, channels: int) -> None:
  39. super().__init__()
  40. self.branch3 = nn.Sequential(
  41. Conv2dBN(channels, channels, 3, 1, 1, groups=channels),
  42. nn.ReLU(inplace=True),
  43. Conv2dBN(channels, channels, 1, 1, 0),
  44. )
  45. self.branch5 = nn.Sequential(
  46. Conv2dBN(channels, channels, 5, 1, 2, groups=channels),
  47. nn.ReLU(inplace=True),
  48. Conv2dBN(channels, channels, 1, 1, 0),
  49. )
  50. def forward(self, x: torch.Tensor) -> torch.Tensor:
  51. return self.branch3(x) + self.branch5(x)
  52. class XWaveletTransform2d(nn.Module):
  53. # ptwt-based wavelet decomposition/reconstruction with explicit crop so odd
  54. # input sizes round-trip to the exact original spatial shape.
  55. def __init__(
  56. self, channels: int, wavelet_type: str = "haar", wavelet_level: int = 1
  57. ) -> None:
  58. super().__init__()
  59. self.channels = channels
  60. self.wavelet_type = wavelet_type
  61. self.wavelet_level = wavelet_level
  62. def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  63. coeffs = ptwt.wavedec2(x, self.wavelet_type, level=self.wavelet_level)
  64. ll = coeffs[0]
  65. high_parts = coeffs[1]
  66. high = torch.cat(high_parts, dim=1)
  67. return ll, high
  68. def inverse(
  69. self, ll: torch.Tensor, high: torch.Tensor, output_size: tuple[int, int]
  70. ) -> torch.Tensor:
  71. lh, hl, hh = torch.chunk(high, 3, dim=1)
  72. coeffs = [ll, (lh, hl, hh)]
  73. x = ptwt.waverec2(coeffs, self.wavelet_type)
  74. return x[:, :, : output_size[0], : output_size[1]]
  75. class XWaveletBranch2d(nn.Module):
  76. # The wavelet branch learns on low/high-frequency components separately and
  77. # then reconstructs back to the original feature size.
  78. def __init__(
  79. self, channels: int, wavelet_type: str = "haar", wavelet_level: int = 1
  80. ) -> None:
  81. super().__init__()
  82. if wavelet_type != "haar":
  83. raise ValueError(f"Unsupported wavelet type: {wavelet_type}")
  84. if wavelet_level != 1:
  85. raise ValueError(
  86. "Initial XNet implementation only supports wavelet_level=1."
  87. )
  88. self.wavelet = XWaveletTransform2d(
  89. channels, wavelet_type=wavelet_type, wavelet_level=wavelet_level
  90. )
  91. self.ll_proj = nn.Sequential(
  92. Conv2dBN(channels, channels, 3, 1, 1),
  93. nn.ReLU(inplace=True),
  94. )
  95. self.high_proj = nn.Sequential(
  96. Conv2dBN(channels * 3, channels * 3, 3, 1, 1, groups=channels * 3),
  97. nn.ReLU(inplace=True),
  98. Conv2dBN(channels * 3, channels * 3, 1, 1, 0),
  99. )
  100. self.out_proj = nn.Sequential(
  101. Conv2dBN(channels, channels, 1, 1, 0),
  102. nn.ReLU(inplace=True),
  103. )
  104. def forward(self, x: torch.Tensor) -> torch.Tensor:
  105. output_size = x.shape[-2:]
  106. ll, high = self.wavelet(x)
  107. ll = self.ll_proj(ll)
  108. high = self.high_proj(high)
  109. x = self.wavelet.inverse(ll, high, output_size=output_size)
  110. return self.out_proj(x)
  111. class XSSMGlobalBranch2d(nn.Module):
  112. # The global branch wraps VMamba and switches scan backend at runtime.
  113. def __init__(
  114. self,
  115. channels: int,
  116. global_ratio: float = 2.0,
  117. d_state: int = 16,
  118. forward_type: str = "v3",
  119. ssm_backend: str = "auto",
  120. ) -> None:
  121. super().__init__()
  122. hidden_ratio = max(global_ratio, 1.0)
  123. self.backend = ssm_backend
  124. self.pre = nn.Sequential(
  125. Conv2dBN(channels, channels, 1, 1, 0),
  126. nn.ReLU(inplace=True),
  127. )
  128. self.ssm = VMambaSS2D(
  129. d_model=channels,
  130. d_state=d_state,
  131. ssm_ratio=hidden_ratio,
  132. d_conv=3,
  133. dropout=0.0,
  134. initialize="v0",
  135. forward_type=forward_type,
  136. channel_first=True,
  137. )
  138. self.post = nn.Sequential(
  139. Conv2dBN(channels, channels, 1, 1, 0),
  140. nn.ReLU(inplace=True),
  141. )
  142. def forward(self, x: torch.Tensor) -> torch.Tensor:
  143. x = self.pre(x)
  144. prev_backend = None
  145. backend = self.backend.lower()
  146. if backend == "auto":
  147. backend = "oflex" if x.is_cuda else "torch"
  148. if backend == "oflex" and hasattr(self.ssm, "forward_core"):
  149. prev_backend = self.ssm.forward_core
  150. self.ssm.forward_core = lambda z, _core=prev_backend: _core(
  151. z,
  152. selective_scan_backend="oflex",
  153. scan_force_torch=False,
  154. )
  155. elif backend == "torch" and hasattr(self.ssm, "forward_core"):
  156. prev_backend = self.ssm.forward_core
  157. self.ssm.forward_core = lambda z, _core=prev_backend: _core(
  158. z,
  159. selective_scan_backend="torch",
  160. scan_force_torch=True,
  161. )
  162. try:
  163. x = self.ssm(x)
  164. finally:
  165. if prev_backend is not None:
  166. self.ssm.forward_core = prev_backend
  167. return self.post(x)
  168. class XGlobalBranch2d(nn.Module):
  169. def __init__(
  170. self,
  171. channels: int,
  172. global_ratio: float = 2.0,
  173. ssm_d_state: int = 16,
  174. ssm_forward_type: str = "v3",
  175. ssm_backend: str = "auto",
  176. ) -> None:
  177. super().__init__()
  178. self.ssm_branch = XSSMGlobalBranch2d(
  179. channels=channels,
  180. global_ratio=global_ratio,
  181. d_state=ssm_d_state,
  182. forward_type=ssm_forward_type,
  183. ssm_backend=ssm_backend,
  184. )
  185. def forward(self, x: torch.Tensor) -> torch.Tensor:
  186. return self.ssm_branch(x)
  187. class XBranchFusion2d(nn.Module):
  188. def __init__(self, channels: int, num_branches: int = 3) -> None:
  189. super().__init__()
  190. fused_channels = channels * num_branches
  191. hidden_channels = max(channels // 4, 8)
  192. self.fuse = nn.Sequential(
  193. Conv2dBN(fused_channels, channels, 1, 1, 0),
  194. nn.ReLU(inplace=True),
  195. )
  196. self.gate = nn.Sequential(
  197. nn.AdaptiveAvgPool2d(1),
  198. nn.Conv2d(fused_channels, hidden_channels, kernel_size=1, bias=True),
  199. nn.ReLU(inplace=True),
  200. nn.Conv2d(hidden_channels, channels, kernel_size=1, bias=True),
  201. nn.Sigmoid(),
  202. )
  203. def forward(self, branch_outputs: Sequence[torch.Tensor]) -> torch.Tensor:
  204. x_cat = torch.cat(list(branch_outputs), dim=1)
  205. x_fused = self.fuse(x_cat)
  206. gate = self.gate(x_cat)
  207. return x_fused * gate
  208. class XTEB2d(nn.Module):
  209. # XTEB fuses local, wavelet, and global branches with residual post/ffn blocks.
  210. def __init__(
  211. self,
  212. channels: int,
  213. global_ratio: float = 2.0,
  214. wavelet_type: str = "haar",
  215. wavelet_level: int = 1,
  216. use_wavelet_branch: bool = True,
  217. use_global_branch: bool = True,
  218. ssm_d_state: int = 16,
  219. ssm_forward_type: str = "v3",
  220. ssm_backend: str = "auto",
  221. ) -> None:
  222. super().__init__()
  223. self.pre_norm = Conv2dBN(channels, channels, 1, 1, 0)
  224. self.local_branch = XLocalBranch2d(channels)
  225. self.wavelet_branch = (
  226. XWaveletBranch2d(
  227. channels, wavelet_type=wavelet_type, wavelet_level=wavelet_level
  228. )
  229. if use_wavelet_branch
  230. else nn.Identity()
  231. )
  232. self.global_branch = (
  233. XGlobalBranch2d(
  234. channels,
  235. global_ratio=global_ratio,
  236. ssm_d_state=ssm_d_state,
  237. ssm_forward_type=ssm_forward_type,
  238. ssm_backend=ssm_backend,
  239. )
  240. if use_global_branch
  241. else nn.Identity()
  242. )
  243. self.fusion = XBranchFusion2d(channels, num_branches=3)
  244. self.post = nn.Sequential(
  245. Conv2dBN(channels, channels, 3, 1, 1),
  246. nn.ReLU(inplace=True),
  247. Conv2dBN(channels, channels, 1, 1, 0, bn_weight_init=0.0),
  248. )
  249. self.ffn = nn.Sequential(
  250. Conv2dBN(channels, channels * 2, 1, 1, 0),
  251. nn.ReLU(inplace=True),
  252. Conv2dBN(channels * 2, channels, 1, 1, 0, bn_weight_init=0.0),
  253. )
  254. def forward(self, x: torch.Tensor) -> torch.Tensor:
  255. x_in = x
  256. x = self.pre_norm(x)
  257. x = x_in + self.post(
  258. self.fusion(
  259. [self.local_branch(x), self.wavelet_branch(x), self.global_branch(x)]
  260. )
  261. )
  262. return x + self.ffn(x)
  263. class XNetEncoderStage2d(nn.Module):
  264. def __init__(
  265. self,
  266. channels: int,
  267. depth: int,
  268. global_ratio: float = 2.0,
  269. wavelet_type: str = "haar",
  270. wavelet_level: int = 1,
  271. use_wavelet_branch: bool = True,
  272. use_global_branch: bool = True,
  273. ssm_d_state: int = 16,
  274. ssm_forward_type: str = "v3",
  275. ssm_backend: str = "auto",
  276. ) -> None:
  277. super().__init__()
  278. self.blocks = nn.Sequential(
  279. *[
  280. XTEB2d(
  281. channels=channels,
  282. global_ratio=global_ratio,
  283. wavelet_type=wavelet_type,
  284. wavelet_level=wavelet_level,
  285. use_wavelet_branch=use_wavelet_branch,
  286. use_global_branch=use_global_branch,
  287. ssm_d_state=ssm_d_state,
  288. ssm_forward_type=ssm_forward_type,
  289. ssm_backend=ssm_backend,
  290. )
  291. for _ in range(depth)
  292. ]
  293. )
  294. def forward(self, x: torch.Tensor) -> torch.Tensor:
  295. return self.blocks(x)
  296. class XNetEncoder2d(nn.Module):
  297. # The encoder is a 4-stage feature pyramid with optional stage-1 global branch.
  298. def __init__(
  299. self,
  300. in_channels: int,
  301. stem_channels: int,
  302. encoder_channels: Sequence[int],
  303. encoder_depths: Sequence[int],
  304. global_ratio: float = 2.0,
  305. wavelet_type: str = "haar",
  306. wavelet_level: int = 1,
  307. use_wavelet_branch: bool = True,
  308. use_global_branch_stage1: bool = False,
  309. ssm_d_state: int = 16,
  310. ssm_forward_type: str = "v3",
  311. ssm_backend: str = "auto",
  312. ) -> None:
  313. super().__init__()
  314. if len(encoder_channels) != 4 or len(encoder_depths) != 4:
  315. raise ValueError("XNetEncoder2d expects 4 encoder stages.")
  316. c1, c2, c3, c4 = encoder_channels
  317. d1, d2, d3, d4 = encoder_depths
  318. self.stem = XNetStem2d(in_channels, stem_channels, c1)
  319. self.stage1 = XNetEncoderStage2d(
  320. c1,
  321. d1,
  322. global_ratio,
  323. wavelet_type,
  324. wavelet_level,
  325. use_wavelet_branch=use_wavelet_branch,
  326. use_global_branch=use_global_branch_stage1,
  327. ssm_d_state=ssm_d_state,
  328. ssm_forward_type=ssm_forward_type,
  329. ssm_backend=ssm_backend,
  330. )
  331. self.down1 = XNetDownsample2d(c1, c2)
  332. self.stage2 = XNetEncoderStage2d(
  333. c2,
  334. d2,
  335. global_ratio,
  336. wavelet_type,
  337. wavelet_level,
  338. use_wavelet_branch,
  339. True,
  340. ssm_d_state=ssm_d_state,
  341. ssm_forward_type=ssm_forward_type,
  342. ssm_backend=ssm_backend,
  343. )
  344. self.down2 = XNetDownsample2d(c2, c3)
  345. self.stage3 = XNetEncoderStage2d(
  346. c3,
  347. d3,
  348. global_ratio,
  349. wavelet_type,
  350. wavelet_level,
  351. use_wavelet_branch,
  352. True,
  353. ssm_d_state=ssm_d_state,
  354. ssm_forward_type=ssm_forward_type,
  355. ssm_backend=ssm_backend,
  356. )
  357. self.down3 = XNetDownsample2d(c3, c4)
  358. self.stage4 = XNetEncoderStage2d(
  359. c4,
  360. d4,
  361. global_ratio,
  362. wavelet_type,
  363. wavelet_level,
  364. use_wavelet_branch,
  365. True,
  366. ssm_d_state=ssm_d_state,
  367. ssm_forward_type=ssm_forward_type,
  368. ssm_backend=ssm_backend,
  369. )
  370. self.stage_channels = list(encoder_channels)
  371. def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
  372. e1 = self.stage1(self.stem(x))
  373. e2 = self.stage2(self.down1(e1))
  374. e3 = self.stage3(self.down2(e2))
  375. e4 = self.stage4(self.down3(e3))
  376. return [e1, e2, e3, e4]
  377. class XGuideProjector2d(nn.Module):
  378. # Guides are projected from encoder features and aligned to decoder resolution.
  379. def __init__(
  380. self, in_channels: int, out_channels: int, mode: str = "affine"
  381. ) -> None:
  382. super().__init__()
  383. self.mode = mode
  384. if mode == "affine":
  385. self.proj = nn.Sequential(
  386. Conv2dBN(in_channels, out_channels * 2, 1, 1, 0),
  387. nn.ReLU(inplace=True),
  388. nn.Conv2d(out_channels * 2, out_channels * 2, kernel_size=1, bias=True),
  389. )
  390. elif mode == "feature":
  391. self.proj = nn.Sequential(
  392. Conv2dBN(in_channels, out_channels, 1, 1, 0),
  393. nn.ReLU(inplace=True),
  394. )
  395. else:
  396. raise ValueError(f"Unsupported guide mode: {mode}")
  397. def forward(
  398. self,
  399. x: torch.Tensor,
  400. target_size: tuple[int, int],
  401. ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
  402. x = F.interpolate(x, size=target_size, mode="bilinear", align_corners=False)
  403. x = self.proj(x)
  404. if self.mode == "affine":
  405. gamma, beta = torch.chunk(x, 2, dim=1)
  406. gamma = torch.sigmoid(gamma) + 0.5
  407. return gamma, beta
  408. return x
  409. class XSkipFusion2d(nn.Module):
  410. # Decoder input and skip feature are aligned, projected, and fused together.
  411. def __init__(self, in_channels: int, skip_channels: int, out_channels: int) -> None:
  412. super().__init__()
  413. self.input_proj = nn.Sequential(
  414. Conv2dBN(in_channels, out_channels, 1, 1, 0),
  415. nn.ReLU(inplace=True),
  416. )
  417. self.skip_proj = nn.Sequential(
  418. Conv2dBN(skip_channels, out_channels, 1, 1, 0),
  419. nn.ReLU(inplace=True),
  420. )
  421. self.fuse = nn.Sequential(
  422. Conv2dBN(out_channels * 2, out_channels, 3, 1, 1),
  423. nn.ReLU(inplace=True),
  424. )
  425. def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
  426. x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
  427. x = self.input_proj(x)
  428. skip = self.skip_proj(skip)
  429. return self.fuse(torch.cat([x, skip], dim=1))
  430. class XGuideModulation2d(nn.Module):
  431. # Apply either direct affine guide or feature-to-affine modulation.
  432. def __init__(self, channels: int, guide_mode: str = "affine") -> None:
  433. super().__init__()
  434. self.guide_mode = guide_mode
  435. if guide_mode == "feature":
  436. self.to_affine = nn.Conv2d(channels, channels * 2, kernel_size=1, bias=True)
  437. def forward(
  438. self,
  439. x: torch.Tensor,
  440. guide: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
  441. ) -> torch.Tensor:
  442. if self.guide_mode == "affine":
  443. gamma, beta = guide
  444. else:
  445. gamma, beta = torch.chunk(self.to_affine(guide), 2, dim=1)
  446. gamma = torch.sigmoid(gamma) + 0.5
  447. return gamma * x + beta
  448. class XFrequencyRefine2d(nn.Module):
  449. def __init__(
  450. self,
  451. channels: int,
  452. low_freq_radius_h: float = 0.25,
  453. low_freq_radius_w: float = 0.25,
  454. learnable_low_freq_radius: bool = True,
  455. ) -> None:
  456. super().__init__()
  457. if low_freq_radius_h <= 0.0 or low_freq_radius_w <= 0.0:
  458. raise ValueError("Low-frequency radii must be positive.")
  459. # Gates are predicted from half-spectrum magnitude statistics instead of
  460. # directly reusing spatial-domain pooled features.
  461. self.low_gate = nn.Sequential(
  462. nn.Conv2d(channels, channels, kernel_size=1, bias=True),
  463. nn.Sigmoid(),
  464. )
  465. self.high_gate = nn.Sequential(
  466. nn.Conv2d(channels, channels, kernel_size=1, bias=True),
  467. nn.Sigmoid(),
  468. )
  469. self.refine = nn.Sequential(
  470. Conv2dBN(channels, channels, 3, 1, 1, groups=channels),
  471. nn.ReLU(inplace=True),
  472. Conv2dBN(channels, channels, 1, 1, 0),
  473. )
  474. self.learnable_low_freq_radius = learnable_low_freq_radius
  475. if learnable_low_freq_radius:
  476. self.low_freq_radius_h = nn.Parameter(
  477. torch.tensor(low_freq_radius_h, dtype=torch.float32)
  478. )
  479. self.low_freq_radius_w = nn.Parameter(
  480. torch.tensor(low_freq_radius_w, dtype=torch.float32)
  481. )
  482. else:
  483. self.register_buffer(
  484. "low_freq_radius_h",
  485. torch.tensor(low_freq_radius_h, dtype=torch.float32),
  486. persistent=False,
  487. )
  488. self.register_buffer(
  489. "low_freq_radius_w",
  490. torch.tensor(low_freq_radius_w, dtype=torch.float32),
  491. persistent=False,
  492. )
  493. def _resolve_radius(
  494. self, value: torch.Tensor, max_ratio: float, device: torch.device
  495. ) -> torch.Tensor:
  496. radius = value.to(device=device, dtype=torch.float32)
  497. if self.learnable_low_freq_radius:
  498. radius = torch.sigmoid(radius) * max_ratio
  499. return torch.clamp(radius, min=1.0e-3, max=max_ratio)
  500. def _build_low_frequency_mask(
  501. self, h_freq: int, w_freq: int, device: torch.device
  502. ) -> torch.Tensor:
  503. y = torch.arange(h_freq, device=device, dtype=torch.float32)
  504. x = torch.arange(w_freq, device=device, dtype=torch.float32)
  505. y = torch.minimum(y, h_freq - y)
  506. radius_h = self._resolve_radius(self.low_freq_radius_h, 0.5, device) * max(
  507. h_freq, 1
  508. )
  509. radius_w = self._resolve_radius(self.low_freq_radius_w, 1.0, device) * max(
  510. w_freq, 1
  511. )
  512. y = y / torch.clamp(radius_h, min=1.0)
  513. x = x / torch.clamp(radius_w, min=1.0)
  514. y_grid, x_grid = torch.meshgrid(y, x, indexing="ij")
  515. mask = (y_grid.square() + x_grid.square()) <= 1.0
  516. return mask.unsqueeze(0).unsqueeze(0)
  517. def forward(self, x: torch.Tensor) -> torch.Tensor:
  518. input_dtype = x.dtype
  519. if x.dtype != torch.float32:
  520. x = x.to(torch.float32)
  521. fft = torch.fft.rfft2(x, norm="ortho")
  522. h_freq, w_freq = fft.shape[-2], fft.shape[-1]
  523. low_mask = self._build_low_frequency_mask(h_freq, w_freq, fft.device).to(
  524. dtype=x.dtype
  525. )
  526. low = fft * low_mask
  527. high = fft - low
  528. magnitude = fft.abs()
  529. low_stats = (magnitude * low_mask).mean(dim=(-2, -1), keepdim=True)
  530. high_stats = (magnitude * (1.0 - low_mask)).mean(dim=(-2, -1), keepdim=True)
  531. low = low * self.low_gate(low_stats)
  532. high = high * self.high_gate(high_stats)
  533. out = torch.fft.irfft2(low + high, s=x.shape[-2:], norm="ortho")
  534. out = out.to(dtype=input_dtype)
  535. return self.refine(out)
  536. class XCRB2d(nn.Module):
  537. # Decoder block: skip fusion -> guide modulation -> frequency refine -> residual output.
  538. def __init__(
  539. self,
  540. in_channels: int,
  541. skip_channels: int,
  542. guide_channels: int,
  543. out_channels: int,
  544. guide_mode: str = "affine",
  545. use_frequency_refine: bool = True,
  546. low_freq_radius_h: float = 0.25,
  547. low_freq_radius_w: float = 0.25,
  548. learnable_low_freq_radius: bool = True,
  549. ) -> None:
  550. super().__init__()
  551. self.skip_fusion = XSkipFusion2d(in_channels, skip_channels, out_channels)
  552. self.guide_modulation = XGuideModulation2d(out_channels, guide_mode=guide_mode)
  553. self.frequency_refine = (
  554. XFrequencyRefine2d(
  555. out_channels,
  556. low_freq_radius_h=low_freq_radius_h,
  557. low_freq_radius_w=low_freq_radius_w,
  558. learnable_low_freq_radius=learnable_low_freq_radius,
  559. )
  560. if use_frequency_refine
  561. else nn.Identity()
  562. )
  563. self.out_refine = nn.Sequential(
  564. Conv2dBN(out_channels, out_channels, 3, 1, 1),
  565. nn.ReLU(inplace=True),
  566. Conv2dBN(out_channels, out_channels, 3, 1, 1, bn_weight_init=0.0),
  567. )
  568. self.guide_channels = guide_channels
  569. def forward(
  570. self,
  571. x: torch.Tensor,
  572. skip: torch.Tensor,
  573. guide: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
  574. ) -> torch.Tensor:
  575. x = self.skip_fusion(x, skip)
  576. x = self.guide_modulation(x, guide)
  577. x = x + self.frequency_refine(x)
  578. return x + self.out_refine(x)
  579. class XNetHeadRefine2d(nn.Module):
  580. def __init__(self, channels: int, out_channels: int | None = None) -> None:
  581. super().__init__()
  582. if out_channels is None:
  583. out_channels = channels
  584. self.block = nn.Sequential(
  585. Conv2dBN(channels, out_channels, 3, 1, 1),
  586. nn.ReLU(inplace=True),
  587. Conv2dBN(out_channels, out_channels, 3, 1, 1),
  588. nn.ReLU(inplace=True),
  589. )
  590. def forward(self, x: torch.Tensor) -> torch.Tensor:
  591. return self.block(x)
  592. class XNetDecoder2d(nn.Module):
  593. def __init__(
  594. self,
  595. encoder_channels: Sequence[int],
  596. decoder_channels: Sequence[int] = (128, 64, 32),
  597. guide_mode: str = "affine",
  598. use_frequency_refine: bool = True,
  599. low_freq_radius_h: float = 0.25,
  600. low_freq_radius_w: float = 0.25,
  601. learnable_low_freq_radius: bool = True,
  602. out_channels: int | None = None,
  603. ) -> None:
  604. super().__init__()
  605. if len(encoder_channels) != 4:
  606. raise ValueError("XNetDecoder2d expects 4 encoder stages.")
  607. if len(decoder_channels) != 3:
  608. raise ValueError("XNetDecoder2d expects 3 decoder channels.")
  609. c1, c2, c3, c4 = encoder_channels
  610. d4, d3, d2 = decoder_channels
  611. self.guide4 = XGuideProjector2d(c4, d4, mode=guide_mode)
  612. self.guide3 = XGuideProjector2d(c3, d3, mode=guide_mode)
  613. self.guide2 = XGuideProjector2d(c2, d2, mode=guide_mode)
  614. self.dec4 = XCRB2d(
  615. c4,
  616. c3,
  617. d4,
  618. d4,
  619. guide_mode=guide_mode,
  620. use_frequency_refine=use_frequency_refine,
  621. low_freq_radius_h=low_freq_radius_h,
  622. low_freq_radius_w=low_freq_radius_w,
  623. learnable_low_freq_radius=learnable_low_freq_radius,
  624. )
  625. self.dec3 = XCRB2d(
  626. d4,
  627. c2,
  628. d3,
  629. d3,
  630. guide_mode=guide_mode,
  631. use_frequency_refine=use_frequency_refine,
  632. low_freq_radius_h=low_freq_radius_h,
  633. low_freq_radius_w=low_freq_radius_w,
  634. learnable_low_freq_radius=learnable_low_freq_radius,
  635. )
  636. self.dec2 = XCRB2d(
  637. d3,
  638. c1,
  639. d2,
  640. d2,
  641. guide_mode=guide_mode,
  642. use_frequency_refine=use_frequency_refine,
  643. low_freq_radius_h=low_freq_radius_h,
  644. low_freq_radius_w=low_freq_radius_w,
  645. learnable_low_freq_radius=learnable_low_freq_radius,
  646. )
  647. self.head_refine = XNetHeadRefine2d(d2, out_channels or d2)
  648. self.out_channels = out_channels or d2
  649. def forward(
  650. self,
  651. features: Sequence[torch.Tensor],
  652. ) -> tuple[
  653. torch.Tensor,
  654. list[torch.Tensor],
  655. list[torch.Tensor | tuple[torch.Tensor, torch.Tensor]],
  656. ]:
  657. e1, e2, e3, e4 = features
  658. g4 = self.guide4(e4, target_size=e3.shape[-2:])
  659. d4 = self.dec4(e4, e3, g4)
  660. g3 = self.guide3(e3, target_size=e2.shape[-2:])
  661. d3 = self.dec3(d4, e2, g3)
  662. g2 = self.guide2(e2, target_size=e1.shape[-2:])
  663. d2 = self.dec2(d3, e1, g2)
  664. d1 = self.head_refine(d2)
  665. return d1, [d4, d3, d2, d1], [g4, g3, g2]
  666. class XNetSegHead2d(nn.Module):
  667. def __init__(
  668. self, in_channels: int, num_classes: int, upsample_scale: int = 4
  669. ) -> None:
  670. super().__init__()
  671. self.block = nn.Sequential(
  672. Conv2dBN(in_channels, in_channels, 3, 1, 1),
  673. nn.ReLU(inplace=True),
  674. nn.Conv2d(in_channels, num_classes, kernel_size=1, bias=True),
  675. )
  676. self.upsample_scale = upsample_scale
  677. def forward(self, x: torch.Tensor, output_size: tuple[int, int]) -> torch.Tensor:
  678. x = self.block(x)
  679. return F.interpolate(x, size=output_size, mode="bilinear", align_corners=False)
  680. class XNet2d(nn.Module):
  681. def __init__(
  682. self,
  683. in_channels: int,
  684. num_classes: int,
  685. encoder_channels: Sequence[int] = (32, 64, 128, 192),
  686. encoder_depths: Sequence[int] = (2, 2, 2, 2),
  687. decoder_channels: Sequence[int] = (128, 64, 32),
  688. stem_channels: int = 24,
  689. bottleneck_depth: int = 1,
  690. global_ratio: float = 2.0,
  691. wavelet_type: str = "haar",
  692. wavelet_level: int = 1,
  693. use_wavelet_branch: bool = True,
  694. use_global_branch_stage1: bool = False,
  695. ssm_d_state: int = 16,
  696. ssm_forward_type: str = "v3",
  697. ssm_backend: str = "auto",
  698. use_frequency_refine: bool = True,
  699. low_freq_radius_h: float = 0.25,
  700. low_freq_radius_w: float = 0.25,
  701. learnable_low_freq_radius: bool = True,
  702. guide_mode: str = "affine",
  703. out_channels: int | None = None,
  704. ) -> None:
  705. super().__init__()
  706. self.encoder = XNetEncoder2d(
  707. in_channels=in_channels,
  708. stem_channels=stem_channels,
  709. encoder_channels=encoder_channels,
  710. encoder_depths=encoder_depths,
  711. global_ratio=global_ratio,
  712. wavelet_type=wavelet_type,
  713. wavelet_level=wavelet_level,
  714. use_wavelet_branch=use_wavelet_branch,
  715. use_global_branch_stage1=use_global_branch_stage1,
  716. ssm_d_state=ssm_d_state,
  717. ssm_forward_type=ssm_forward_type,
  718. ssm_backend=ssm_backend,
  719. )
  720. bottleneck_channels = encoder_channels[-1]
  721. self.bottleneck = nn.Sequential(
  722. *[
  723. XTEB2d(
  724. channels=bottleneck_channels,
  725. global_ratio=global_ratio,
  726. wavelet_type=wavelet_type,
  727. wavelet_level=wavelet_level,
  728. use_wavelet_branch=use_wavelet_branch,
  729. use_global_branch=True,
  730. ssm_d_state=ssm_d_state,
  731. ssm_forward_type=ssm_forward_type,
  732. ssm_backend=ssm_backend,
  733. )
  734. for _ in range(bottleneck_depth)
  735. ]
  736. )
  737. self.decoder = XNetDecoder2d(
  738. encoder_channels=encoder_channels,
  739. decoder_channels=decoder_channels,
  740. guide_mode=guide_mode,
  741. use_frequency_refine=use_frequency_refine,
  742. low_freq_radius_h=low_freq_radius_h,
  743. low_freq_radius_w=low_freq_radius_w,
  744. learnable_low_freq_radius=learnable_low_freq_radius,
  745. out_channels=out_channels,
  746. )
  747. head_in_channels = self.decoder.out_channels
  748. self.segmentation_head = XNetSegHead2d(head_in_channels, num_classes)
  749. def forward(
  750. self, x: torch.Tensor
  751. ) -> dict[
  752. str, torch.Tensor | list[torch.Tensor] | list[tuple[torch.Tensor, torch.Tensor]]
  753. ]:
  754. encoder_features = self.encoder(x)
  755. encoder_features[-1] = self.bottleneck(encoder_features[-1])
  756. decoder_out, decoder_features, guides = self.decoder(encoder_features)
  757. output_size = x.shape[-2:]
  758. logits = self.segmentation_head(decoder_out, output_size=output_size)
  759. outputs: dict[
  760. str,
  761. torch.Tensor | list[torch.Tensor] | list[tuple[torch.Tensor, torch.Tensor]],
  762. ] = {
  763. "logits": logits,
  764. "seg_logits": logits,
  765. "encoder_features": encoder_features,
  766. "decoder_features": decoder_features,
  767. "guides": guides,
  768. }
  769. return outputs