from __future__ import annotations from dataclasses import dataclass from typing import Optional, Tuple import ptwt import torch import torch.nn as nn import torch.nn.functional as F def build_gaussian_lowpass( channels: int, sigma_ratio: float = 0.35, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: """ 构建用于通道维度的 1D 高斯低通滤波器。 Returns: Tensor of shape [1, 1, C]. """ sigma = max(channels * sigma_ratio, 1.0) center = (channels - 1) / 2.0 coords = torch.arange(channels, device=device, dtype=dtype or torch.float32) kernel = torch.exp(-0.5 * ((coords - center) / sigma) ** 2) kernel = kernel / kernel.max().clamp_min(1e-6) return kernel.view(1, 1, channels) @dataclass class FWTADebug: initial_global_token: torch.Tensor fourier_score: torch.Tensor wavelet_score: torch.Tensor stability_prior: torch.Tensor saliency_prior: torch.Tensor fused_score: torch.Tensor gate: torch.Tensor pooled_token: torch.Tensor class FourierWaveletTokenAggregation(nn.Module): """ 傅里叶 - 小波令牌聚合模块。 Inputs: cls_token: [B, C] patch_tokens: [B, N, C] Output: cls_out: [B, C] gate: [B, N] Design: - Fourier branch estimates token-wise semantic stability. - Wavelet branch estimates token-wise structural saliency. - Fused score produces a softmax gate over tokens. - Weighted pooled token is added back to the CLS token by residual update. """ def __init__( self, dim: int, grid_size: Tuple[int, int], wavelet: str = "haar", wavelet_level: int = 1, sigma_ratio: float = 0.35, tau_fourier: float = 0.15, gate_temperature: float = 1.0, residual_scale_init: float = 1.0, fusion_hidden_ratio: float = 0.5, use_cls_conditioning: bool = True, learnable_global_token: bool = True, global_token_use_image_conditioning: bool = True, eps: float = 1e-6, ) -> None: super().__init__() self.dim = dim self.grid_size = grid_size self.wavelet = wavelet self.wavelet_level = wavelet_level self.sigma_ratio = sigma_ratio self.tau_fourier = tau_fourier self.gate_temperature = gate_temperature self.use_cls_conditioning = use_cls_conditioning self.learnable_global_token = learnable_global_token self.global_token_use_image_conditioning = global_token_use_image_conditioning self.eps = eps hidden_dim = max(int(dim * fusion_hidden_ratio), 32) fuse_in_dim = 3 if use_cls_conditioning else 2 self.score_fuser = nn.Sequential( nn.Linear(fuse_in_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, 1), ) self.token_proj = nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, dim), nn.GELU(), nn.Linear(dim, dim), ) self.out_norm = nn.LayerNorm(dim) self.residual_scale = nn.Parameter(torch.tensor(float(residual_scale_init))) self.base_global_token = nn.Parameter(torch.zeros(1, dim)) nn.init.trunc_normal_(self.base_global_token, std=0.02) if learnable_global_token and global_token_use_image_conditioning: self.global_context_proj = nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, dim), nn.GELU(), nn.Linear(dim, dim), ) self.global_token_norm = nn.LayerNorm(dim) elif learnable_global_token: self.global_context_proj = None self.global_token_norm = nn.LayerNorm(dim) else: self.global_context_proj = None self.global_token_norm = nn.Identity() # 学习系数以平衡粗结构、边缘线索和高频细节。 # 注意:HH 子带不被预设为纯噪声,而是允许模型学习其正负贡献。 self.wavelet_ll_weight = nn.Parameter(torch.tensor(1.0)) self.wavelet_edge_weight = nn.Parameter(torch.tensor(0.5)) self.wavelet_hh_weight = nn.Parameter(torch.tensor(-0.25)) self.stability_fourier_weight = nn.Parameter(torch.tensor(0.7)) self.stability_wavelet_weight = nn.Parameter(torch.tensor(0.3)) self.saliency_wavelet_weight = nn.Parameter(torch.tensor(1.0)) self.context_fourier_weight = nn.Parameter(torch.tensor(0.5)) self.context_wavelet_weight = nn.Parameter(torch.tensor(0.5)) self.alignment_residual_weight = nn.Parameter(torch.tensor(0.1)) self.register_buffer("gaussian_kernel", build_gaussian_lowpass(dim, sigma_ratio), persistent=False) def forward( self, patch_tokens: torch.Tensor, cls_token: torch.Tensor | None = None, return_debug: bool = False, ): B, N, C = patch_tokens.shape H, W = self.grid_size if N != H * W: raise ValueError(f"patch count mismatch: got N={N}, expected H*W={H * W}") if C != self.dim: raise ValueError(f"channel mismatch: got C={C}, expected dim={self.dim}") fourier_score = self._fourier_stability_score(patch_tokens) wavelet_score = self._wavelet_saliency_score(patch_tokens) initial_global_token = self._build_global_token( patch_tokens, fourier_score=fourier_score, wavelet_score=wavelet_score, cls_token=cls_token, ) stability_prior = self._build_stability_prior(fourier_score, wavelet_score) saliency_prior = self._build_saliency_prior(wavelet_score) fused_input = torch.stack([fourier_score, wavelet_score], dim=-1) # [B, N, 2] fused_score = self.score_fuser(fused_input).squeeze(-1) # [B, N] if self.use_cls_conditioning: cls_alignment = self._cls_alignment_score(initial_global_token.detach(), patch_tokens) fused_score = fused_score + self.alignment_residual_weight * cls_alignment gate = torch.softmax(fused_score / max(self.gate_temperature, self.eps), dim=1) pooled_token = torch.sum(gate.unsqueeze(-1) * patch_tokens, dim=1) # [B, C] pooled_token = self.token_proj(pooled_token) cls_out = initial_global_token + self.residual_scale * pooled_token cls_out = self.out_norm(cls_out) if return_debug: debug = FWTADebug( initial_global_token=initial_global_token, fourier_score=fourier_score, wavelet_score=wavelet_score, stability_prior=stability_prior, saliency_prior=saliency_prior, fused_score=fused_score, gate=gate, pooled_token=pooled_token, ) return cls_out, gate, debug return cls_out, gate def get_stability_map(self, patch_tokens: torch.Tensor) -> torch.Tensor: """ 为分割任务提供二维稳定性先验图接口。 Returns: Tensor of shape [B, 1, H, W]. """ _, _, debug = self.forward( patch_tokens=patch_tokens, return_debug=True, ) return self._score_to_map(debug.stability_prior, patch_tokens.shape[0]) def forward_with_map( self, patch_tokens: torch.Tensor, cls_token: torch.Tensor | None = None, return_debug: bool = False, ): """ 同时返回 CLS 更新结果、门控权重以及二维稳定性图。 """ outputs = self.forward(patch_tokens, cls_token=cls_token, return_debug=return_debug) H, W = self.grid_size if return_debug: cls_out, gate, debug = outputs stability_map = self._score_to_map(debug.stability_prior, patch_tokens.shape[0]) saliency_map = self._score_to_map(debug.saliency_prior, patch_tokens.shape[0]) return cls_out, gate, stability_map, saliency_map, debug cls_out, gate = outputs stability_map = self._score_to_map(self._build_stability_prior( self._fourier_stability_score(patch_tokens), self._wavelet_saliency_score(patch_tokens), ), patch_tokens.shape[0]) saliency_map = self._score_to_map(self._build_saliency_prior( self._wavelet_saliency_score(patch_tokens) ), patch_tokens.shape[0]) return cls_out, gate, stability_map, saliency_map def _build_global_token( self, patch_tokens: torch.Tensor, fourier_score: torch.Tensor, wavelet_score: torch.Tensor, cls_token: torch.Tensor | None = None, ) -> torch.Tensor: if cls_token is not None: return cls_token if not self.learnable_global_token: return patch_tokens.mean(dim=1) batch_size, _, channels = patch_tokens.shape token = self.base_global_token.expand(batch_size, channels) if self.global_context_proj is not None: pre_context_gate = self._build_context_gate(fourier_score, wavelet_score) image_context = torch.sum(pre_context_gate.unsqueeze(-1) * patch_tokens, dim=1) token = token + self.global_context_proj(image_context) return self.global_token_norm(token) def _fourier_stability_score(self, patch_tokens: torch.Tensor) -> torch.Tensor: """ 通过通道级低通滤波后的变化量来评分令牌。 Higher score => more stable token => more likely to carry coherent semantics. """ kernel = self.gaussian_kernel.to(device=patch_tokens.device, dtype=patch_tokens.dtype) xf = torch.fft.fft(patch_tokens, dim=-1) xf = torch.fft.fftshift(xf, dim=-1) xf_low = xf * kernel xf_low = torch.fft.ifftshift(xf_low, dim=-1) x_low = torch.fft.ifft(xf_low, dim=-1).real delta = torch.mean(torch.abs(patch_tokens - x_low), dim=-1) # [B, N] score = torch.exp(-delta / max(self.tau_fourier, self.eps)) return score def _wavelet_saliency_score(self, patch_tokens: torch.Tensor) -> torch.Tensor: """ 使用 Token-Grid 小波分解来估计结构前景显著性。 The patch tokens are treated as a low-resolution feature map [B, C, H, W]. """ B, N, C = patch_tokens.shape H, W = self.grid_size x2d = patch_tokens.transpose(1, 2).reshape(B, C, H, W) coeffs = ptwt.wavedec2(x2d, self.wavelet, level=self.wavelet_level) ll = coeffs[0] detail_coeffs = coeffs[1:] ll_energy = ll.abs().mean(dim=1, keepdim=True) ll_energy = F.interpolate(ll_energy, size=(H, W), mode="nearest") edge_energy = torch.zeros_like(ll_energy) hh_energy = torch.zeros_like(ll_energy) for level_detail in detail_coeffs: lh, hl, hh = level_detail level_edge = 0.5 * (lh.abs().mean(dim=1, keepdim=True) + hl.abs().mean(dim=1, keepdim=True)) level_hh = hh.abs().mean(dim=1, keepdim=True) target_size = (H, W) level_edge = F.interpolate(level_edge, size=target_size, mode="nearest") level_hh = F.interpolate(level_hh, size=target_size, mode="nearest") edge_energy = edge_energy + level_edge hh_energy = hh_energy + level_hh raw_score = ( self.wavelet_ll_weight * ll_energy + self.wavelet_edge_weight * edge_energy + self.wavelet_hh_weight * hh_energy ) raw_score = raw_score.flatten(1) # [B, N] score = torch.sigmoid(raw_score) return score def _build_stability_prior( self, fourier_score: torch.Tensor, wavelet_score: torch.Tensor, ) -> torch.Tensor: raw = ( self.stability_fourier_weight * fourier_score + self.stability_wavelet_weight * wavelet_score ) return torch.sigmoid(raw) def _build_saliency_prior(self, wavelet_score: torch.Tensor) -> torch.Tensor: raw = self.saliency_wavelet_weight * wavelet_score return torch.sigmoid(raw) def _build_context_gate( self, fourier_score: torch.Tensor, wavelet_score: torch.Tensor, ) -> torch.Tensor: context_score = ( self.context_fourier_weight * fourier_score + self.context_wavelet_weight * wavelet_score ) return torch.softmax(context_score / max(self.gate_temperature, self.eps), dim=1) def _score_to_map(self, score: torch.Tensor, batch_size: int) -> torch.Tensor: H, W = self.grid_size return score.reshape(batch_size, 1, H, W) def _cls_alignment_score(self, cls_token: torch.Tensor, patch_tokens: torch.Tensor) -> torch.Tensor: """ 可选稳定器:偏好已与现有 CLS 令牌对齐的令牌。 这有助于模块作为修正项而不是完全独立的分支发挥作用。 """ cls_norm = F.normalize(cls_token, dim=-1) patch_norm = F.normalize(patch_tokens, dim=-1) score = torch.sum(patch_norm * cls_norm.unsqueeze(1), dim=-1) score = 0.5 * (score + 1.0) # map cosine similarity from [-1, 1] to [0, 1] return score class ViTBlockWithFWTA(nn.Module): """ 最小包装器,展示如何在 Transformer Block 后插入 FWTA。 Expected input: x: [B, 1 + N, C] Output: x: [B, 1 + N, C] """ def __init__(self, block: nn.Module, dim: int, grid_size: Tuple[int, int]) -> None: super().__init__() self.block = block self.fwta = FourierWaveletTokenAggregation(dim=dim, grid_size=grid_size) def forward(self, x: torch.Tensor): x = self.block(x) cls_token = x[:, 0] patch_tokens = x[:, 1:] cls_token, gate = self.fwta(cls_token, patch_tokens) x = torch.cat([cls_token.unsqueeze(1), patch_tokens], dim=1) return x, gate class FinalAggregatorWithFWTA(nn.Module): """ 适用于 torchvision / timm 风格 ViT 的更简单变体: 保持所有 Encoder Block 不变,仅在最后应用 FWTA。 """ def __init__(self, dim: int, grid_size: Tuple[int, int], num_classes: int) -> None: super().__init__() self.fwta = FourierWaveletTokenAggregation(dim=dim, grid_size=grid_size) self.head = nn.Linear(dim, num_classes) def forward(self, encoder_tokens: torch.Tensor): cls_token = encoder_tokens[:, 0] patch_tokens = encoder_tokens[:, 1:] cls_token, gate = self.fwta(cls_token, patch_tokens) logits = self.head(cls_token) return logits, gate