xnet_2d.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761
  1. from __future__ import annotations
  2. from collections.abc import Sequence
  3. import ptwt
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from .layers_2d import Conv2dBN
  8. from .lib_mamba.vmamba import SS2D as VMambaSS2D
  9. class XNetStem2d(nn.Module):
  10. # Stem reduces spatial size by 4x while lifting features into encoder stage 1.
  11. def __init__(self, in_channels: int, stem_channels: int, out_channels: int) -> None:
  12. super().__init__()
  13. self.block = nn.Sequential(
  14. Conv2dBN(in_channels, stem_channels, 3, 2, 1),
  15. nn.ReLU(inplace=True),
  16. Conv2dBN(stem_channels, stem_channels, 3, 1, 1, groups=stem_channels),
  17. nn.ReLU(inplace=True),
  18. Conv2dBN(stem_channels, out_channels, 1, 1, 0),
  19. nn.ReLU(inplace=True),
  20. Conv2dBN(out_channels, out_channels, 3, 2, 1),
  21. nn.ReLU(inplace=True),
  22. )
  23. def forward(self, x: torch.Tensor) -> torch.Tensor:
  24. return self.block(x)
  25. class XNetDownsample2d(nn.Module):
  26. def __init__(self, in_channels: int, out_channels: int, mode: str = "conv") -> None:
  27. super().__init__()
  28. if mode != "conv":
  29. raise ValueError(f"Unsupported downsample mode: {mode}")
  30. self.block = nn.Sequential(
  31. Conv2dBN(in_channels, out_channels, 3, 2, 1),
  32. nn.ReLU(inplace=True),
  33. )
  34. def forward(self, x: torch.Tensor) -> torch.Tensor:
  35. return self.block(x)
  36. class XLocalBranch2d(nn.Module):
  37. # Parallel depthwise branches capture short-range texture at two kernel scales.
  38. def __init__(self, channels: int) -> None:
  39. super().__init__()
  40. self.branch3 = nn.Sequential(
  41. Conv2dBN(channels, channels, 3, 1, 1, groups=channels),
  42. nn.ReLU(inplace=True),
  43. Conv2dBN(channels, channels, 1, 1, 0),
  44. )
  45. self.branch5 = nn.Sequential(
  46. Conv2dBN(channels, channels, 5, 1, 2, groups=channels),
  47. nn.ReLU(inplace=True),
  48. Conv2dBN(channels, channels, 1, 1, 0),
  49. )
  50. def forward(self, x: torch.Tensor) -> torch.Tensor:
  51. return self.branch3(x) + self.branch5(x)
  52. class XWaveletTransform2d(nn.Module):
  53. # ptwt-based wavelet decomposition/reconstruction with explicit crop so odd
  54. # input sizes round-trip to the exact original spatial shape.
  55. def __init__(
  56. self, channels: int, wavelet_type: str = "haar", wavelet_level: int = 1
  57. ) -> None:
  58. super().__init__()
  59. self.channels = channels
  60. self.wavelet_type = wavelet_type
  61. self.wavelet_level = wavelet_level
  62. def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  63. original_dtype = x.dtype
  64. with torch.autocast(device_type=x.device.type, enabled=False):
  65. coeffs = ptwt.wavedec2(
  66. x.float(), self.wavelet_type, level=self.wavelet_level
  67. )
  68. ll = coeffs[0]
  69. high_parts = coeffs[1]
  70. high = torch.cat(high_parts, dim=1)
  71. return ll.to(original_dtype), high.to(original_dtype)
  72. def inverse(
  73. self, ll: torch.Tensor, high: torch.Tensor, output_size: tuple[int, int]
  74. ) -> torch.Tensor:
  75. original_dtype = ll.dtype
  76. with torch.autocast(device_type=ll.device.type, enabled=False):
  77. lh, hl, hh = torch.chunk(high.float(), 3, dim=1)
  78. coeffs = [ll.float(), (lh, hl, hh)]
  79. x = ptwt.waverec2(coeffs, self.wavelet_type)
  80. x = x[:, :, : output_size[0], : output_size[1]]
  81. return x.to(original_dtype)
  82. class XWaveletBranch2d(nn.Module):
  83. # The wavelet branch learns on low/high-frequency components separately and
  84. # then reconstructs back to the original feature size.
  85. def __init__(
  86. self, channels: int, wavelet_type: str = "haar", wavelet_level: int = 1
  87. ) -> None:
  88. super().__init__()
  89. if wavelet_type != "haar":
  90. raise ValueError(f"Unsupported wavelet type: {wavelet_type}")
  91. if wavelet_level != 1:
  92. raise ValueError(
  93. "Initial XNet implementation only supports wavelet_level=1."
  94. )
  95. self.wavelet = XWaveletTransform2d(
  96. channels, wavelet_type=wavelet_type, wavelet_level=wavelet_level
  97. )
  98. self.ll_proj = nn.Sequential(
  99. Conv2dBN(channels, channels, 3, 1, 1),
  100. nn.ReLU(inplace=True),
  101. )
  102. self.high_proj = nn.Sequential(
  103. Conv2dBN(channels * 3, channels * 3, 3, 1, 1, groups=channels * 3),
  104. nn.ReLU(inplace=True),
  105. Conv2dBN(channels * 3, channels * 3, 1, 1, 0),
  106. )
  107. self.out_proj = nn.Sequential(
  108. Conv2dBN(channels, channels, 1, 1, 0),
  109. nn.ReLU(inplace=True),
  110. )
  111. def forward(self, x: torch.Tensor) -> torch.Tensor:
  112. output_size = x.shape[-2:]
  113. ll, high = self.wavelet(x)
  114. ll = self.ll_proj(ll)
  115. high = self.high_proj(high)
  116. x = self.wavelet.inverse(ll, high, output_size=output_size)
  117. return self.out_proj(x)
  118. class XSSMGlobalBranch2d(nn.Module):
  119. # The global branch wraps VMamba and switches scan backend at runtime.
  120. def __init__(
  121. self,
  122. channels: int,
  123. global_ratio: float = 2.0,
  124. d_state: int = 16,
  125. forward_type: str = "v3",
  126. ssm_backend: str = "auto",
  127. ) -> None:
  128. super().__init__()
  129. hidden_ratio = max(global_ratio, 1.0)
  130. self.backend = ssm_backend
  131. self.pre = nn.Sequential(
  132. Conv2dBN(channels, channels, 1, 1, 0),
  133. nn.ReLU(inplace=True),
  134. )
  135. self.ssm = VMambaSS2D(
  136. d_model=channels,
  137. d_state=d_state,
  138. ssm_ratio=hidden_ratio,
  139. d_conv=3,
  140. dropout=0.0,
  141. initialize="v0",
  142. forward_type=forward_type,
  143. channel_first=True,
  144. )
  145. self.post = nn.Sequential(
  146. Conv2dBN(channels, channels, 1, 1, 0),
  147. nn.ReLU(inplace=True),
  148. )
  149. def forward(self, x: torch.Tensor) -> torch.Tensor:
  150. x = self.pre(x)
  151. prev_backend = None
  152. backend = self.backend.lower()
  153. if backend == "auto":
  154. backend = "oflex" if x.is_cuda else "torch"
  155. if backend == "oflex" and hasattr(self.ssm, "forward_core"):
  156. prev_backend = self.ssm.forward_core
  157. self.ssm.forward_core = lambda z, _core=prev_backend: _core(
  158. z,
  159. selective_scan_backend="oflex",
  160. scan_force_torch=False,
  161. )
  162. elif backend == "torch" and hasattr(self.ssm, "forward_core"):
  163. prev_backend = self.ssm.forward_core
  164. self.ssm.forward_core = lambda z, _core=prev_backend: _core(
  165. z,
  166. selective_scan_backend="torch",
  167. scan_force_torch=True,
  168. )
  169. try:
  170. x = self.ssm(x)
  171. finally:
  172. if prev_backend is not None:
  173. self.ssm.forward_core = prev_backend
  174. return self.post(x)
  175. class XGlobalBranch2d(nn.Module):
  176. def __init__(
  177. self,
  178. channels: int,
  179. global_ratio: float = 2.0,
  180. ssm_d_state: int = 16,
  181. ssm_forward_type: str = "v3",
  182. ssm_backend: str = "auto",
  183. ) -> None:
  184. super().__init__()
  185. self.ssm_branch = XSSMGlobalBranch2d(
  186. channels=channels,
  187. global_ratio=global_ratio,
  188. d_state=ssm_d_state,
  189. forward_type=ssm_forward_type,
  190. ssm_backend=ssm_backend,
  191. )
  192. def forward(self, x: torch.Tensor) -> torch.Tensor:
  193. return self.ssm_branch(x)
  194. class XBranchFusion2d(nn.Module):
  195. def __init__(self, channels: int, num_branches: int = 3) -> None:
  196. super().__init__()
  197. fused_channels = channels * num_branches
  198. hidden_channels = max(channels // 4, 8)
  199. self.fuse = nn.Sequential(
  200. Conv2dBN(fused_channels, channels, 1, 1, 0),
  201. nn.ReLU(inplace=True),
  202. )
  203. self.gate = nn.Sequential(
  204. nn.AdaptiveAvgPool2d(1),
  205. nn.Conv2d(fused_channels, hidden_channels, kernel_size=1, bias=True),
  206. nn.ReLU(inplace=True),
  207. nn.Conv2d(hidden_channels, channels, kernel_size=1, bias=True),
  208. nn.Sigmoid(),
  209. )
  210. def forward(self, branch_outputs: Sequence[torch.Tensor]) -> torch.Tensor:
  211. x_cat = torch.cat(list(branch_outputs), dim=1)
  212. x_fused = self.fuse(x_cat)
  213. gate = self.gate(x_cat)
  214. return x_fused * gate
  215. class XTEB2d(nn.Module):
  216. # XTEB fuses local, wavelet, and global branches with residual post/ffn blocks.
  217. def __init__(
  218. self,
  219. channels: int,
  220. global_ratio: float = 2.0,
  221. wavelet_type: str = "haar",
  222. wavelet_level: int = 1,
  223. use_wavelet_branch: bool = True,
  224. use_global_branch: bool = True,
  225. ssm_d_state: int = 16,
  226. ssm_forward_type: str = "v3",
  227. ssm_backend: str = "auto",
  228. ) -> None:
  229. super().__init__()
  230. self.pre_norm = Conv2dBN(channels, channels, 1, 1, 0)
  231. self.local_branch = XLocalBranch2d(channels)
  232. self.wavelet_branch = (
  233. XWaveletBranch2d(
  234. channels, wavelet_type=wavelet_type, wavelet_level=wavelet_level
  235. )
  236. if use_wavelet_branch
  237. else nn.Identity()
  238. )
  239. self.global_branch = (
  240. XGlobalBranch2d(
  241. channels,
  242. global_ratio=global_ratio,
  243. ssm_d_state=ssm_d_state,
  244. ssm_forward_type=ssm_forward_type,
  245. ssm_backend=ssm_backend,
  246. )
  247. if use_global_branch
  248. else nn.Identity()
  249. )
  250. self.fusion = XBranchFusion2d(channels, num_branches=3)
  251. self.post = nn.Sequential(
  252. Conv2dBN(channels, channels, 3, 1, 1),
  253. nn.ReLU(inplace=True),
  254. Conv2dBN(channels, channels, 1, 1, 0, bn_weight_init=0.0),
  255. )
  256. self.ffn = nn.Sequential(
  257. Conv2dBN(channels, channels * 2, 1, 1, 0),
  258. nn.ReLU(inplace=True),
  259. Conv2dBN(channels * 2, channels, 1, 1, 0, bn_weight_init=0.0),
  260. )
  261. def forward(self, x: torch.Tensor) -> torch.Tensor:
  262. x_in = x
  263. x = self.pre_norm(x)
  264. x = x_in + self.post(
  265. self.fusion(
  266. [self.local_branch(x), self.wavelet_branch(x), self.global_branch(x)]
  267. )
  268. )
  269. return x + self.ffn(x)
  270. class XNetEncoderStage2d(nn.Module):
  271. def __init__(
  272. self,
  273. channels: int,
  274. depth: int,
  275. global_ratio: float = 2.0,
  276. wavelet_type: str = "haar",
  277. wavelet_level: int = 1,
  278. use_wavelet_branch: bool = True,
  279. use_global_branch: bool = True,
  280. ssm_d_state: int = 16,
  281. ssm_forward_type: str = "v3",
  282. ssm_backend: str = "auto",
  283. ) -> None:
  284. super().__init__()
  285. self.blocks = nn.Sequential(
  286. *[
  287. XTEB2d(
  288. channels=channels,
  289. global_ratio=global_ratio,
  290. wavelet_type=wavelet_type,
  291. wavelet_level=wavelet_level,
  292. use_wavelet_branch=use_wavelet_branch,
  293. use_global_branch=use_global_branch,
  294. ssm_d_state=ssm_d_state,
  295. ssm_forward_type=ssm_forward_type,
  296. ssm_backend=ssm_backend,
  297. )
  298. for _ in range(depth)
  299. ]
  300. )
  301. def forward(self, x: torch.Tensor) -> torch.Tensor:
  302. return self.blocks(x)
  303. class XNetEncoder2d(nn.Module):
  304. # The encoder is a 4-stage feature pyramid with optional stage-1 global branch.
  305. def __init__(
  306. self,
  307. in_channels: int,
  308. stem_channels: int,
  309. encoder_channels: Sequence[int],
  310. encoder_depths: Sequence[int],
  311. global_ratio: float = 2.0,
  312. wavelet_type: str = "haar",
  313. wavelet_level: int = 1,
  314. use_wavelet_branch: bool = True,
  315. use_global_branch_stage1: bool = False,
  316. ssm_d_state: int = 16,
  317. ssm_forward_type: str = "v3",
  318. ssm_backend: str = "auto",
  319. ) -> None:
  320. super().__init__()
  321. if len(encoder_channels) != 4 or len(encoder_depths) != 4:
  322. raise ValueError("XNetEncoder2d expects 4 encoder stages.")
  323. c1, c2, c3, c4 = encoder_channels
  324. d1, d2, d3, d4 = encoder_depths
  325. self.stem = XNetStem2d(in_channels, stem_channels, c1)
  326. self.stage1 = XNetEncoderStage2d(
  327. c1,
  328. d1,
  329. global_ratio,
  330. wavelet_type,
  331. wavelet_level,
  332. use_wavelet_branch=use_wavelet_branch,
  333. use_global_branch=use_global_branch_stage1,
  334. ssm_d_state=ssm_d_state,
  335. ssm_forward_type=ssm_forward_type,
  336. ssm_backend=ssm_backend,
  337. )
  338. self.down1 = XNetDownsample2d(c1, c2)
  339. self.stage2 = XNetEncoderStage2d(
  340. c2,
  341. d2,
  342. global_ratio,
  343. wavelet_type,
  344. wavelet_level,
  345. use_wavelet_branch,
  346. True,
  347. ssm_d_state=ssm_d_state,
  348. ssm_forward_type=ssm_forward_type,
  349. ssm_backend=ssm_backend,
  350. )
  351. self.down2 = XNetDownsample2d(c2, c3)
  352. self.stage3 = XNetEncoderStage2d(
  353. c3,
  354. d3,
  355. global_ratio,
  356. wavelet_type,
  357. wavelet_level,
  358. use_wavelet_branch,
  359. True,
  360. ssm_d_state=ssm_d_state,
  361. ssm_forward_type=ssm_forward_type,
  362. ssm_backend=ssm_backend,
  363. )
  364. self.down3 = XNetDownsample2d(c3, c4)
  365. self.stage4 = XNetEncoderStage2d(
  366. c4,
  367. d4,
  368. global_ratio,
  369. wavelet_type,
  370. wavelet_level,
  371. use_wavelet_branch,
  372. True,
  373. ssm_d_state=ssm_d_state,
  374. ssm_forward_type=ssm_forward_type,
  375. ssm_backend=ssm_backend,
  376. )
  377. self.stage_channels = list(encoder_channels)
  378. def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
  379. e1 = self.stage1(self.stem(x))
  380. e2 = self.stage2(self.down1(e1))
  381. e3 = self.stage3(self.down2(e2))
  382. e4 = self.stage4(self.down3(e3))
  383. return [e1, e2, e3, e4]
  384. class XSkipFusion2d(nn.Module):
  385. # Decoder input and skip feature are aligned, projected, and fused together.
  386. def __init__(self, in_channels: int, skip_channels: int, out_channels: int) -> None:
  387. super().__init__()
  388. self.input_proj = nn.Sequential(
  389. Conv2dBN(in_channels, out_channels, 1, 1, 0),
  390. nn.ReLU(inplace=True),
  391. )
  392. self.skip_proj = nn.Sequential(
  393. Conv2dBN(skip_channels, out_channels, 1, 1, 0),
  394. nn.ReLU(inplace=True),
  395. )
  396. self.fuse = nn.Sequential(
  397. Conv2dBN(out_channels * 2, out_channels, 3, 1, 1),
  398. nn.ReLU(inplace=True),
  399. )
  400. def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
  401. x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
  402. x = self.input_proj(x)
  403. skip = self.skip_proj(skip)
  404. return self.fuse(torch.cat([x, skip], dim=1))
  405. class XFrequencyRefine2d(nn.Module):
  406. def __init__(
  407. self,
  408. channels: int,
  409. low_freq_radius_h: float = 0.25,
  410. low_freq_radius_w: float = 0.25,
  411. learnable_low_freq_radius: bool = True,
  412. ) -> None:
  413. super().__init__()
  414. if low_freq_radius_h <= 0.0 or low_freq_radius_w <= 0.0:
  415. raise ValueError("Low-frequency radii must be positive.")
  416. # Gates are predicted from half-spectrum magnitude statistics instead of
  417. # directly reusing spatial-domain pooled features.
  418. self.low_gate = nn.Sequential(
  419. nn.Conv2d(channels, channels, kernel_size=1, bias=True),
  420. nn.Sigmoid(),
  421. )
  422. self.high_gate = nn.Sequential(
  423. nn.Conv2d(channels, channels, kernel_size=1, bias=True),
  424. nn.Sigmoid(),
  425. )
  426. self.refine = nn.Sequential(
  427. Conv2dBN(channels, channels, 3, 1, 1, groups=channels),
  428. nn.ReLU(inplace=True),
  429. Conv2dBN(channels, channels, 1, 1, 0),
  430. )
  431. self.learnable_low_freq_radius = learnable_low_freq_radius
  432. if learnable_low_freq_radius:
  433. self.low_freq_radius_h = nn.Parameter(
  434. torch.tensor(low_freq_radius_h, dtype=torch.float32)
  435. )
  436. self.low_freq_radius_w = nn.Parameter(
  437. torch.tensor(low_freq_radius_w, dtype=torch.float32)
  438. )
  439. else:
  440. self.register_buffer(
  441. "low_freq_radius_h",
  442. torch.tensor(low_freq_radius_h, dtype=torch.float32),
  443. persistent=False,
  444. )
  445. self.register_buffer(
  446. "low_freq_radius_w",
  447. torch.tensor(low_freq_radius_w, dtype=torch.float32),
  448. persistent=False,
  449. )
  450. def _resolve_radius(
  451. self, value: torch.Tensor, max_ratio: float, device: torch.device
  452. ) -> torch.Tensor:
  453. radius = value.to(device=device, dtype=torch.float32)
  454. if self.learnable_low_freq_radius:
  455. radius = torch.sigmoid(radius) * max_ratio
  456. return torch.clamp(radius, min=1.0e-3, max=max_ratio)
  457. def _build_low_frequency_mask(
  458. self, h_freq: int, w_freq: int, device: torch.device
  459. ) -> torch.Tensor:
  460. y = torch.arange(h_freq, device=device, dtype=torch.float32)
  461. x = torch.arange(w_freq, device=device, dtype=torch.float32)
  462. y = torch.minimum(y, h_freq - y)
  463. radius_h = self._resolve_radius(self.low_freq_radius_h, 0.5, device) * max(
  464. h_freq, 1
  465. )
  466. radius_w = self._resolve_radius(self.low_freq_radius_w, 1.0, device) * max(
  467. w_freq, 1
  468. )
  469. y = y / torch.clamp(radius_h, min=1.0)
  470. x = x / torch.clamp(radius_w, min=1.0)
  471. y_grid, x_grid = torch.meshgrid(y, x, indexing="ij")
  472. mask = (y_grid.square() + x_grid.square()) <= 1.0
  473. return mask.unsqueeze(0).unsqueeze(0)
  474. def forward(self, x: torch.Tensor) -> torch.Tensor:
  475. input_dtype = x.dtype
  476. if x.dtype != torch.float32:
  477. x = x.to(torch.float32)
  478. fft = torch.fft.rfft2(x, norm="ortho")
  479. h_freq, w_freq = fft.shape[-2], fft.shape[-1]
  480. low_mask = self._build_low_frequency_mask(h_freq, w_freq, fft.device).to(
  481. dtype=x.dtype
  482. )
  483. low = fft * low_mask
  484. high = fft - low
  485. magnitude = fft.abs()
  486. low_stats = (magnitude * low_mask).mean(dim=(-2, -1), keepdim=True)
  487. high_stats = (magnitude * (1.0 - low_mask)).mean(dim=(-2, -1), keepdim=True)
  488. low = low * self.low_gate(low_stats)
  489. high = high * self.high_gate(high_stats)
  490. out = torch.fft.irfft2(low + high, s=x.shape[-2:], norm="ortho")
  491. out = out.to(dtype=input_dtype)
  492. return self.refine(out)
  493. class XCRB2d(nn.Module):
  494. # Decoder block: U-Net skip fusion -> frequency refine -> residual output.
  495. def __init__(
  496. self,
  497. in_channels: int,
  498. skip_channels: int,
  499. out_channels: int,
  500. use_frequency_refine: bool = True,
  501. low_freq_radius_h: float = 0.25,
  502. low_freq_radius_w: float = 0.25,
  503. learnable_low_freq_radius: bool = True,
  504. ) -> None:
  505. super().__init__()
  506. self.skip_fusion = XSkipFusion2d(in_channels, skip_channels, out_channels)
  507. self.frequency_refine = (
  508. XFrequencyRefine2d(
  509. out_channels,
  510. low_freq_radius_h=low_freq_radius_h,
  511. low_freq_radius_w=low_freq_radius_w,
  512. learnable_low_freq_radius=learnable_low_freq_radius,
  513. )
  514. if use_frequency_refine
  515. else nn.Identity()
  516. )
  517. self.out_refine = nn.Sequential(
  518. Conv2dBN(out_channels, out_channels, 3, 1, 1),
  519. nn.ReLU(inplace=True),
  520. Conv2dBN(out_channels, out_channels, 3, 1, 1, bn_weight_init=0.0),
  521. )
  522. def forward(
  523. self,
  524. x: torch.Tensor,
  525. skip: torch.Tensor,
  526. ) -> torch.Tensor:
  527. x = self.skip_fusion(x, skip)
  528. x = x + self.frequency_refine(x)
  529. return x + self.out_refine(x)
  530. class XNetHeadRefine2d(nn.Module):
  531. def __init__(self, channels: int, out_channels: int | None = None) -> None:
  532. super().__init__()
  533. if out_channels is None:
  534. out_channels = channels
  535. self.block = nn.Sequential(
  536. Conv2dBN(channels, out_channels, 3, 1, 1),
  537. nn.ReLU(inplace=True),
  538. Conv2dBN(out_channels, out_channels, 3, 1, 1),
  539. nn.ReLU(inplace=True),
  540. )
  541. def forward(self, x: torch.Tensor) -> torch.Tensor:
  542. return self.block(x)
  543. class XNetDecoder2d(nn.Module):
  544. def __init__(
  545. self,
  546. encoder_channels: Sequence[int],
  547. decoder_channels: Sequence[int] = (128, 64, 32),
  548. guide_mode: str = "affine",
  549. use_frequency_refine: bool = True,
  550. low_freq_radius_h: float = 0.25,
  551. low_freq_radius_w: float = 0.25,
  552. learnable_low_freq_radius: bool = True,
  553. out_channels: int | None = None,
  554. ) -> None:
  555. super().__init__()
  556. if len(encoder_channels) != 4:
  557. raise ValueError("XNetDecoder2d expects 4 encoder stages.")
  558. if len(decoder_channels) != 3:
  559. raise ValueError("XNetDecoder2d expects 3 decoder channels.")
  560. c1, c2, c3, c4 = encoder_channels
  561. d4, d3, d2 = decoder_channels
  562. self.guide_mode = guide_mode
  563. self.dec4 = XCRB2d(
  564. c4,
  565. c3,
  566. d4,
  567. use_frequency_refine=use_frequency_refine,
  568. low_freq_radius_h=low_freq_radius_h,
  569. low_freq_radius_w=low_freq_radius_w,
  570. learnable_low_freq_radius=learnable_low_freq_radius,
  571. )
  572. self.dec3 = XCRB2d(
  573. d4,
  574. c2,
  575. d3,
  576. use_frequency_refine=use_frequency_refine,
  577. low_freq_radius_h=low_freq_radius_h,
  578. low_freq_radius_w=low_freq_radius_w,
  579. learnable_low_freq_radius=learnable_low_freq_radius,
  580. )
  581. self.dec2 = XCRB2d(
  582. d3,
  583. c1,
  584. d2,
  585. use_frequency_refine=use_frequency_refine,
  586. low_freq_radius_h=low_freq_radius_h,
  587. low_freq_radius_w=low_freq_radius_w,
  588. learnable_low_freq_radius=learnable_low_freq_radius,
  589. )
  590. self.head_refine = XNetHeadRefine2d(d2, out_channels or d2)
  591. self.out_channels = out_channels or d2
  592. def forward(
  593. self,
  594. features: Sequence[torch.Tensor],
  595. ) -> tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]:
  596. e1, e2, e3, e4 = features
  597. d4 = self.dec4(e4, e3)
  598. d3 = self.dec3(d4, e2)
  599. d2 = self.dec2(d3, e1)
  600. d1 = self.head_refine(d2)
  601. return d1, [d4, d3, d2, d1], []
  602. class XNetSegHead2d(nn.Module):
  603. def __init__(
  604. self, in_channels: int, num_classes: int, upsample_scale: int = 4
  605. ) -> None:
  606. super().__init__()
  607. self.block = nn.Sequential(
  608. Conv2dBN(in_channels, in_channels, 3, 1, 1),
  609. nn.ReLU(inplace=True),
  610. nn.Conv2d(in_channels, num_classes, kernel_size=1, bias=True),
  611. )
  612. self.upsample_scale = upsample_scale
  613. def forward(self, x: torch.Tensor, output_size: tuple[int, int]) -> torch.Tensor:
  614. x = self.block(x)
  615. return F.interpolate(x, size=output_size, mode="bilinear", align_corners=False)
  616. class XNet2d(nn.Module):
  617. def __init__(
  618. self,
  619. in_channels: int,
  620. num_classes: int,
  621. encoder_channels: Sequence[int] = (32, 64, 128, 192),
  622. encoder_depths: Sequence[int] = (2, 2, 2, 2),
  623. decoder_channels: Sequence[int] = (128, 64, 32),
  624. stem_channels: int = 24,
  625. bottleneck_depth: int = 1,
  626. global_ratio: float = 2.0,
  627. wavelet_type: str = "haar",
  628. wavelet_level: int = 1,
  629. use_wavelet_branch: bool = True,
  630. use_global_branch_stage1: bool = False,
  631. ssm_d_state: int = 16,
  632. ssm_forward_type: str = "v3",
  633. ssm_backend: str = "auto",
  634. use_frequency_refine: bool = True,
  635. low_freq_radius_h: float = 0.25,
  636. low_freq_radius_w: float = 0.25,
  637. learnable_low_freq_radius: bool = True,
  638. guide_mode: str = "affine",
  639. out_channels: int | None = None,
  640. ) -> None:
  641. super().__init__()
  642. self.encoder = XNetEncoder2d(
  643. in_channels=in_channels,
  644. stem_channels=stem_channels,
  645. encoder_channels=encoder_channels,
  646. encoder_depths=encoder_depths,
  647. global_ratio=global_ratio,
  648. wavelet_type=wavelet_type,
  649. wavelet_level=wavelet_level,
  650. use_wavelet_branch=use_wavelet_branch,
  651. use_global_branch_stage1=use_global_branch_stage1,
  652. ssm_d_state=ssm_d_state,
  653. ssm_forward_type=ssm_forward_type,
  654. ssm_backend=ssm_backend,
  655. )
  656. bottleneck_channels = encoder_channels[-1]
  657. self.bottleneck = nn.Sequential(
  658. *[
  659. XTEB2d(
  660. channels=bottleneck_channels,
  661. global_ratio=global_ratio,
  662. wavelet_type=wavelet_type,
  663. wavelet_level=wavelet_level,
  664. use_wavelet_branch=use_wavelet_branch,
  665. use_global_branch=True,
  666. ssm_d_state=ssm_d_state,
  667. ssm_forward_type=ssm_forward_type,
  668. ssm_backend=ssm_backend,
  669. )
  670. for _ in range(bottleneck_depth)
  671. ]
  672. )
  673. self.decoder = XNetDecoder2d(
  674. encoder_channels=encoder_channels,
  675. decoder_channels=decoder_channels,
  676. guide_mode=guide_mode,
  677. use_frequency_refine=use_frequency_refine,
  678. low_freq_radius_h=low_freq_radius_h,
  679. low_freq_radius_w=low_freq_radius_w,
  680. learnable_low_freq_radius=learnable_low_freq_radius,
  681. out_channels=out_channels,
  682. )
  683. head_in_channels = self.decoder.out_channels
  684. self.segmentation_head = XNetSegHead2d(head_in_channels, num_classes)
  685. def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor | list[torch.Tensor]]:
  686. encoder_features = self.encoder(x)
  687. encoder_features[-1] = self.bottleneck(encoder_features[-1])
  688. decoder_out, decoder_features, guides = self.decoder(encoder_features)
  689. output_size = x.shape[-2:]
  690. logits = self.segmentation_head(decoder_out, output_size=output_size)
  691. outputs: dict[str, torch.Tensor | list[torch.Tensor]] = {
  692. "logits": logits,
  693. "seg_logits": logits,
  694. "encoder_features": encoder_features,
  695. "decoder_features": decoder_features,
  696. "guides": guides,
  697. }
  698. return outputs