xnet_2d.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845
  1. from __future__ import annotations
  2. from collections.abc import Sequence
  3. import ptwt
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from .layers_2d import Conv2dBN
  8. from .lib_mamba.vmamba import SS2D as VMambaSS2D
  9. class XNetStem2d(nn.Module):
  10. # Stem reduces spatial size by 4x while lifting features into encoder stage 1.
  11. def __init__(self, in_channels: int, stem_channels: int, out_channels: int) -> None:
  12. super().__init__()
  13. self.block = nn.Sequential(
  14. Conv2dBN(in_channels, stem_channels, 3, 2, 1),
  15. nn.ReLU(inplace=True),
  16. Conv2dBN(stem_channels, stem_channels, 3, 1, 1, groups=stem_channels),
  17. nn.ReLU(inplace=True),
  18. Conv2dBN(stem_channels, out_channels, 1, 1, 0),
  19. nn.ReLU(inplace=True),
  20. Conv2dBN(out_channels, out_channels, 3, 2, 1),
  21. nn.ReLU(inplace=True),
  22. )
  23. def forward(self, x: torch.Tensor) -> torch.Tensor:
  24. return self.block(x)
  25. class XNetDownsample2d(nn.Module):
  26. def __init__(self, in_channels: int, out_channels: int, mode: str = "conv") -> None:
  27. super().__init__()
  28. if mode != "conv":
  29. raise ValueError(f"Unsupported downsample mode: {mode}")
  30. self.block = nn.Sequential(
  31. Conv2dBN(in_channels, out_channels, 3, 2, 1),
  32. nn.ReLU(inplace=True),
  33. )
  34. def forward(self, x: torch.Tensor) -> torch.Tensor:
  35. return self.block(x)
  36. class XLocalBranch2d(nn.Module):
  37. # Parallel depthwise branches capture short-range texture at two kernel scales.
  38. def __init__(self, channels: int) -> None:
  39. super().__init__()
  40. self.branch3 = nn.Sequential(
  41. Conv2dBN(channels, channels, 3, 1, 1, groups=channels),
  42. nn.ReLU(inplace=True),
  43. Conv2dBN(channels, channels, 1, 1, 0),
  44. )
  45. self.branch5 = nn.Sequential(
  46. Conv2dBN(channels, channels, 5, 1, 2, groups=channels),
  47. nn.ReLU(inplace=True),
  48. Conv2dBN(channels, channels, 1, 1, 0),
  49. )
  50. def forward(self, x: torch.Tensor) -> torch.Tensor:
  51. return self.branch3(x) + self.branch5(x)
  52. class XWaveletTransform2d(nn.Module):
  53. # ptwt-based wavelet decomposition/reconstruction with explicit crop so odd
  54. # input sizes round-trip to the exact original spatial shape.
  55. def __init__(
  56. self, channels: int, wavelet_type: str = "haar", wavelet_level: int = 1
  57. ) -> None:
  58. super().__init__()
  59. self.channels = channels
  60. self.wavelet_type = wavelet_type
  61. self.wavelet_level = wavelet_level
  62. def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  63. original_dtype = x.dtype
  64. with torch.autocast(device_type=x.device.type, enabled=False):
  65. coeffs = ptwt.wavedec2(
  66. x.float(), self.wavelet_type, level=self.wavelet_level
  67. )
  68. ll = coeffs[0]
  69. high_parts = coeffs[1]
  70. high = torch.cat(high_parts, dim=1)
  71. return ll.to(original_dtype), high.to(original_dtype)
  72. def inverse(
  73. self, ll: torch.Tensor, high: torch.Tensor, output_size: tuple[int, int]
  74. ) -> torch.Tensor:
  75. original_dtype = ll.dtype
  76. with torch.autocast(device_type=ll.device.type, enabled=False):
  77. lh, hl, hh = torch.chunk(high.float(), 3, dim=1)
  78. coeffs = [ll.float(), (lh, hl, hh)]
  79. x = ptwt.waverec2(coeffs, self.wavelet_type)
  80. x = x[:, :, : output_size[0], : output_size[1]]
  81. return x.to(original_dtype)
  82. class XWaveletBranch2d(nn.Module):
  83. # The wavelet branch learns on low/high-frequency components separately and
  84. # then reconstructs back to the original feature size.
  85. def __init__(
  86. self, channels: int, wavelet_type: str = "haar", wavelet_level: int = 1
  87. ) -> None:
  88. super().__init__()
  89. if wavelet_type != "haar":
  90. raise ValueError(f"Unsupported wavelet type: {wavelet_type}")
  91. if wavelet_level != 1:
  92. raise ValueError(
  93. "Initial XNet implementation only supports wavelet_level=1."
  94. )
  95. self.wavelet = XWaveletTransform2d(
  96. channels, wavelet_type=wavelet_type, wavelet_level=wavelet_level
  97. )
  98. self.ll_proj = nn.Sequential(
  99. Conv2dBN(channels, channels, 3, 1, 1),
  100. nn.ReLU(inplace=True),
  101. )
  102. self.high_proj = nn.Sequential(
  103. Conv2dBN(channels * 3, channels * 3, 3, 1, 1, groups=channels * 3),
  104. nn.ReLU(inplace=True),
  105. Conv2dBN(channels * 3, channels * 3, 1, 1, 0),
  106. )
  107. self.out_proj = nn.Sequential(
  108. Conv2dBN(channels, channels, 1, 1, 0),
  109. nn.ReLU(inplace=True),
  110. )
  111. def forward(self, x: torch.Tensor) -> torch.Tensor:
  112. output_size = x.shape[-2:]
  113. ll, high = self.wavelet(x)
  114. ll = self.ll_proj(ll)
  115. high = self.high_proj(high)
  116. x = self.wavelet.inverse(ll, high, output_size=output_size)
  117. return self.out_proj(x)
  118. class XSSMGlobalBranch2d(nn.Module):
  119. # The global branch wraps VMamba and switches scan backend at runtime.
  120. def __init__(
  121. self,
  122. channels: int,
  123. global_ratio: float = 2.0,
  124. d_state: int = 16,
  125. forward_type: str = "v3",
  126. ssm_backend: str = "auto",
  127. ) -> None:
  128. super().__init__()
  129. hidden_ratio = max(global_ratio, 1.0)
  130. self.backend = ssm_backend
  131. self.pre = nn.Sequential(
  132. Conv2dBN(channels, channels, 1, 1, 0),
  133. nn.ReLU(inplace=True),
  134. )
  135. self.ssm = VMambaSS2D(
  136. d_model=channels,
  137. d_state=d_state,
  138. ssm_ratio=hidden_ratio,
  139. d_conv=3,
  140. dropout=0.0,
  141. initialize="v0",
  142. forward_type=forward_type,
  143. channel_first=True,
  144. )
  145. self.post = nn.Sequential(
  146. Conv2dBN(channels, channels, 1, 1, 0),
  147. nn.ReLU(inplace=True),
  148. )
  149. def forward(self, x: torch.Tensor) -> torch.Tensor:
  150. x = self.pre(x)
  151. prev_backend = None
  152. backend = self.backend.lower()
  153. if backend == "auto":
  154. backend = "oflex" if x.is_cuda else "torch"
  155. if backend == "oflex" and hasattr(self.ssm, "forward_core"):
  156. prev_backend = self.ssm.forward_core
  157. self.ssm.forward_core = lambda z, _core=prev_backend: _core(
  158. z,
  159. selective_scan_backend="oflex",
  160. scan_force_torch=False,
  161. )
  162. elif backend == "torch" and hasattr(self.ssm, "forward_core"):
  163. prev_backend = self.ssm.forward_core
  164. self.ssm.forward_core = lambda z, _core=prev_backend: _core(
  165. z,
  166. selective_scan_backend="torch",
  167. scan_force_torch=True,
  168. )
  169. try:
  170. x = self.ssm(x)
  171. finally:
  172. if prev_backend is not None:
  173. self.ssm.forward_core = prev_backend
  174. return self.post(x)
  175. class XGlobalBranch2d(nn.Module):
  176. def __init__(
  177. self,
  178. channels: int,
  179. global_ratio: float = 2.0,
  180. ssm_d_state: int = 16,
  181. ssm_forward_type: str = "v3",
  182. ssm_backend: str = "auto",
  183. ) -> None:
  184. super().__init__()
  185. self.ssm_branch = XSSMGlobalBranch2d(
  186. channels=channels,
  187. global_ratio=global_ratio,
  188. d_state=ssm_d_state,
  189. forward_type=ssm_forward_type,
  190. ssm_backend=ssm_backend,
  191. )
  192. def forward(self, x: torch.Tensor) -> torch.Tensor:
  193. return self.ssm_branch(x)
  194. class XBranchFusion2d(nn.Module):
  195. def __init__(self, channels: int, num_branches: int = 3) -> None:
  196. super().__init__()
  197. fused_channels = channels * num_branches
  198. hidden_channels = max(channels // 4, 8)
  199. self.fuse = nn.Sequential(
  200. Conv2dBN(fused_channels, channels, 1, 1, 0),
  201. nn.ReLU(inplace=True),
  202. )
  203. self.gate = nn.Sequential(
  204. nn.AdaptiveAvgPool2d(1),
  205. nn.Conv2d(fused_channels, hidden_channels, kernel_size=1, bias=True),
  206. nn.ReLU(inplace=True),
  207. nn.Conv2d(hidden_channels, channels, kernel_size=1, bias=True),
  208. nn.Sigmoid(),
  209. )
  210. def forward(self, branch_outputs: Sequence[torch.Tensor]) -> torch.Tensor:
  211. x_cat = torch.cat(list(branch_outputs), dim=1)
  212. x_fused = self.fuse(x_cat)
  213. gate = self.gate(x_cat)
  214. return x_fused * gate
  215. class XTEB2d(nn.Module):
  216. # XTEB fuses local, wavelet, and global branches with residual post/ffn blocks.
  217. def __init__(
  218. self,
  219. channels: int,
  220. global_ratio: float = 2.0,
  221. wavelet_type: str = "haar",
  222. wavelet_level: int = 1,
  223. use_wavelet_branch: bool = True,
  224. use_global_branch: bool = True,
  225. ssm_d_state: int = 16,
  226. ssm_forward_type: str = "v3",
  227. ssm_backend: str = "auto",
  228. ) -> None:
  229. super().__init__()
  230. self.pre_norm = Conv2dBN(channels, channels, 1, 1, 0)
  231. self.local_branch = XLocalBranch2d(channels)
  232. self.wavelet_branch = (
  233. XWaveletBranch2d(
  234. channels, wavelet_type=wavelet_type, wavelet_level=wavelet_level
  235. )
  236. if use_wavelet_branch
  237. else nn.Identity()
  238. )
  239. self.global_branch = (
  240. XGlobalBranch2d(
  241. channels,
  242. global_ratio=global_ratio,
  243. ssm_d_state=ssm_d_state,
  244. ssm_forward_type=ssm_forward_type,
  245. ssm_backend=ssm_backend,
  246. )
  247. if use_global_branch
  248. else nn.Identity()
  249. )
  250. self.fusion = XBranchFusion2d(channels, num_branches=3)
  251. self.post = nn.Sequential(
  252. Conv2dBN(channels, channels, 3, 1, 1),
  253. nn.ReLU(inplace=True),
  254. Conv2dBN(channels, channels, 1, 1, 0, bn_weight_init=0.0),
  255. )
  256. self.ffn = nn.Sequential(
  257. Conv2dBN(channels, channels * 2, 1, 1, 0),
  258. nn.ReLU(inplace=True),
  259. Conv2dBN(channels * 2, channels, 1, 1, 0, bn_weight_init=0.0),
  260. )
  261. def forward(self, x: torch.Tensor) -> torch.Tensor:
  262. x_in = x
  263. x = self.pre_norm(x)
  264. x = x_in + self.post(
  265. self.fusion(
  266. [self.local_branch(x), self.wavelet_branch(x), self.global_branch(x)]
  267. )
  268. )
  269. return x + self.ffn(x)
  270. class XNetEncoderStage2d(nn.Module):
  271. def __init__(
  272. self,
  273. channels: int,
  274. depth: int,
  275. global_ratio: float = 2.0,
  276. wavelet_type: str = "haar",
  277. wavelet_level: int = 1,
  278. use_wavelet_branch: bool = True,
  279. use_global_branch: bool = True,
  280. ssm_d_state: int = 16,
  281. ssm_forward_type: str = "v3",
  282. ssm_backend: str = "auto",
  283. ) -> None:
  284. super().__init__()
  285. self.blocks = nn.Sequential(
  286. *[
  287. XTEB2d(
  288. channels=channels,
  289. global_ratio=global_ratio,
  290. wavelet_type=wavelet_type,
  291. wavelet_level=wavelet_level,
  292. use_wavelet_branch=use_wavelet_branch,
  293. use_global_branch=use_global_branch,
  294. ssm_d_state=ssm_d_state,
  295. ssm_forward_type=ssm_forward_type,
  296. ssm_backend=ssm_backend,
  297. )
  298. for _ in range(depth)
  299. ]
  300. )
  301. def forward(self, x: torch.Tensor) -> torch.Tensor:
  302. return self.blocks(x)
  303. class XNetEncoder2d(nn.Module):
  304. # The encoder is a 4-stage feature pyramid with optional stage-1 global branch.
  305. def __init__(
  306. self,
  307. in_channels: int,
  308. stem_channels: int,
  309. encoder_channels: Sequence[int],
  310. encoder_depths: Sequence[int],
  311. global_ratio: float = 2.0,
  312. wavelet_type: str = "haar",
  313. wavelet_level: int = 1,
  314. use_wavelet_branch: bool = True,
  315. use_global_branch_stage1: bool = False,
  316. ssm_d_state: int = 16,
  317. ssm_forward_type: str = "v3",
  318. ssm_backend: str = "auto",
  319. ) -> None:
  320. super().__init__()
  321. if len(encoder_channels) != 4 or len(encoder_depths) != 4:
  322. raise ValueError("XNetEncoder2d expects 4 encoder stages.")
  323. c1, c2, c3, c4 = encoder_channels
  324. d1, d2, d3, d4 = encoder_depths
  325. self.stem = XNetStem2d(in_channels, stem_channels, c1)
  326. self.stage1 = XNetEncoderStage2d(
  327. c1,
  328. d1,
  329. global_ratio,
  330. wavelet_type,
  331. wavelet_level,
  332. use_wavelet_branch=use_wavelet_branch,
  333. use_global_branch=use_global_branch_stage1,
  334. ssm_d_state=ssm_d_state,
  335. ssm_forward_type=ssm_forward_type,
  336. ssm_backend=ssm_backend,
  337. )
  338. self.down1 = XNetDownsample2d(c1, c2)
  339. self.stage2 = XNetEncoderStage2d(
  340. c2,
  341. d2,
  342. global_ratio,
  343. wavelet_type,
  344. wavelet_level,
  345. use_wavelet_branch,
  346. True,
  347. ssm_d_state=ssm_d_state,
  348. ssm_forward_type=ssm_forward_type,
  349. ssm_backend=ssm_backend,
  350. )
  351. self.down2 = XNetDownsample2d(c2, c3)
  352. self.stage3 = XNetEncoderStage2d(
  353. c3,
  354. d3,
  355. global_ratio,
  356. wavelet_type,
  357. wavelet_level,
  358. use_wavelet_branch,
  359. True,
  360. ssm_d_state=ssm_d_state,
  361. ssm_forward_type=ssm_forward_type,
  362. ssm_backend=ssm_backend,
  363. )
  364. self.down3 = XNetDownsample2d(c3, c4)
  365. self.stage4 = XNetEncoderStage2d(
  366. c4,
  367. d4,
  368. global_ratio,
  369. wavelet_type,
  370. wavelet_level,
  371. use_wavelet_branch,
  372. True,
  373. ssm_d_state=ssm_d_state,
  374. ssm_forward_type=ssm_forward_type,
  375. ssm_backend=ssm_backend,
  376. )
  377. self.stage_channels = list(encoder_channels)
  378. def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
  379. e1 = self.stage1(self.stem(x))
  380. e2 = self.stage2(self.down1(e1))
  381. e3 = self.stage3(self.down2(e2))
  382. e4 = self.stage4(self.down3(e3))
  383. return [e1, e2, e3, e4]
  384. class XGuideProjector2d(nn.Module):
  385. # Guides are projected from encoder features and aligned to decoder resolution.
  386. def __init__(
  387. self, in_channels: int, out_channels: int, mode: str = "affine"
  388. ) -> None:
  389. super().__init__()
  390. self.mode = mode
  391. if mode == "affine":
  392. self.proj = nn.Sequential(
  393. Conv2dBN(in_channels, out_channels * 2, 1, 1, 0),
  394. nn.ReLU(inplace=True),
  395. nn.Conv2d(out_channels * 2, out_channels * 2, kernel_size=1, bias=True),
  396. )
  397. elif mode == "feature":
  398. self.proj = nn.Sequential(
  399. Conv2dBN(in_channels, out_channels, 1, 1, 0),
  400. nn.ReLU(inplace=True),
  401. )
  402. else:
  403. raise ValueError(f"Unsupported guide mode: {mode}")
  404. def forward(
  405. self,
  406. x: torch.Tensor,
  407. target_size: tuple[int, int],
  408. ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
  409. x = F.interpolate(x, size=target_size, mode="bilinear", align_corners=False)
  410. x = self.proj(x)
  411. if self.mode == "affine":
  412. gamma, beta = torch.chunk(x, 2, dim=1)
  413. gamma = torch.sigmoid(gamma) + 0.5
  414. return gamma, beta
  415. return x
  416. class XSkipFusion2d(nn.Module):
  417. # Decoder input and skip feature are aligned, projected, and fused together.
  418. def __init__(self, in_channels: int, skip_channels: int, out_channels: int) -> None:
  419. super().__init__()
  420. self.input_proj = nn.Sequential(
  421. Conv2dBN(in_channels, out_channels, 1, 1, 0),
  422. nn.ReLU(inplace=True),
  423. )
  424. self.skip_proj = nn.Sequential(
  425. Conv2dBN(skip_channels, out_channels, 1, 1, 0),
  426. nn.ReLU(inplace=True),
  427. )
  428. self.fuse = nn.Sequential(
  429. Conv2dBN(out_channels * 2, out_channels, 3, 1, 1),
  430. nn.ReLU(inplace=True),
  431. )
  432. def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
  433. x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
  434. x = self.input_proj(x)
  435. skip = self.skip_proj(skip)
  436. return self.fuse(torch.cat([x, skip], dim=1))
  437. class XGuideModulation2d(nn.Module):
  438. # Apply either direct affine guide or feature-to-affine modulation.
  439. def __init__(self, channels: int, guide_mode: str = "affine") -> None:
  440. super().__init__()
  441. self.guide_mode = guide_mode
  442. if guide_mode == "feature":
  443. self.to_affine = nn.Conv2d(channels, channels * 2, kernel_size=1, bias=True)
  444. def forward(
  445. self,
  446. x: torch.Tensor,
  447. guide: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
  448. ) -> torch.Tensor:
  449. if self.guide_mode == "affine":
  450. gamma, beta = guide
  451. else:
  452. gamma, beta = torch.chunk(self.to_affine(guide), 2, dim=1)
  453. gamma = torch.sigmoid(gamma) + 0.5
  454. return gamma * x + beta
  455. class XFrequencyRefine2d(nn.Module):
  456. def __init__(
  457. self,
  458. channels: int,
  459. low_freq_radius_h: float = 0.25,
  460. low_freq_radius_w: float = 0.25,
  461. learnable_low_freq_radius: bool = True,
  462. ) -> None:
  463. super().__init__()
  464. if low_freq_radius_h <= 0.0 or low_freq_radius_w <= 0.0:
  465. raise ValueError("Low-frequency radii must be positive.")
  466. # Gates are predicted from half-spectrum magnitude statistics instead of
  467. # directly reusing spatial-domain pooled features.
  468. self.low_gate = nn.Sequential(
  469. nn.Conv2d(channels, channels, kernel_size=1, bias=True),
  470. nn.Sigmoid(),
  471. )
  472. self.high_gate = nn.Sequential(
  473. nn.Conv2d(channels, channels, kernel_size=1, bias=True),
  474. nn.Sigmoid(),
  475. )
  476. self.refine = nn.Sequential(
  477. Conv2dBN(channels, channels, 3, 1, 1, groups=channels),
  478. nn.ReLU(inplace=True),
  479. Conv2dBN(channels, channels, 1, 1, 0),
  480. )
  481. self.learnable_low_freq_radius = learnable_low_freq_radius
  482. if learnable_low_freq_radius:
  483. self.low_freq_radius_h = nn.Parameter(
  484. torch.tensor(low_freq_radius_h, dtype=torch.float32)
  485. )
  486. self.low_freq_radius_w = nn.Parameter(
  487. torch.tensor(low_freq_radius_w, dtype=torch.float32)
  488. )
  489. else:
  490. self.register_buffer(
  491. "low_freq_radius_h",
  492. torch.tensor(low_freq_radius_h, dtype=torch.float32),
  493. persistent=False,
  494. )
  495. self.register_buffer(
  496. "low_freq_radius_w",
  497. torch.tensor(low_freq_radius_w, dtype=torch.float32),
  498. persistent=False,
  499. )
  500. def _resolve_radius(
  501. self, value: torch.Tensor, max_ratio: float, device: torch.device
  502. ) -> torch.Tensor:
  503. radius = value.to(device=device, dtype=torch.float32)
  504. if self.learnable_low_freq_radius:
  505. radius = torch.sigmoid(radius) * max_ratio
  506. return torch.clamp(radius, min=1.0e-3, max=max_ratio)
  507. def _build_low_frequency_mask(
  508. self, h_freq: int, w_freq: int, device: torch.device
  509. ) -> torch.Tensor:
  510. y = torch.arange(h_freq, device=device, dtype=torch.float32)
  511. x = torch.arange(w_freq, device=device, dtype=torch.float32)
  512. y = torch.minimum(y, h_freq - y)
  513. radius_h = self._resolve_radius(self.low_freq_radius_h, 0.5, device) * max(
  514. h_freq, 1
  515. )
  516. radius_w = self._resolve_radius(self.low_freq_radius_w, 1.0, device) * max(
  517. w_freq, 1
  518. )
  519. y = y / torch.clamp(radius_h, min=1.0)
  520. x = x / torch.clamp(radius_w, min=1.0)
  521. y_grid, x_grid = torch.meshgrid(y, x, indexing="ij")
  522. mask = (y_grid.square() + x_grid.square()) <= 1.0
  523. return mask.unsqueeze(0).unsqueeze(0)
  524. def forward(self, x: torch.Tensor) -> torch.Tensor:
  525. input_dtype = x.dtype
  526. if x.dtype != torch.float32:
  527. x = x.to(torch.float32)
  528. fft = torch.fft.rfft2(x, norm="ortho")
  529. h_freq, w_freq = fft.shape[-2], fft.shape[-1]
  530. low_mask = self._build_low_frequency_mask(h_freq, w_freq, fft.device).to(
  531. dtype=x.dtype
  532. )
  533. low = fft * low_mask
  534. high = fft - low
  535. magnitude = fft.abs()
  536. low_stats = (magnitude * low_mask).mean(dim=(-2, -1), keepdim=True)
  537. high_stats = (magnitude * (1.0 - low_mask)).mean(dim=(-2, -1), keepdim=True)
  538. low = low * self.low_gate(low_stats)
  539. high = high * self.high_gate(high_stats)
  540. out = torch.fft.irfft2(low + high, s=x.shape[-2:], norm="ortho")
  541. out = out.to(dtype=input_dtype)
  542. return self.refine(out)
  543. class XCRB2d(nn.Module):
  544. # Decoder block: skip fusion -> guide modulation -> frequency refine -> residual output.
  545. def __init__(
  546. self,
  547. in_channels: int,
  548. skip_channels: int,
  549. guide_channels: int,
  550. out_channels: int,
  551. guide_mode: str = "affine",
  552. use_frequency_refine: bool = True,
  553. low_freq_radius_h: float = 0.25,
  554. low_freq_radius_w: float = 0.25,
  555. learnable_low_freq_radius: bool = True,
  556. ) -> None:
  557. super().__init__()
  558. self.skip_fusion = XSkipFusion2d(in_channels, skip_channels, out_channels)
  559. self.guide_modulation = XGuideModulation2d(out_channels, guide_mode=guide_mode)
  560. self.frequency_refine = (
  561. XFrequencyRefine2d(
  562. out_channels,
  563. low_freq_radius_h=low_freq_radius_h,
  564. low_freq_radius_w=low_freq_radius_w,
  565. learnable_low_freq_radius=learnable_low_freq_radius,
  566. )
  567. if use_frequency_refine
  568. else nn.Identity()
  569. )
  570. self.out_refine = nn.Sequential(
  571. Conv2dBN(out_channels, out_channels, 3, 1, 1),
  572. nn.ReLU(inplace=True),
  573. Conv2dBN(out_channels, out_channels, 3, 1, 1, bn_weight_init=0.0),
  574. )
  575. self.guide_channels = guide_channels
  576. def forward(
  577. self,
  578. x: torch.Tensor,
  579. skip: torch.Tensor,
  580. guide: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
  581. ) -> torch.Tensor:
  582. x = self.skip_fusion(x, skip)
  583. x = self.guide_modulation(x, guide)
  584. x = x + self.frequency_refine(x)
  585. return x + self.out_refine(x)
  586. class XNetHeadRefine2d(nn.Module):
  587. def __init__(self, channels: int, out_channels: int | None = None) -> None:
  588. super().__init__()
  589. if out_channels is None:
  590. out_channels = channels
  591. self.block = nn.Sequential(
  592. Conv2dBN(channels, out_channels, 3, 1, 1),
  593. nn.ReLU(inplace=True),
  594. Conv2dBN(out_channels, out_channels, 3, 1, 1),
  595. nn.ReLU(inplace=True),
  596. )
  597. def forward(self, x: torch.Tensor) -> torch.Tensor:
  598. return self.block(x)
  599. class XNetDecoder2d(nn.Module):
  600. def __init__(
  601. self,
  602. encoder_channels: Sequence[int],
  603. decoder_channels: Sequence[int] = (128, 64, 32),
  604. guide_mode: str = "affine",
  605. use_frequency_refine: bool = True,
  606. low_freq_radius_h: float = 0.25,
  607. low_freq_radius_w: float = 0.25,
  608. learnable_low_freq_radius: bool = True,
  609. out_channels: int | None = None,
  610. ) -> None:
  611. super().__init__()
  612. if len(encoder_channels) != 4:
  613. raise ValueError("XNetDecoder2d expects 4 encoder stages.")
  614. if len(decoder_channels) != 3:
  615. raise ValueError("XNetDecoder2d expects 3 decoder channels.")
  616. c1, c2, c3, c4 = encoder_channels
  617. d4, d3, d2 = decoder_channels
  618. self.guide4 = XGuideProjector2d(c4, d4, mode=guide_mode)
  619. self.guide3 = XGuideProjector2d(c3, d3, mode=guide_mode)
  620. self.guide2 = XGuideProjector2d(c2, d2, mode=guide_mode)
  621. self.dec4 = XCRB2d(
  622. c4,
  623. c3,
  624. d4,
  625. d4,
  626. guide_mode=guide_mode,
  627. use_frequency_refine=use_frequency_refine,
  628. low_freq_radius_h=low_freq_radius_h,
  629. low_freq_radius_w=low_freq_radius_w,
  630. learnable_low_freq_radius=learnable_low_freq_radius,
  631. )
  632. self.dec3 = XCRB2d(
  633. d4,
  634. c2,
  635. d3,
  636. d3,
  637. guide_mode=guide_mode,
  638. use_frequency_refine=use_frequency_refine,
  639. low_freq_radius_h=low_freq_radius_h,
  640. low_freq_radius_w=low_freq_radius_w,
  641. learnable_low_freq_radius=learnable_low_freq_radius,
  642. )
  643. self.dec2 = XCRB2d(
  644. d3,
  645. c1,
  646. d2,
  647. d2,
  648. guide_mode=guide_mode,
  649. use_frequency_refine=use_frequency_refine,
  650. low_freq_radius_h=low_freq_radius_h,
  651. low_freq_radius_w=low_freq_radius_w,
  652. learnable_low_freq_radius=learnable_low_freq_radius,
  653. )
  654. self.head_refine = XNetHeadRefine2d(d2, out_channels or d2)
  655. self.out_channels = out_channels or d2
  656. def forward(
  657. self,
  658. features: Sequence[torch.Tensor],
  659. ) -> tuple[
  660. torch.Tensor,
  661. list[torch.Tensor],
  662. list[torch.Tensor | tuple[torch.Tensor, torch.Tensor]],
  663. ]:
  664. e1, e2, e3, e4 = features
  665. g4 = self.guide4(e4, target_size=e3.shape[-2:])
  666. d4 = self.dec4(e4, e3, g4)
  667. g3 = self.guide3(e3, target_size=e2.shape[-2:])
  668. d3 = self.dec3(d4, e2, g3)
  669. g2 = self.guide2(e2, target_size=e1.shape[-2:])
  670. d2 = self.dec2(d3, e1, g2)
  671. d1 = self.head_refine(d2)
  672. return d1, [d4, d3, d2, d1], [g4, g3, g2]
  673. class XNetSegHead2d(nn.Module):
  674. def __init__(
  675. self, in_channels: int, num_classes: int, upsample_scale: int = 4
  676. ) -> None:
  677. super().__init__()
  678. self.block = nn.Sequential(
  679. Conv2dBN(in_channels, in_channels, 3, 1, 1),
  680. nn.ReLU(inplace=True),
  681. nn.Conv2d(in_channels, num_classes, kernel_size=1, bias=True),
  682. )
  683. self.upsample_scale = upsample_scale
  684. def forward(self, x: torch.Tensor, output_size: tuple[int, int]) -> torch.Tensor:
  685. x = self.block(x)
  686. return F.interpolate(x, size=output_size, mode="bilinear", align_corners=False)
  687. class XNet2d(nn.Module):
  688. def __init__(
  689. self,
  690. in_channels: int,
  691. num_classes: int,
  692. encoder_channels: Sequence[int] = (32, 64, 128, 192),
  693. encoder_depths: Sequence[int] = (2, 2, 2, 2),
  694. decoder_channels: Sequence[int] = (128, 64, 32),
  695. stem_channels: int = 24,
  696. bottleneck_depth: int = 1,
  697. global_ratio: float = 2.0,
  698. wavelet_type: str = "haar",
  699. wavelet_level: int = 1,
  700. use_wavelet_branch: bool = True,
  701. use_global_branch_stage1: bool = False,
  702. ssm_d_state: int = 16,
  703. ssm_forward_type: str = "v3",
  704. ssm_backend: str = "auto",
  705. use_frequency_refine: bool = True,
  706. low_freq_radius_h: float = 0.25,
  707. low_freq_radius_w: float = 0.25,
  708. learnable_low_freq_radius: bool = True,
  709. guide_mode: str = "affine",
  710. out_channels: int | None = None,
  711. ) -> None:
  712. super().__init__()
  713. self.encoder = XNetEncoder2d(
  714. in_channels=in_channels,
  715. stem_channels=stem_channels,
  716. encoder_channels=encoder_channels,
  717. encoder_depths=encoder_depths,
  718. global_ratio=global_ratio,
  719. wavelet_type=wavelet_type,
  720. wavelet_level=wavelet_level,
  721. use_wavelet_branch=use_wavelet_branch,
  722. use_global_branch_stage1=use_global_branch_stage1,
  723. ssm_d_state=ssm_d_state,
  724. ssm_forward_type=ssm_forward_type,
  725. ssm_backend=ssm_backend,
  726. )
  727. bottleneck_channels = encoder_channels[-1]
  728. self.bottleneck = nn.Sequential(
  729. *[
  730. XTEB2d(
  731. channels=bottleneck_channels,
  732. global_ratio=global_ratio,
  733. wavelet_type=wavelet_type,
  734. wavelet_level=wavelet_level,
  735. use_wavelet_branch=use_wavelet_branch,
  736. use_global_branch=True,
  737. ssm_d_state=ssm_d_state,
  738. ssm_forward_type=ssm_forward_type,
  739. ssm_backend=ssm_backend,
  740. )
  741. for _ in range(bottleneck_depth)
  742. ]
  743. )
  744. self.decoder = XNetDecoder2d(
  745. encoder_channels=encoder_channels,
  746. decoder_channels=decoder_channels,
  747. guide_mode=guide_mode,
  748. use_frequency_refine=use_frequency_refine,
  749. low_freq_radius_h=low_freq_radius_h,
  750. low_freq_radius_w=low_freq_radius_w,
  751. learnable_low_freq_radius=learnable_low_freq_radius,
  752. out_channels=out_channels,
  753. )
  754. head_in_channels = self.decoder.out_channels
  755. self.segmentation_head = XNetSegHead2d(head_in_channels, num_classes)
  756. def forward(
  757. self, x: torch.Tensor
  758. ) -> dict[
  759. str, torch.Tensor | list[torch.Tensor] | list[tuple[torch.Tensor, torch.Tensor]]
  760. ]:
  761. encoder_features = self.encoder(x)
  762. encoder_features[-1] = self.bottleneck(encoder_features[-1])
  763. decoder_out, decoder_features, guides = self.decoder(encoder_features)
  764. output_size = x.shape[-2:]
  765. logits = self.segmentation_head(decoder_out, output_size=output_size)
  766. outputs: dict[
  767. str,
  768. torch.Tensor | list[torch.Tensor] | list[tuple[torch.Tensor, torch.Tensor]],
  769. ] = {
  770. "logits": logits,
  771. "seg_logits": logits,
  772. "encoder_features": encoder_features,
  773. "decoder_features": decoder_features,
  774. "guides": guides,
  775. }
  776. return outputs