fwta_2d.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. from typing import Optional, Tuple
  4. import ptwt
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. def build_gaussian_lowpass(
  9. channels: int,
  10. sigma_ratio: float = 0.35,
  11. device: Optional[torch.device] = None,
  12. dtype: Optional[torch.dtype] = None,
  13. ) -> torch.Tensor:
  14. """
  15. 构建用于通道维度的 1D 高斯低通滤波器。
  16. Returns:
  17. Tensor of shape [1, 1, C].
  18. """
  19. sigma = max(channels * sigma_ratio, 1.0)
  20. center = (channels - 1) / 2.0
  21. coords = torch.arange(channels, device=device, dtype=dtype or torch.float32)
  22. kernel = torch.exp(-0.5 * ((coords - center) / sigma) ** 2)
  23. kernel = kernel / kernel.max().clamp_min(1e-6)
  24. return kernel.view(1, 1, channels)
  25. @dataclass
  26. class FWTADebug:
  27. initial_global_token: torch.Tensor
  28. fourier_score: torch.Tensor
  29. wavelet_score: torch.Tensor
  30. stability_prior: torch.Tensor
  31. saliency_prior: torch.Tensor
  32. fused_score: torch.Tensor
  33. gate: torch.Tensor
  34. pooled_token: torch.Tensor
  35. class FourierWaveletTokenAggregation(nn.Module):
  36. """
  37. 傅里叶 - 小波令牌聚合模块。
  38. Inputs:
  39. cls_token: [B, C]
  40. patch_tokens: [B, N, C]
  41. Output:
  42. cls_out: [B, C]
  43. gate: [B, N]
  44. Design:
  45. - Fourier branch estimates token-wise semantic stability.
  46. - Wavelet branch estimates token-wise structural saliency.
  47. - Fused score produces a softmax gate over tokens.
  48. - Weighted pooled token is added back to the CLS token by residual update.
  49. """
  50. def __init__(
  51. self,
  52. dim: int,
  53. grid_size: Tuple[int, int],
  54. wavelet: str = "haar",
  55. wavelet_level: int = 1,
  56. sigma_ratio: float = 0.35,
  57. tau_fourier: float = 0.15,
  58. gate_temperature: float = 1.0,
  59. residual_scale_init: float = 1.0,
  60. fusion_hidden_ratio: float = 0.5,
  61. use_cls_conditioning: bool = True,
  62. learnable_global_token: bool = True,
  63. global_token_use_image_conditioning: bool = True,
  64. eps: float = 1e-6,
  65. ) -> None:
  66. super().__init__()
  67. self.dim = dim
  68. self.grid_size = grid_size
  69. self.wavelet = wavelet
  70. self.wavelet_level = wavelet_level
  71. self.sigma_ratio = sigma_ratio
  72. self.tau_fourier = tau_fourier
  73. self.gate_temperature = gate_temperature
  74. self.use_cls_conditioning = use_cls_conditioning
  75. self.learnable_global_token = learnable_global_token
  76. self.global_token_use_image_conditioning = global_token_use_image_conditioning
  77. self.eps = eps
  78. hidden_dim = max(int(dim * fusion_hidden_ratio), 32)
  79. fuse_in_dim = 3 if use_cls_conditioning else 2
  80. self.score_fuser = nn.Sequential(
  81. nn.Linear(fuse_in_dim, hidden_dim),
  82. nn.GELU(),
  83. nn.Linear(hidden_dim, 1),
  84. )
  85. self.token_proj = nn.Sequential(
  86. nn.LayerNorm(dim),
  87. nn.Linear(dim, dim),
  88. nn.GELU(),
  89. nn.Linear(dim, dim),
  90. )
  91. self.out_norm = nn.LayerNorm(dim)
  92. self.residual_scale = nn.Parameter(torch.tensor(float(residual_scale_init)))
  93. self.base_global_token = nn.Parameter(torch.zeros(1, dim))
  94. nn.init.trunc_normal_(self.base_global_token, std=0.02)
  95. if learnable_global_token and global_token_use_image_conditioning:
  96. self.global_context_proj = nn.Sequential(
  97. nn.LayerNorm(dim),
  98. nn.Linear(dim, dim),
  99. nn.GELU(),
  100. nn.Linear(dim, dim),
  101. )
  102. self.global_token_norm = nn.LayerNorm(dim)
  103. elif learnable_global_token:
  104. self.global_context_proj = None
  105. self.global_token_norm = nn.LayerNorm(dim)
  106. else:
  107. self.global_context_proj = None
  108. self.global_token_norm = nn.Identity()
  109. # 学习系数以平衡粗结构、边缘线索和高频细节。
  110. # 注意:HH 子带不被预设为纯噪声,而是允许模型学习其正负贡献。
  111. self.wavelet_ll_weight = nn.Parameter(torch.tensor(1.0))
  112. self.wavelet_edge_weight = nn.Parameter(torch.tensor(0.5))
  113. self.wavelet_hh_weight = nn.Parameter(torch.tensor(-0.25))
  114. self.stability_fourier_weight = nn.Parameter(torch.tensor(0.7))
  115. self.stability_wavelet_weight = nn.Parameter(torch.tensor(0.3))
  116. self.saliency_wavelet_weight = nn.Parameter(torch.tensor(1.0))
  117. self.context_fourier_weight = nn.Parameter(torch.tensor(0.5))
  118. self.context_wavelet_weight = nn.Parameter(torch.tensor(0.5))
  119. self.alignment_residual_weight = nn.Parameter(torch.tensor(0.1))
  120. self.register_buffer("gaussian_kernel", build_gaussian_lowpass(dim, sigma_ratio), persistent=False)
  121. def forward(
  122. self,
  123. patch_tokens: torch.Tensor,
  124. cls_token: torch.Tensor | None = None,
  125. return_debug: bool = False,
  126. ):
  127. B, N, C = patch_tokens.shape
  128. H, W = self.grid_size
  129. if N != H * W:
  130. raise ValueError(f"patch count mismatch: got N={N}, expected H*W={H * W}")
  131. if C != self.dim:
  132. raise ValueError(f"channel mismatch: got C={C}, expected dim={self.dim}")
  133. fourier_score = self._fourier_stability_score(patch_tokens)
  134. wavelet_score = self._wavelet_saliency_score(patch_tokens)
  135. initial_global_token = self._build_global_token(
  136. patch_tokens,
  137. fourier_score=fourier_score,
  138. wavelet_score=wavelet_score,
  139. cls_token=cls_token,
  140. )
  141. stability_prior = self._build_stability_prior(fourier_score, wavelet_score)
  142. saliency_prior = self._build_saliency_prior(wavelet_score)
  143. fused_input = torch.stack([fourier_score, wavelet_score], dim=-1) # [B, N, 2]
  144. fused_score = self.score_fuser(fused_input).squeeze(-1) # [B, N]
  145. if self.use_cls_conditioning:
  146. cls_alignment = self._cls_alignment_score(initial_global_token.detach(), patch_tokens)
  147. fused_score = fused_score + self.alignment_residual_weight * cls_alignment
  148. gate = torch.softmax(fused_score / max(self.gate_temperature, self.eps), dim=1)
  149. pooled_token = torch.sum(gate.unsqueeze(-1) * patch_tokens, dim=1) # [B, C]
  150. pooled_token = self.token_proj(pooled_token)
  151. cls_out = initial_global_token + self.residual_scale * pooled_token
  152. cls_out = self.out_norm(cls_out)
  153. if return_debug:
  154. debug = FWTADebug(
  155. initial_global_token=initial_global_token,
  156. fourier_score=fourier_score,
  157. wavelet_score=wavelet_score,
  158. stability_prior=stability_prior,
  159. saliency_prior=saliency_prior,
  160. fused_score=fused_score,
  161. gate=gate,
  162. pooled_token=pooled_token,
  163. )
  164. return cls_out, gate, debug
  165. return cls_out, gate
  166. def get_stability_map(self, patch_tokens: torch.Tensor) -> torch.Tensor:
  167. """
  168. 为分割任务提供二维稳定性先验图接口。
  169. Returns:
  170. Tensor of shape [B, 1, H, W].
  171. """
  172. _, _, debug = self.forward(
  173. patch_tokens=patch_tokens,
  174. return_debug=True,
  175. )
  176. return self._score_to_map(debug.stability_prior, patch_tokens.shape[0])
  177. def forward_with_map(
  178. self,
  179. patch_tokens: torch.Tensor,
  180. cls_token: torch.Tensor | None = None,
  181. return_debug: bool = False,
  182. ):
  183. """
  184. 同时返回 CLS 更新结果、门控权重以及二维稳定性图。
  185. """
  186. outputs = self.forward(patch_tokens, cls_token=cls_token, return_debug=return_debug)
  187. H, W = self.grid_size
  188. if return_debug:
  189. cls_out, gate, debug = outputs
  190. stability_map = self._score_to_map(debug.stability_prior, patch_tokens.shape[0])
  191. saliency_map = self._score_to_map(debug.saliency_prior, patch_tokens.shape[0])
  192. return cls_out, gate, stability_map, saliency_map, debug
  193. cls_out, gate = outputs
  194. stability_map = self._score_to_map(self._build_stability_prior(
  195. self._fourier_stability_score(patch_tokens),
  196. self._wavelet_saliency_score(patch_tokens),
  197. ), patch_tokens.shape[0])
  198. saliency_map = self._score_to_map(self._build_saliency_prior(
  199. self._wavelet_saliency_score(patch_tokens)
  200. ), patch_tokens.shape[0])
  201. return cls_out, gate, stability_map, saliency_map
  202. def _build_global_token(
  203. self,
  204. patch_tokens: torch.Tensor,
  205. fourier_score: torch.Tensor,
  206. wavelet_score: torch.Tensor,
  207. cls_token: torch.Tensor | None = None,
  208. ) -> torch.Tensor:
  209. if cls_token is not None:
  210. return cls_token
  211. if not self.learnable_global_token:
  212. return patch_tokens.mean(dim=1)
  213. batch_size, _, channels = patch_tokens.shape
  214. token = self.base_global_token.expand(batch_size, channels)
  215. if self.global_context_proj is not None:
  216. pre_context_gate = self._build_context_gate(fourier_score, wavelet_score)
  217. image_context = torch.sum(pre_context_gate.unsqueeze(-1) * patch_tokens, dim=1)
  218. token = token + self.global_context_proj(image_context)
  219. return self.global_token_norm(token)
  220. def _fourier_stability_score(self, patch_tokens: torch.Tensor) -> torch.Tensor:
  221. """
  222. 通过通道级低通滤波后的变化量来评分令牌。
  223. Higher score => more stable token => more likely to carry coherent semantics.
  224. """
  225. kernel = self.gaussian_kernel.to(device=patch_tokens.device, dtype=patch_tokens.dtype)
  226. xf = torch.fft.fft(patch_tokens, dim=-1)
  227. xf = torch.fft.fftshift(xf, dim=-1)
  228. xf_low = xf * kernel
  229. xf_low = torch.fft.ifftshift(xf_low, dim=-1)
  230. x_low = torch.fft.ifft(xf_low, dim=-1).real
  231. delta = torch.mean(torch.abs(patch_tokens - x_low), dim=-1) # [B, N]
  232. score = torch.exp(-delta / max(self.tau_fourier, self.eps))
  233. return score
  234. def _wavelet_saliency_score(self, patch_tokens: torch.Tensor) -> torch.Tensor:
  235. """
  236. 使用 Token-Grid 小波分解来估计结构前景显著性。
  237. The patch tokens are treated as a low-resolution feature map [B, C, H, W].
  238. """
  239. B, N, C = patch_tokens.shape
  240. H, W = self.grid_size
  241. x2d = patch_tokens.transpose(1, 2).reshape(B, C, H, W)
  242. coeffs = ptwt.wavedec2(x2d, self.wavelet, level=self.wavelet_level)
  243. ll = coeffs[0]
  244. detail_coeffs = coeffs[1:]
  245. ll_energy = ll.abs().mean(dim=1, keepdim=True)
  246. ll_energy = F.interpolate(ll_energy, size=(H, W), mode="nearest")
  247. edge_energy = torch.zeros_like(ll_energy)
  248. hh_energy = torch.zeros_like(ll_energy)
  249. for level_detail in detail_coeffs:
  250. lh, hl, hh = level_detail
  251. level_edge = 0.5 * (lh.abs().mean(dim=1, keepdim=True) + hl.abs().mean(dim=1, keepdim=True))
  252. level_hh = hh.abs().mean(dim=1, keepdim=True)
  253. target_size = (H, W)
  254. level_edge = F.interpolate(level_edge, size=target_size, mode="nearest")
  255. level_hh = F.interpolate(level_hh, size=target_size, mode="nearest")
  256. edge_energy = edge_energy + level_edge
  257. hh_energy = hh_energy + level_hh
  258. raw_score = (
  259. self.wavelet_ll_weight * ll_energy
  260. + self.wavelet_edge_weight * edge_energy
  261. + self.wavelet_hh_weight * hh_energy
  262. )
  263. raw_score = raw_score.flatten(1) # [B, N]
  264. score = torch.sigmoid(raw_score)
  265. return score
  266. def _build_stability_prior(
  267. self,
  268. fourier_score: torch.Tensor,
  269. wavelet_score: torch.Tensor,
  270. ) -> torch.Tensor:
  271. raw = (
  272. self.stability_fourier_weight * fourier_score
  273. + self.stability_wavelet_weight * wavelet_score
  274. )
  275. return torch.sigmoid(raw)
  276. def _build_saliency_prior(self, wavelet_score: torch.Tensor) -> torch.Tensor:
  277. raw = self.saliency_wavelet_weight * wavelet_score
  278. return torch.sigmoid(raw)
  279. def _build_context_gate(
  280. self,
  281. fourier_score: torch.Tensor,
  282. wavelet_score: torch.Tensor,
  283. ) -> torch.Tensor:
  284. context_score = (
  285. self.context_fourier_weight * fourier_score
  286. + self.context_wavelet_weight * wavelet_score
  287. )
  288. return torch.softmax(context_score / max(self.gate_temperature, self.eps), dim=1)
  289. def _score_to_map(self, score: torch.Tensor, batch_size: int) -> torch.Tensor:
  290. H, W = self.grid_size
  291. return score.reshape(batch_size, 1, H, W)
  292. def _cls_alignment_score(self, cls_token: torch.Tensor, patch_tokens: torch.Tensor) -> torch.Tensor:
  293. """
  294. 可选稳定器:偏好已与现有 CLS 令牌对齐的令牌。
  295. 这有助于模块作为修正项而不是完全独立的分支发挥作用。
  296. """
  297. cls_norm = F.normalize(cls_token, dim=-1)
  298. patch_norm = F.normalize(patch_tokens, dim=-1)
  299. score = torch.sum(patch_norm * cls_norm.unsqueeze(1), dim=-1)
  300. score = 0.5 * (score + 1.0) # map cosine similarity from [-1, 1] to [0, 1]
  301. return score
  302. class ViTBlockWithFWTA(nn.Module):
  303. """
  304. 最小包装器,展示如何在 Transformer Block 后插入 FWTA。
  305. Expected input:
  306. x: [B, 1 + N, C]
  307. Output:
  308. x: [B, 1 + N, C]
  309. """
  310. def __init__(self, block: nn.Module, dim: int, grid_size: Tuple[int, int]) -> None:
  311. super().__init__()
  312. self.block = block
  313. self.fwta = FourierWaveletTokenAggregation(dim=dim, grid_size=grid_size)
  314. def forward(self, x: torch.Tensor):
  315. x = self.block(x)
  316. cls_token = x[:, 0]
  317. patch_tokens = x[:, 1:]
  318. cls_token, gate = self.fwta(cls_token, patch_tokens)
  319. x = torch.cat([cls_token.unsqueeze(1), patch_tokens], dim=1)
  320. return x, gate
  321. class FinalAggregatorWithFWTA(nn.Module):
  322. """
  323. 适用于 torchvision / timm 风格 ViT 的更简单变体:
  324. 保持所有 Encoder Block 不变,仅在最后应用 FWTA。
  325. """
  326. def __init__(self, dim: int, grid_size: Tuple[int, int], num_classes: int) -> None:
  327. super().__init__()
  328. self.fwta = FourierWaveletTokenAggregation(dim=dim, grid_size=grid_size)
  329. self.head = nn.Linear(dim, num_classes)
  330. def forward(self, encoder_tokens: torch.Tensor):
  331. cls_token = encoder_tokens[:, 0]
  332. patch_tokens = encoder_tokens[:, 1:]
  333. cls_token, gate = self.fwta(cls_token, patch_tokens)
  334. logits = self.head(cls_token)
  335. return logits, gate