xnet_2d.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748
  1. from __future__ import annotations
  2. from collections.abc import Sequence
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from .layers_2d import Conv2dBN
  7. from .lib_mamba.vmamba import SS2D as VMambaSS2D
  8. class XNetStem2d(nn.Module):
  9. def __init__(self, in_channels: int, stem_channels: int, out_channels: int) -> None:
  10. super().__init__()
  11. self.block = nn.Sequential(
  12. Conv2dBN(in_channels, stem_channels, 3, 2, 1),
  13. nn.ReLU(inplace=True),
  14. Conv2dBN(stem_channels, stem_channels, 3, 1, 1, groups=stem_channels),
  15. nn.ReLU(inplace=True),
  16. Conv2dBN(stem_channels, out_channels, 1, 1, 0),
  17. nn.ReLU(inplace=True),
  18. Conv2dBN(out_channels, out_channels, 3, 2, 1),
  19. nn.ReLU(inplace=True),
  20. )
  21. def forward(self, x: torch.Tensor) -> torch.Tensor:
  22. return self.block(x)
  23. class XNetDownsample2d(nn.Module):
  24. def __init__(self, in_channels: int, out_channels: int, mode: str = "conv") -> None:
  25. super().__init__()
  26. if mode != "conv":
  27. raise ValueError(f"Unsupported downsample mode: {mode}")
  28. self.block = nn.Sequential(
  29. Conv2dBN(in_channels, out_channels, 3, 2, 1),
  30. nn.ReLU(inplace=True),
  31. )
  32. def forward(self, x: torch.Tensor) -> torch.Tensor:
  33. return self.block(x)
  34. class XLocalBranch2d(nn.Module):
  35. def __init__(self, channels: int) -> None:
  36. super().__init__()
  37. self.branch3 = nn.Sequential(
  38. Conv2dBN(channels, channels, 3, 1, 1, groups=channels),
  39. nn.ReLU(inplace=True),
  40. Conv2dBN(channels, channels, 1, 1, 0),
  41. )
  42. self.branch5 = nn.Sequential(
  43. Conv2dBN(channels, channels, 5, 1, 2, groups=channels),
  44. nn.ReLU(inplace=True),
  45. Conv2dBN(channels, channels, 1, 1, 0),
  46. )
  47. def forward(self, x: torch.Tensor) -> torch.Tensor:
  48. return self.branch3(x) + self.branch5(x)
  49. class XHaarWaveletTransform2d(nn.Module):
  50. def __init__(self, channels: int) -> None:
  51. super().__init__()
  52. ll = torch.tensor([[0.5, 0.5], [0.5, 0.5]], dtype=torch.float32)
  53. lh = torch.tensor([[-0.5, -0.5], [0.5, 0.5]], dtype=torch.float32)
  54. hl = torch.tensor([[-0.5, 0.5], [-0.5, 0.5]], dtype=torch.float32)
  55. hh = torch.tensor([[0.5, -0.5], [-0.5, 0.5]], dtype=torch.float32)
  56. filt = torch.stack([ll, lh, hl, hh], dim=0).unsqueeze(1)
  57. self.register_buffer(
  58. "analysis_filter", filt.repeat(channels, 1, 1, 1), persistent=False
  59. )
  60. self.register_buffer(
  61. "synthesis_filter", filt.repeat(channels, 1, 1, 1), persistent=False
  62. )
  63. self.channels = channels
  64. def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  65. b, c, h, w = x.shape
  66. pad_h = h % 2
  67. pad_w = w % 2
  68. if pad_h or pad_w:
  69. x = F.pad(x, (0, pad_w, 0, pad_h))
  70. y = F.conv2d(x, self.analysis_filter, stride=2, groups=self.channels)
  71. y = y.view(b, c, 4, y.shape[-2], y.shape[-1])
  72. ll = y[:, :, 0]
  73. high = y[:, :, 1:].reshape(b, c * 3, y.shape[-2], y.shape[-1])
  74. return ll, high
  75. def inverse(
  76. self, ll: torch.Tensor, high: torch.Tensor, output_size: tuple[int, int]
  77. ) -> torch.Tensor:
  78. b, c, h, w = ll.shape
  79. high = high.view(b, c, 3, h, w)
  80. y = torch.cat([ll.unsqueeze(2), high], dim=2).reshape(b, c * 4, h, w)
  81. x = F.conv_transpose2d(y, self.synthesis_filter, stride=2, groups=self.channels)
  82. return x[:, :, : output_size[0], : output_size[1]]
  83. class XWaveletBranch2d(nn.Module):
  84. def __init__(
  85. self, channels: int, wavelet_type: str = "haar", wavelet_level: int = 1
  86. ) -> None:
  87. super().__init__()
  88. if wavelet_type != "haar":
  89. raise ValueError(f"Unsupported wavelet type: {wavelet_type}")
  90. if wavelet_level != 1:
  91. raise ValueError(
  92. "Initial XNet implementation only supports wavelet_level=1."
  93. )
  94. self.wavelet = XHaarWaveletTransform2d(channels)
  95. self.ll_proj = nn.Sequential(
  96. Conv2dBN(channels, channels, 3, 1, 1),
  97. nn.ReLU(inplace=True),
  98. )
  99. self.high_proj = nn.Sequential(
  100. Conv2dBN(channels * 3, channels * 3, 3, 1, 1, groups=channels * 3),
  101. nn.ReLU(inplace=True),
  102. Conv2dBN(channels * 3, channels * 3, 1, 1, 0),
  103. )
  104. self.out_proj = nn.Sequential(
  105. Conv2dBN(channels, channels, 1, 1, 0),
  106. nn.ReLU(inplace=True),
  107. )
  108. def forward(self, x: torch.Tensor) -> torch.Tensor:
  109. output_size = x.shape[-2:]
  110. ll, high = self.wavelet(x)
  111. ll = self.ll_proj(ll)
  112. high = self.high_proj(high)
  113. x = self.wavelet.inverse(ll, high, output_size=output_size)
  114. return self.out_proj(x)
  115. class XSSMGlobalBranch2d(nn.Module):
  116. def __init__(
  117. self,
  118. channels: int,
  119. global_ratio: float = 2.0,
  120. d_state: int = 16,
  121. forward_type: str = "v3",
  122. ssm_backend: str = "auto",
  123. ) -> None:
  124. super().__init__()
  125. hidden_ratio = max(global_ratio, 1.0)
  126. self.backend = ssm_backend
  127. self.pre = nn.Sequential(
  128. Conv2dBN(channels, channels, 1, 1, 0),
  129. nn.ReLU(inplace=True),
  130. )
  131. self.ssm = VMambaSS2D(
  132. d_model=channels,
  133. d_state=d_state,
  134. ssm_ratio=hidden_ratio,
  135. d_conv=3,
  136. dropout=0.0,
  137. initialize="v0",
  138. forward_type=forward_type,
  139. channel_first=True,
  140. )
  141. self.post = nn.Sequential(
  142. Conv2dBN(channels, channels, 1, 1, 0),
  143. nn.ReLU(inplace=True),
  144. )
  145. def forward(self, x: torch.Tensor) -> torch.Tensor:
  146. x = self.pre(x)
  147. prev_backend = None
  148. backend = self.backend.lower()
  149. if backend == "auto":
  150. backend = "oflex" if x.is_cuda else "torch"
  151. if backend == "oflex" and hasattr(self.ssm, "forward_core"):
  152. prev_backend = self.ssm.forward_core
  153. self.ssm.forward_core = lambda z, _core=prev_backend: _core(
  154. z,
  155. selective_scan_backend="oflex",
  156. scan_force_torch=False,
  157. )
  158. elif backend == "torch" and hasattr(self.ssm, "forward_core"):
  159. prev_backend = self.ssm.forward_core
  160. self.ssm.forward_core = lambda z, _core=prev_backend: _core(
  161. z,
  162. selective_scan_backend="torch",
  163. scan_force_torch=True,
  164. )
  165. try:
  166. x = self.ssm(x)
  167. finally:
  168. if prev_backend is not None:
  169. self.ssm.forward_core = prev_backend
  170. return self.post(x)
  171. class XGlobalBranch2d(nn.Module):
  172. def __init__(
  173. self,
  174. channels: int,
  175. global_ratio: float = 2.0,
  176. ssm_d_state: int = 16,
  177. ssm_forward_type: str = "v3",
  178. ssm_backend: str = "auto",
  179. ) -> None:
  180. super().__init__()
  181. self.ssm_branch = XSSMGlobalBranch2d(
  182. channels=channels,
  183. global_ratio=global_ratio,
  184. d_state=ssm_d_state,
  185. forward_type=ssm_forward_type,
  186. ssm_backend=ssm_backend,
  187. )
  188. def forward(self, x: torch.Tensor) -> torch.Tensor:
  189. return self.ssm_branch(x)
  190. class XBranchFusion2d(nn.Module):
  191. def __init__(self, channels: int, num_branches: int = 3) -> None:
  192. super().__init__()
  193. fused_channels = channels * num_branches
  194. hidden_channels = max(channels // 4, 8)
  195. self.fuse = nn.Sequential(
  196. Conv2dBN(fused_channels, channels, 1, 1, 0),
  197. nn.ReLU(inplace=True),
  198. )
  199. self.gate = nn.Sequential(
  200. nn.AdaptiveAvgPool2d(1),
  201. nn.Conv2d(fused_channels, hidden_channels, kernel_size=1, bias=True),
  202. nn.ReLU(inplace=True),
  203. nn.Conv2d(hidden_channels, channels, kernel_size=1, bias=True),
  204. nn.Sigmoid(),
  205. )
  206. def forward(self, branch_outputs: Sequence[torch.Tensor]) -> torch.Tensor:
  207. x_cat = torch.cat(list(branch_outputs), dim=1)
  208. x_fused = self.fuse(x_cat)
  209. gate = self.gate(x_cat)
  210. return x_fused * gate
  211. class XTEB2d(nn.Module):
  212. def __init__(
  213. self,
  214. channels: int,
  215. global_ratio: float = 2.0,
  216. wavelet_type: str = "haar",
  217. wavelet_level: int = 1,
  218. use_wavelet_branch: bool = True,
  219. use_global_branch: bool = True,
  220. ssm_d_state: int = 16,
  221. ssm_forward_type: str = "v3",
  222. ssm_backend: str = "auto",
  223. ) -> None:
  224. super().__init__()
  225. self.pre_norm = Conv2dBN(channels, channels, 1, 1, 0)
  226. self.local_branch = XLocalBranch2d(channels)
  227. self.wavelet_branch = (
  228. XWaveletBranch2d(
  229. channels, wavelet_type=wavelet_type, wavelet_level=wavelet_level
  230. )
  231. if use_wavelet_branch
  232. else nn.Identity()
  233. )
  234. self.global_branch = (
  235. XGlobalBranch2d(
  236. channels,
  237. global_ratio=global_ratio,
  238. ssm_d_state=ssm_d_state,
  239. ssm_forward_type=ssm_forward_type,
  240. ssm_backend=ssm_backend,
  241. )
  242. if use_global_branch
  243. else nn.Identity()
  244. )
  245. self.fusion = XBranchFusion2d(channels, num_branches=3)
  246. self.post = nn.Sequential(
  247. Conv2dBN(channels, channels, 3, 1, 1),
  248. nn.ReLU(inplace=True),
  249. Conv2dBN(channels, channels, 1, 1, 0, bn_weight_init=0.0),
  250. )
  251. self.ffn = nn.Sequential(
  252. Conv2dBN(channels, channels * 2, 1, 1, 0),
  253. nn.ReLU(inplace=True),
  254. Conv2dBN(channels * 2, channels, 1, 1, 0, bn_weight_init=0.0),
  255. )
  256. def forward(self, x: torch.Tensor) -> torch.Tensor:
  257. x_in = x
  258. x = self.pre_norm(x)
  259. x = x_in + self.post(
  260. self.fusion(
  261. [self.local_branch(x), self.wavelet_branch(x), self.global_branch(x)]
  262. )
  263. )
  264. return x + self.ffn(x)
  265. class XNetEncoderStage2d(nn.Module):
  266. def __init__(
  267. self,
  268. channels: int,
  269. depth: int,
  270. global_ratio: float = 2.0,
  271. wavelet_type: str = "haar",
  272. wavelet_level: int = 1,
  273. use_wavelet_branch: bool = True,
  274. use_global_branch: bool = True,
  275. ssm_d_state: int = 16,
  276. ssm_forward_type: str = "v3",
  277. ssm_backend: str = "auto",
  278. ) -> None:
  279. super().__init__()
  280. self.blocks = nn.Sequential(
  281. *[
  282. XTEB2d(
  283. channels=channels,
  284. global_ratio=global_ratio,
  285. wavelet_type=wavelet_type,
  286. wavelet_level=wavelet_level,
  287. use_wavelet_branch=use_wavelet_branch,
  288. use_global_branch=use_global_branch,
  289. ssm_d_state=ssm_d_state,
  290. ssm_forward_type=ssm_forward_type,
  291. ssm_backend=ssm_backend,
  292. )
  293. for _ in range(depth)
  294. ]
  295. )
  296. def forward(self, x: torch.Tensor) -> torch.Tensor:
  297. return self.blocks(x)
  298. class XNetEncoder2d(nn.Module):
  299. def __init__(
  300. self,
  301. in_channels: int,
  302. stem_channels: int,
  303. encoder_channels: Sequence[int],
  304. encoder_depths: Sequence[int],
  305. global_ratio: float = 2.0,
  306. wavelet_type: str = "haar",
  307. wavelet_level: int = 1,
  308. use_wavelet_branch: bool = True,
  309. use_global_branch_stage1: bool = False,
  310. ssm_d_state: int = 16,
  311. ssm_forward_type: str = "v3",
  312. ssm_backend: str = "auto",
  313. ) -> None:
  314. super().__init__()
  315. if len(encoder_channels) != 4 or len(encoder_depths) != 4:
  316. raise ValueError("XNetEncoder2d expects 4 encoder stages.")
  317. c1, c2, c3, c4 = encoder_channels
  318. d1, d2, d3, d4 = encoder_depths
  319. self.stem = XNetStem2d(in_channels, stem_channels, c1)
  320. self.stage1 = XNetEncoderStage2d(
  321. c1,
  322. d1,
  323. global_ratio,
  324. wavelet_type,
  325. wavelet_level,
  326. use_wavelet_branch=use_wavelet_branch,
  327. use_global_branch=use_global_branch_stage1,
  328. ssm_d_state=ssm_d_state,
  329. ssm_forward_type=ssm_forward_type,
  330. ssm_backend=ssm_backend,
  331. )
  332. self.down1 = XNetDownsample2d(c1, c2)
  333. self.stage2 = XNetEncoderStage2d(
  334. c2,
  335. d2,
  336. global_ratio,
  337. wavelet_type,
  338. wavelet_level,
  339. use_wavelet_branch,
  340. True,
  341. ssm_d_state=ssm_d_state,
  342. ssm_forward_type=ssm_forward_type,
  343. ssm_backend=ssm_backend,
  344. )
  345. self.down2 = XNetDownsample2d(c2, c3)
  346. self.stage3 = XNetEncoderStage2d(
  347. c3,
  348. d3,
  349. global_ratio,
  350. wavelet_type,
  351. wavelet_level,
  352. use_wavelet_branch,
  353. True,
  354. ssm_d_state=ssm_d_state,
  355. ssm_forward_type=ssm_forward_type,
  356. ssm_backend=ssm_backend,
  357. )
  358. self.down3 = XNetDownsample2d(c3, c4)
  359. self.stage4 = XNetEncoderStage2d(
  360. c4,
  361. d4,
  362. global_ratio,
  363. wavelet_type,
  364. wavelet_level,
  365. use_wavelet_branch,
  366. True,
  367. ssm_d_state=ssm_d_state,
  368. ssm_forward_type=ssm_forward_type,
  369. ssm_backend=ssm_backend,
  370. )
  371. self.stage_channels = list(encoder_channels)
  372. def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
  373. e1 = self.stage1(self.stem(x))
  374. e2 = self.stage2(self.down1(e1))
  375. e3 = self.stage3(self.down2(e2))
  376. e4 = self.stage4(self.down3(e3))
  377. return [e1, e2, e3, e4]
  378. class XGuideProjector2d(nn.Module):
  379. def __init__(
  380. self, in_channels: int, out_channels: int, mode: str = "affine"
  381. ) -> None:
  382. super().__init__()
  383. self.mode = mode
  384. if mode == "affine":
  385. self.proj = nn.Sequential(
  386. Conv2dBN(in_channels, out_channels * 2, 1, 1, 0),
  387. nn.ReLU(inplace=True),
  388. nn.Conv2d(out_channels * 2, out_channels * 2, kernel_size=1, bias=True),
  389. )
  390. elif mode == "feature":
  391. self.proj = nn.Sequential(
  392. Conv2dBN(in_channels, out_channels, 1, 1, 0),
  393. nn.ReLU(inplace=True),
  394. )
  395. else:
  396. raise ValueError(f"Unsupported guide mode: {mode}")
  397. def forward(
  398. self,
  399. x: torch.Tensor,
  400. target_size: tuple[int, int],
  401. ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
  402. x = F.interpolate(x, size=target_size, mode="bilinear", align_corners=False)
  403. x = self.proj(x)
  404. if self.mode == "affine":
  405. gamma, beta = torch.chunk(x, 2, dim=1)
  406. gamma = torch.sigmoid(gamma) + 0.5
  407. return gamma, beta
  408. return x
  409. class XSkipFusion2d(nn.Module):
  410. def __init__(self, in_channels: int, skip_channels: int, out_channels: int) -> None:
  411. super().__init__()
  412. self.input_proj = nn.Sequential(
  413. Conv2dBN(in_channels, out_channels, 1, 1, 0),
  414. nn.ReLU(inplace=True),
  415. )
  416. self.skip_proj = nn.Sequential(
  417. Conv2dBN(skip_channels, out_channels, 1, 1, 0),
  418. nn.ReLU(inplace=True),
  419. )
  420. self.fuse = nn.Sequential(
  421. Conv2dBN(out_channels * 2, out_channels, 3, 1, 1),
  422. nn.ReLU(inplace=True),
  423. )
  424. def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
  425. x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
  426. x = self.input_proj(x)
  427. skip = self.skip_proj(skip)
  428. return self.fuse(torch.cat([x, skip], dim=1))
  429. class XGuideModulation2d(nn.Module):
  430. def __init__(self, channels: int, guide_mode: str = "affine") -> None:
  431. super().__init__()
  432. self.guide_mode = guide_mode
  433. if guide_mode == "feature":
  434. self.to_affine = nn.Conv2d(channels, channels * 2, kernel_size=1, bias=True)
  435. def forward(
  436. self,
  437. x: torch.Tensor,
  438. guide: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
  439. ) -> torch.Tensor:
  440. if self.guide_mode == "affine":
  441. gamma, beta = guide
  442. else:
  443. gamma, beta = torch.chunk(self.to_affine(guide), 2, dim=1)
  444. gamma = torch.sigmoid(gamma) + 0.5
  445. return gamma * x + beta
  446. class XFrequencyRefine2d(nn.Module):
  447. def __init__(self, channels: int) -> None:
  448. super().__init__()
  449. self.low_gate = nn.Sequential(
  450. nn.AdaptiveAvgPool2d(1),
  451. nn.Conv2d(channels, channels, kernel_size=1, bias=True),
  452. nn.Sigmoid(),
  453. )
  454. self.high_gate = nn.Sequential(
  455. nn.AdaptiveAvgPool2d(1),
  456. nn.Conv2d(channels, channels, kernel_size=1, bias=True),
  457. nn.Sigmoid(),
  458. )
  459. self.refine = nn.Sequential(
  460. Conv2dBN(channels, channels, 3, 1, 1, groups=channels),
  461. nn.ReLU(inplace=True),
  462. Conv2dBN(channels, channels, 1, 1, 0),
  463. )
  464. def forward(self, x: torch.Tensor) -> torch.Tensor:
  465. input_dtype = x.dtype
  466. if x.dtype != torch.float32:
  467. x = x.to(torch.float32)
  468. fft = torch.fft.rfft2(x, norm="ortho")
  469. low = fft.clone()
  470. h_freq, w_freq = low.shape[-2], low.shape[-1]
  471. low[:, :, h_freq // 4 :, :] = 0
  472. low[:, :, :, w_freq // 4 :] = 0
  473. high = fft - low
  474. low = low * self.low_gate(x)
  475. high = high * self.high_gate(x)
  476. out = torch.fft.irfft2(low + high, s=x.shape[-2:], norm="ortho")
  477. out = out.to(dtype=input_dtype)
  478. return self.refine(out)
  479. class XCRB2d(nn.Module):
  480. def __init__(
  481. self,
  482. in_channels: int,
  483. skip_channels: int,
  484. guide_channels: int,
  485. out_channels: int,
  486. guide_mode: str = "affine",
  487. use_frequency_refine: bool = True,
  488. ) -> None:
  489. super().__init__()
  490. self.skip_fusion = XSkipFusion2d(in_channels, skip_channels, out_channels)
  491. self.guide_modulation = XGuideModulation2d(out_channels, guide_mode=guide_mode)
  492. self.frequency_refine = (
  493. XFrequencyRefine2d(out_channels) if use_frequency_refine else nn.Identity()
  494. )
  495. self.out_refine = nn.Sequential(
  496. Conv2dBN(out_channels, out_channels, 3, 1, 1),
  497. nn.ReLU(inplace=True),
  498. Conv2dBN(out_channels, out_channels, 3, 1, 1, bn_weight_init=0.0),
  499. )
  500. self.guide_channels = guide_channels
  501. def forward(
  502. self,
  503. x: torch.Tensor,
  504. skip: torch.Tensor,
  505. guide: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
  506. ) -> torch.Tensor:
  507. x = self.skip_fusion(x, skip)
  508. x = self.guide_modulation(x, guide)
  509. x = x + self.frequency_refine(x)
  510. return x + self.out_refine(x)
  511. class XNetHeadRefine2d(nn.Module):
  512. def __init__(self, channels: int, out_channels: int | None = None) -> None:
  513. super().__init__()
  514. if out_channels is None:
  515. out_channels = channels
  516. self.block = nn.Sequential(
  517. Conv2dBN(channels, out_channels, 3, 1, 1),
  518. nn.ReLU(inplace=True),
  519. Conv2dBN(out_channels, out_channels, 3, 1, 1),
  520. nn.ReLU(inplace=True),
  521. )
  522. def forward(self, x: torch.Tensor) -> torch.Tensor:
  523. return self.block(x)
  524. class XNetDecoder2d(nn.Module):
  525. def __init__(
  526. self,
  527. encoder_channels: Sequence[int],
  528. decoder_channels: Sequence[int] = (128, 64, 32),
  529. guide_mode: str = "affine",
  530. use_frequency_refine: bool = True,
  531. out_channels: int | None = None,
  532. ) -> None:
  533. super().__init__()
  534. if len(encoder_channels) != 4:
  535. raise ValueError("XNetDecoder2d expects 4 encoder stages.")
  536. if len(decoder_channels) != 3:
  537. raise ValueError("XNetDecoder2d expects 3 decoder channels.")
  538. c1, c2, c3, c4 = encoder_channels
  539. d4, d3, d2 = decoder_channels
  540. self.guide4 = XGuideProjector2d(c4, d4, mode=guide_mode)
  541. self.guide3 = XGuideProjector2d(c3, d3, mode=guide_mode)
  542. self.guide2 = XGuideProjector2d(c2, d2, mode=guide_mode)
  543. self.dec4 = XCRB2d(
  544. c4,
  545. c3,
  546. d4,
  547. d4,
  548. guide_mode=guide_mode,
  549. use_frequency_refine=use_frequency_refine,
  550. )
  551. self.dec3 = XCRB2d(
  552. d4,
  553. c2,
  554. d3,
  555. d3,
  556. guide_mode=guide_mode,
  557. use_frequency_refine=use_frequency_refine,
  558. )
  559. self.dec2 = XCRB2d(
  560. d3,
  561. c1,
  562. d2,
  563. d2,
  564. guide_mode=guide_mode,
  565. use_frequency_refine=use_frequency_refine,
  566. )
  567. self.head_refine = XNetHeadRefine2d(d2, out_channels or d2)
  568. self.out_channels = out_channels or d2
  569. def forward(
  570. self,
  571. features: Sequence[torch.Tensor],
  572. ) -> tuple[
  573. torch.Tensor,
  574. list[torch.Tensor],
  575. list[torch.Tensor | tuple[torch.Tensor, torch.Tensor]],
  576. ]:
  577. e1, e2, e3, e4 = features
  578. g4 = self.guide4(e4, target_size=e3.shape[-2:])
  579. d4 = self.dec4(e4, e3, g4)
  580. g3 = self.guide3(e3, target_size=e2.shape[-2:])
  581. d3 = self.dec3(d4, e2, g3)
  582. g2 = self.guide2(e2, target_size=e1.shape[-2:])
  583. d2 = self.dec2(d3, e1, g2)
  584. d1 = self.head_refine(d2)
  585. return d1, [d4, d3, d2, d1], [g4, g3, g2]
  586. class XNetSegHead2d(nn.Module):
  587. def __init__(
  588. self, in_channels: int, num_classes: int, upsample_scale: int = 4
  589. ) -> None:
  590. super().__init__()
  591. self.block = nn.Sequential(
  592. Conv2dBN(in_channels, in_channels, 3, 1, 1),
  593. nn.ReLU(inplace=True),
  594. nn.Conv2d(in_channels, num_classes, kernel_size=1, bias=True),
  595. )
  596. self.upsample_scale = upsample_scale
  597. def forward(self, x: torch.Tensor, output_size: tuple[int, int]) -> torch.Tensor:
  598. x = self.block(x)
  599. return F.interpolate(x, size=output_size, mode="bilinear", align_corners=False)
  600. class XNet2d(nn.Module):
  601. def __init__(
  602. self,
  603. in_channels: int,
  604. num_classes: int,
  605. encoder_channels: Sequence[int] = (32, 64, 128, 192),
  606. encoder_depths: Sequence[int] = (2, 2, 2, 2),
  607. decoder_channels: Sequence[int] = (128, 64, 32),
  608. stem_channels: int = 24,
  609. bottleneck_depth: int = 1,
  610. global_ratio: float = 2.0,
  611. wavelet_type: str = "haar",
  612. wavelet_level: int = 1,
  613. use_wavelet_branch: bool = True,
  614. use_global_branch_stage1: bool = False,
  615. ssm_d_state: int = 16,
  616. ssm_forward_type: str = "v3",
  617. ssm_backend: str = "auto",
  618. use_frequency_refine: bool = True,
  619. guide_mode: str = "affine",
  620. out_channels: int | None = None,
  621. ) -> None:
  622. super().__init__()
  623. self.encoder = XNetEncoder2d(
  624. in_channels=in_channels,
  625. stem_channels=stem_channels,
  626. encoder_channels=encoder_channels,
  627. encoder_depths=encoder_depths,
  628. global_ratio=global_ratio,
  629. wavelet_type=wavelet_type,
  630. wavelet_level=wavelet_level,
  631. use_wavelet_branch=use_wavelet_branch,
  632. use_global_branch_stage1=use_global_branch_stage1,
  633. ssm_d_state=ssm_d_state,
  634. ssm_forward_type=ssm_forward_type,
  635. ssm_backend=ssm_backend,
  636. )
  637. bottleneck_channels = encoder_channels[-1]
  638. self.bottleneck = nn.Sequential(
  639. *[
  640. XTEB2d(
  641. channels=bottleneck_channels,
  642. global_ratio=global_ratio,
  643. wavelet_type=wavelet_type,
  644. wavelet_level=wavelet_level,
  645. use_wavelet_branch=use_wavelet_branch,
  646. use_global_branch=True,
  647. ssm_d_state=ssm_d_state,
  648. ssm_forward_type=ssm_forward_type,
  649. ssm_backend=ssm_backend,
  650. )
  651. for _ in range(bottleneck_depth)
  652. ]
  653. )
  654. self.decoder = XNetDecoder2d(
  655. encoder_channels=encoder_channels,
  656. decoder_channels=decoder_channels,
  657. guide_mode=guide_mode,
  658. use_frequency_refine=use_frequency_refine,
  659. out_channels=out_channels,
  660. )
  661. head_in_channels = self.decoder.out_channels
  662. self.segmentation_head = XNetSegHead2d(head_in_channels, num_classes)
  663. def forward(
  664. self, x: torch.Tensor
  665. ) -> dict[
  666. str, torch.Tensor | list[torch.Tensor] | list[tuple[torch.Tensor, torch.Tensor]]
  667. ]:
  668. encoder_features = self.encoder(x)
  669. encoder_features[-1] = self.bottleneck(encoder_features[-1])
  670. decoder_out, decoder_features, guides = self.decoder(encoder_features)
  671. output_size = x.shape[-2:]
  672. logits = self.segmentation_head(decoder_out, output_size=output_size)
  673. outputs: dict[
  674. str,
  675. torch.Tensor | list[torch.Tensor] | list[tuple[torch.Tensor, torch.Tensor]],
  676. ] = {
  677. "logits": logits,
  678. "seg_logits": logits,
  679. "encoder_features": encoder_features,
  680. "decoder_features": decoder_features,
  681. "guides": guides,
  682. }
  683. return outputs