| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302 |
- from __future__ import annotations
- from dataclasses import dataclass
- from typing import Optional, Tuple
- import ptwt
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- def build_gaussian_lowpass(
- channels: int,
- sigma_ratio: float = 0.35,
- device: Optional[torch.device] = None,
- dtype: Optional[torch.dtype] = None,
- ) -> torch.Tensor:
- """
- 构建用于通道维度的 1D 高斯低通滤波器。
- Returns:
- Tensor of shape [1, 1, C].
- """
- sigma = max(channels * sigma_ratio, 1.0)
- center = (channels - 1) / 2.0
- coords = torch.arange(channels, device=device, dtype=dtype or torch.float32)
- kernel = torch.exp(-0.5 * ((coords - center) / sigma) ** 2)
- kernel = kernel / kernel.max().clamp_min(1e-6)
- return kernel.view(1, 1, channels)
- @dataclass
- class FWTADebug:
- fourier_score: torch.Tensor
- wavelet_score: torch.Tensor
- fused_score: torch.Tensor
- gate: torch.Tensor
- pooled_token: torch.Tensor
- class FourierWaveletTokenAggregation(nn.Module):
- """
- 傅里叶 - 小波令牌聚合模块。
- Inputs:
- cls_token: [B, C]
- patch_tokens: [B, N, C]
- Output:
- cls_out: [B, C]
- gate: [B, N]
- Design:
- - Fourier branch estimates token-wise semantic stability.
- - Wavelet branch estimates token-wise structural saliency.
- - Fused score produces a softmax gate over tokens.
- - Weighted pooled token is added back to the CLS token by residual update.
- """
- def __init__(
- self,
- dim: int,
- grid_size: Tuple[int, int],
- wavelet: str = "haar",
- wavelet_level: int = 1,
- sigma_ratio: float = 0.35,
- tau_fourier: float = 0.15,
- gate_temperature: float = 1.0,
- residual_scale_init: float = 1.0,
- fusion_hidden_ratio: float = 0.5,
- use_cls_conditioning: bool = True,
- eps: float = 1e-6,
- ) -> None:
- super().__init__()
- self.dim = dim
- self.grid_size = grid_size
- self.wavelet = wavelet
- self.wavelet_level = wavelet_level
- self.sigma_ratio = sigma_ratio
- self.tau_fourier = tau_fourier
- self.gate_temperature = gate_temperature
- self.use_cls_conditioning = use_cls_conditioning
- self.eps = eps
- hidden_dim = max(int(dim * fusion_hidden_ratio), 32)
- fuse_in_dim = 3 if use_cls_conditioning else 2
- self.score_fuser = nn.Sequential(
- nn.Linear(fuse_in_dim, hidden_dim),
- nn.GELU(),
- nn.Linear(hidden_dim, 1),
- )
- self.token_proj = nn.Sequential(
- nn.LayerNorm(dim),
- nn.Linear(dim, dim),
- nn.GELU(),
- nn.Linear(dim, dim),
- )
- self.out_norm = nn.LayerNorm(dim)
- self.residual_scale = nn.Parameter(torch.tensor(float(residual_scale_init)))
- # 学习系数以平衡粗结构、边缘线索和噪声。
- self.wavelet_ll_weight = nn.Parameter(torch.tensor(1.0))
- self.wavelet_edge_weight = nn.Parameter(torch.tensor(0.5))
- self.wavelet_noise_weight = nn.Parameter(torch.tensor(0.5))
- self.register_buffer("gaussian_kernel", build_gaussian_lowpass(dim, sigma_ratio), persistent=False)
- def forward(
- self,
- cls_token: torch.Tensor,
- patch_tokens: torch.Tensor,
- return_debug: bool = False,
- ):
- B, N, C = patch_tokens.shape
- H, W = self.grid_size
- if N != H * W:
- raise ValueError(f"patch count mismatch: got N={N}, expected H*W={H * W}")
- if C != self.dim:
- raise ValueError(f"channel mismatch: got C={C}, expected dim={self.dim}")
- fourier_score = self._fourier_stability_score(patch_tokens)
- wavelet_score = self._wavelet_saliency_score(patch_tokens)
- fuse_inputs = [fourier_score, wavelet_score]
- if self.use_cls_conditioning:
- cls_alignment = self._cls_alignment_score(cls_token, patch_tokens)
- fuse_inputs.append(cls_alignment)
- fused_input = torch.stack(fuse_inputs, dim=-1) # [B, N, 2 or 3]
- fused_score = self.score_fuser(fused_input).squeeze(-1) # [B, N]
- gate = torch.softmax(fused_score / max(self.gate_temperature, self.eps), dim=1)
- pooled_token = torch.sum(gate.unsqueeze(-1) * patch_tokens, dim=1) # [B, C]
- pooled_token = self.token_proj(pooled_token)
- cls_out = cls_token + self.residual_scale * pooled_token
- cls_out = self.out_norm(cls_out)
- if return_debug:
- debug = FWTADebug(
- fourier_score=fourier_score,
- wavelet_score=wavelet_score,
- fused_score=fused_score,
- gate=gate,
- pooled_token=pooled_token,
- )
- return cls_out, gate, debug
- return cls_out, gate
- def get_stability_map(self, patch_tokens: torch.Tensor) -> torch.Tensor:
- """
- 为分割任务提供二维稳定性图接口。
- Returns:
- Tensor of shape [B, 1, H, W].
- """
- _, gate = self.forward(
- cls_token=patch_tokens.mean(dim=1),
- patch_tokens=patch_tokens,
- return_debug=False,
- )
- H, W = self.grid_size
- return gate.reshape(patch_tokens.shape[0], 1, H, W)
- def forward_with_map(
- self,
- cls_token: torch.Tensor,
- patch_tokens: torch.Tensor,
- return_debug: bool = False,
- ):
- """
- 同时返回 CLS 更新结果、门控权重以及二维稳定性图。
- """
- outputs = self.forward(cls_token, patch_tokens, return_debug=return_debug)
- H, W = self.grid_size
- if return_debug:
- cls_out, gate, debug = outputs
- stability_map = gate.reshape(patch_tokens.shape[0], 1, H, W)
- return cls_out, gate, stability_map, debug
- cls_out, gate = outputs
- stability_map = gate.reshape(patch_tokens.shape[0], 1, H, W)
- return cls_out, gate, stability_map
- def _fourier_stability_score(self, patch_tokens: torch.Tensor) -> torch.Tensor:
- """
- 通过通道级低通滤波后的变化量来评分令牌。
- Higher score => more stable token => more likely to carry coherent semantics.
- """
- kernel = self.gaussian_kernel.to(device=patch_tokens.device, dtype=patch_tokens.dtype)
- xf = torch.fft.fft(patch_tokens, dim=-1)
- xf = torch.fft.fftshift(xf, dim=-1)
- xf_low = xf * kernel
- xf_low = torch.fft.ifftshift(xf_low, dim=-1)
- x_low = torch.fft.ifft(xf_low, dim=-1).real
- delta = torch.mean(torch.abs(patch_tokens - x_low), dim=-1) # [B, N]
- score = torch.exp(-delta / max(self.tau_fourier, self.eps))
- return score
- def _wavelet_saliency_score(self, patch_tokens: torch.Tensor) -> torch.Tensor:
- """
- 使用 Token-Grid 小波分解来估计结构前景显著性。
- The patch tokens are treated as a low-resolution feature map [B, C, H, W].
- """
- B, N, C = patch_tokens.shape
- H, W = self.grid_size
- x2d = patch_tokens.transpose(1, 2).reshape(B, C, H, W)
- coeffs = ptwt.wavedec2(x2d, self.wavelet, level=self.wavelet_level)
- ll = coeffs[0]
- detail_coeffs = coeffs[1:]
- ll_energy = ll.abs().mean(dim=1, keepdim=True)
- ll_energy = F.interpolate(ll_energy, size=(H, W), mode="nearest")
- edge_energy = torch.zeros_like(ll_energy)
- noise_energy = torch.zeros_like(ll_energy)
- for level_detail in detail_coeffs:
- lh, hl, hh = level_detail
- level_edge = 0.5 * (lh.abs().mean(dim=1, keepdim=True) + hl.abs().mean(dim=1, keepdim=True))
- level_noise = hh.abs().mean(dim=1, keepdim=True)
- target_size = (H, W)
- level_edge = F.interpolate(level_edge, size=target_size, mode="nearest")
- level_noise = F.interpolate(level_noise, size=target_size, mode="nearest")
- edge_energy = edge_energy + level_edge
- noise_energy = noise_energy + level_noise
- raw_score = (
- self.wavelet_ll_weight * ll_energy
- + self.wavelet_edge_weight * edge_energy
- - self.wavelet_noise_weight * noise_energy
- )
- raw_score = raw_score.flatten(1) # [B, N]
- score = torch.sigmoid(raw_score)
- return score
- def _cls_alignment_score(self, cls_token: torch.Tensor, patch_tokens: torch.Tensor) -> torch.Tensor:
- """
- 可选稳定器:偏好已与现有 CLS 令牌对齐的令牌。
- 这有助于模块作为修正项而不是完全独立的分支发挥作用。
- """
- cls_norm = F.normalize(cls_token, dim=-1)
- patch_norm = F.normalize(patch_tokens, dim=-1)
- score = torch.sum(patch_norm * cls_norm.unsqueeze(1), dim=-1)
- score = 0.5 * (score + 1.0) # map cosine similarity from [-1, 1] to [0, 1]
- return score
- class ViTBlockWithFWTA(nn.Module):
- """
- 最小包装器,展示如何在 Transformer Block 后插入 FWTA。
- Expected input:
- x: [B, 1 + N, C]
- Output:
- x: [B, 1 + N, C]
- """
- def __init__(self, block: nn.Module, dim: int, grid_size: Tuple[int, int]) -> None:
- super().__init__()
- self.block = block
- self.fwta = FourierWaveletTokenAggregation(dim=dim, grid_size=grid_size)
- def forward(self, x: torch.Tensor):
- x = self.block(x)
- cls_token = x[:, 0]
- patch_tokens = x[:, 1:]
- cls_token, gate = self.fwta(cls_token, patch_tokens)
- x = torch.cat([cls_token.unsqueeze(1), patch_tokens], dim=1)
- return x, gate
- class FinalAggregatorWithFWTA(nn.Module):
- """
- 适用于 torchvision / timm 风格 ViT 的更简单变体:
- 保持所有 Encoder Block 不变,仅在最后应用 FWTA。
- """
- def __init__(self, dim: int, grid_size: Tuple[int, int], num_classes: int) -> None:
- super().__init__()
- self.fwta = FourierWaveletTokenAggregation(dim=dim, grid_size=grid_size)
- self.head = nn.Linear(dim, num_classes)
- def forward(self, encoder_tokens: torch.Tensor):
- cls_token = encoder_tokens[:, 0]
- patch_tokens = encoder_tokens[:, 1:]
- cls_token, gate = self.fwta(cls_token, patch_tokens)
- logits = self.head(cls_token)
- return logits, gate
|