|
@@ -31,8 +31,11 @@ def build_gaussian_lowpass(
|
|
|
|
|
|
|
|
@dataclass
|
|
@dataclass
|
|
|
class FWTADebug:
|
|
class FWTADebug:
|
|
|
|
|
+ initial_global_token: torch.Tensor
|
|
|
fourier_score: torch.Tensor
|
|
fourier_score: torch.Tensor
|
|
|
wavelet_score: torch.Tensor
|
|
wavelet_score: torch.Tensor
|
|
|
|
|
+ stability_prior: torch.Tensor
|
|
|
|
|
+ saliency_prior: torch.Tensor
|
|
|
fused_score: torch.Tensor
|
|
fused_score: torch.Tensor
|
|
|
gate: torch.Tensor
|
|
gate: torch.Tensor
|
|
|
pooled_token: torch.Tensor
|
|
pooled_token: torch.Tensor
|
|
@@ -69,6 +72,8 @@ class FourierWaveletTokenAggregation(nn.Module):
|
|
|
residual_scale_init: float = 1.0,
|
|
residual_scale_init: float = 1.0,
|
|
|
fusion_hidden_ratio: float = 0.5,
|
|
fusion_hidden_ratio: float = 0.5,
|
|
|
use_cls_conditioning: bool = True,
|
|
use_cls_conditioning: bool = True,
|
|
|
|
|
+ learnable_global_token: bool = True,
|
|
|
|
|
+ global_token_use_image_conditioning: bool = True,
|
|
|
eps: float = 1e-6,
|
|
eps: float = 1e-6,
|
|
|
) -> None:
|
|
) -> None:
|
|
|
super().__init__()
|
|
super().__init__()
|
|
@@ -80,6 +85,8 @@ class FourierWaveletTokenAggregation(nn.Module):
|
|
|
self.tau_fourier = tau_fourier
|
|
self.tau_fourier = tau_fourier
|
|
|
self.gate_temperature = gate_temperature
|
|
self.gate_temperature = gate_temperature
|
|
|
self.use_cls_conditioning = use_cls_conditioning
|
|
self.use_cls_conditioning = use_cls_conditioning
|
|
|
|
|
+ self.learnable_global_token = learnable_global_token
|
|
|
|
|
+ self.global_token_use_image_conditioning = global_token_use_image_conditioning
|
|
|
self.eps = eps
|
|
self.eps = eps
|
|
|
|
|
|
|
|
hidden_dim = max(int(dim * fusion_hidden_ratio), 32)
|
|
hidden_dim = max(int(dim * fusion_hidden_ratio), 32)
|
|
@@ -101,17 +108,42 @@ class FourierWaveletTokenAggregation(nn.Module):
|
|
|
self.out_norm = nn.LayerNorm(dim)
|
|
self.out_norm = nn.LayerNorm(dim)
|
|
|
self.residual_scale = nn.Parameter(torch.tensor(float(residual_scale_init)))
|
|
self.residual_scale = nn.Parameter(torch.tensor(float(residual_scale_init)))
|
|
|
|
|
|
|
|
- # 学习系数以平衡粗结构、边缘线索和噪声。
|
|
|
|
|
|
|
+ self.base_global_token = nn.Parameter(torch.zeros(1, dim))
|
|
|
|
|
+ nn.init.trunc_normal_(self.base_global_token, std=0.02)
|
|
|
|
|
+ if learnable_global_token and global_token_use_image_conditioning:
|
|
|
|
|
+ self.global_context_proj = nn.Sequential(
|
|
|
|
|
+ nn.LayerNorm(dim),
|
|
|
|
|
+ nn.Linear(dim, dim),
|
|
|
|
|
+ nn.GELU(),
|
|
|
|
|
+ nn.Linear(dim, dim),
|
|
|
|
|
+ )
|
|
|
|
|
+ self.global_token_norm = nn.LayerNorm(dim)
|
|
|
|
|
+ elif learnable_global_token:
|
|
|
|
|
+ self.global_context_proj = None
|
|
|
|
|
+ self.global_token_norm = nn.LayerNorm(dim)
|
|
|
|
|
+ else:
|
|
|
|
|
+ self.global_context_proj = None
|
|
|
|
|
+ self.global_token_norm = nn.Identity()
|
|
|
|
|
+
|
|
|
|
|
+ # 学习系数以平衡粗结构、边缘线索和高频细节。
|
|
|
|
|
+ # 注意:HH 子带不被预设为纯噪声,而是允许模型学习其正负贡献。
|
|
|
self.wavelet_ll_weight = nn.Parameter(torch.tensor(1.0))
|
|
self.wavelet_ll_weight = nn.Parameter(torch.tensor(1.0))
|
|
|
self.wavelet_edge_weight = nn.Parameter(torch.tensor(0.5))
|
|
self.wavelet_edge_weight = nn.Parameter(torch.tensor(0.5))
|
|
|
- self.wavelet_noise_weight = nn.Parameter(torch.tensor(0.5))
|
|
|
|
|
|
|
+ self.wavelet_hh_weight = nn.Parameter(torch.tensor(-0.25))
|
|
|
|
|
+
|
|
|
|
|
+ self.stability_fourier_weight = nn.Parameter(torch.tensor(0.7))
|
|
|
|
|
+ self.stability_wavelet_weight = nn.Parameter(torch.tensor(0.3))
|
|
|
|
|
+ self.saliency_wavelet_weight = nn.Parameter(torch.tensor(1.0))
|
|
|
|
|
+ self.context_fourier_weight = nn.Parameter(torch.tensor(0.5))
|
|
|
|
|
+ self.context_wavelet_weight = nn.Parameter(torch.tensor(0.5))
|
|
|
|
|
+ self.alignment_residual_weight = nn.Parameter(torch.tensor(0.1))
|
|
|
|
|
|
|
|
self.register_buffer("gaussian_kernel", build_gaussian_lowpass(dim, sigma_ratio), persistent=False)
|
|
self.register_buffer("gaussian_kernel", build_gaussian_lowpass(dim, sigma_ratio), persistent=False)
|
|
|
|
|
|
|
|
def forward(
|
|
def forward(
|
|
|
self,
|
|
self,
|
|
|
- cls_token: torch.Tensor,
|
|
|
|
|
patch_tokens: torch.Tensor,
|
|
patch_tokens: torch.Tensor,
|
|
|
|
|
+ cls_token: torch.Tensor | None = None,
|
|
|
return_debug: bool = False,
|
|
return_debug: bool = False,
|
|
|
):
|
|
):
|
|
|
B, N, C = patch_tokens.shape
|
|
B, N, C = patch_tokens.shape
|
|
@@ -123,26 +155,35 @@ class FourierWaveletTokenAggregation(nn.Module):
|
|
|
|
|
|
|
|
fourier_score = self._fourier_stability_score(patch_tokens)
|
|
fourier_score = self._fourier_stability_score(patch_tokens)
|
|
|
wavelet_score = self._wavelet_saliency_score(patch_tokens)
|
|
wavelet_score = self._wavelet_saliency_score(patch_tokens)
|
|
|
|
|
+ initial_global_token = self._build_global_token(
|
|
|
|
|
+ patch_tokens,
|
|
|
|
|
+ fourier_score=fourier_score,
|
|
|
|
|
+ wavelet_score=wavelet_score,
|
|
|
|
|
+ cls_token=cls_token,
|
|
|
|
|
+ )
|
|
|
|
|
+ stability_prior = self._build_stability_prior(fourier_score, wavelet_score)
|
|
|
|
|
+ saliency_prior = self._build_saliency_prior(wavelet_score)
|
|
|
|
|
|
|
|
- fuse_inputs = [fourier_score, wavelet_score]
|
|
|
|
|
- if self.use_cls_conditioning:
|
|
|
|
|
- cls_alignment = self._cls_alignment_score(cls_token, patch_tokens)
|
|
|
|
|
- fuse_inputs.append(cls_alignment)
|
|
|
|
|
-
|
|
|
|
|
- fused_input = torch.stack(fuse_inputs, dim=-1) # [B, N, 2 or 3]
|
|
|
|
|
|
|
+ fused_input = torch.stack([fourier_score, wavelet_score], dim=-1) # [B, N, 2]
|
|
|
fused_score = self.score_fuser(fused_input).squeeze(-1) # [B, N]
|
|
fused_score = self.score_fuser(fused_input).squeeze(-1) # [B, N]
|
|
|
|
|
+ if self.use_cls_conditioning:
|
|
|
|
|
+ cls_alignment = self._cls_alignment_score(initial_global_token.detach(), patch_tokens)
|
|
|
|
|
+ fused_score = fused_score + self.alignment_residual_weight * cls_alignment
|
|
|
gate = torch.softmax(fused_score / max(self.gate_temperature, self.eps), dim=1)
|
|
gate = torch.softmax(fused_score / max(self.gate_temperature, self.eps), dim=1)
|
|
|
|
|
|
|
|
pooled_token = torch.sum(gate.unsqueeze(-1) * patch_tokens, dim=1) # [B, C]
|
|
pooled_token = torch.sum(gate.unsqueeze(-1) * patch_tokens, dim=1) # [B, C]
|
|
|
pooled_token = self.token_proj(pooled_token)
|
|
pooled_token = self.token_proj(pooled_token)
|
|
|
|
|
|
|
|
- cls_out = cls_token + self.residual_scale * pooled_token
|
|
|
|
|
|
|
+ cls_out = initial_global_token + self.residual_scale * pooled_token
|
|
|
cls_out = self.out_norm(cls_out)
|
|
cls_out = self.out_norm(cls_out)
|
|
|
|
|
|
|
|
if return_debug:
|
|
if return_debug:
|
|
|
debug = FWTADebug(
|
|
debug = FWTADebug(
|
|
|
|
|
+ initial_global_token=initial_global_token,
|
|
|
fourier_score=fourier_score,
|
|
fourier_score=fourier_score,
|
|
|
wavelet_score=wavelet_score,
|
|
wavelet_score=wavelet_score,
|
|
|
|
|
+ stability_prior=stability_prior,
|
|
|
|
|
+ saliency_prior=saliency_prior,
|
|
|
fused_score=fused_score,
|
|
fused_score=fused_score,
|
|
|
gate=gate,
|
|
gate=gate,
|
|
|
pooled_token=pooled_token,
|
|
pooled_token=pooled_token,
|
|
@@ -152,39 +193,65 @@ class FourierWaveletTokenAggregation(nn.Module):
|
|
|
|
|
|
|
|
def get_stability_map(self, patch_tokens: torch.Tensor) -> torch.Tensor:
|
|
def get_stability_map(self, patch_tokens: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
"""
|
|
|
- 为分割任务提供二维稳定性图接口。
|
|
|
|
|
|
|
+ 为分割任务提供二维稳定性先验图接口。
|
|
|
|
|
|
|
|
Returns:
|
|
Returns:
|
|
|
Tensor of shape [B, 1, H, W].
|
|
Tensor of shape [B, 1, H, W].
|
|
|
"""
|
|
"""
|
|
|
- _, gate = self.forward(
|
|
|
|
|
- cls_token=patch_tokens.mean(dim=1),
|
|
|
|
|
|
|
+ _, _, debug = self.forward(
|
|
|
patch_tokens=patch_tokens,
|
|
patch_tokens=patch_tokens,
|
|
|
- return_debug=False,
|
|
|
|
|
|
|
+ return_debug=True,
|
|
|
)
|
|
)
|
|
|
- H, W = self.grid_size
|
|
|
|
|
- return gate.reshape(patch_tokens.shape[0], 1, H, W)
|
|
|
|
|
|
|
+ return self._score_to_map(debug.stability_prior, patch_tokens.shape[0])
|
|
|
|
|
|
|
|
def forward_with_map(
|
|
def forward_with_map(
|
|
|
self,
|
|
self,
|
|
|
- cls_token: torch.Tensor,
|
|
|
|
|
patch_tokens: torch.Tensor,
|
|
patch_tokens: torch.Tensor,
|
|
|
|
|
+ cls_token: torch.Tensor | None = None,
|
|
|
return_debug: bool = False,
|
|
return_debug: bool = False,
|
|
|
):
|
|
):
|
|
|
"""
|
|
"""
|
|
|
同时返回 CLS 更新结果、门控权重以及二维稳定性图。
|
|
同时返回 CLS 更新结果、门控权重以及二维稳定性图。
|
|
|
"""
|
|
"""
|
|
|
- outputs = self.forward(cls_token, patch_tokens, return_debug=return_debug)
|
|
|
|
|
|
|
+ outputs = self.forward(patch_tokens, cls_token=cls_token, return_debug=return_debug)
|
|
|
H, W = self.grid_size
|
|
H, W = self.grid_size
|
|
|
|
|
|
|
|
if return_debug:
|
|
if return_debug:
|
|
|
cls_out, gate, debug = outputs
|
|
cls_out, gate, debug = outputs
|
|
|
- stability_map = gate.reshape(patch_tokens.shape[0], 1, H, W)
|
|
|
|
|
- return cls_out, gate, stability_map, debug
|
|
|
|
|
|
|
+ stability_map = self._score_to_map(debug.stability_prior, patch_tokens.shape[0])
|
|
|
|
|
+ saliency_map = self._score_to_map(debug.saliency_prior, patch_tokens.shape[0])
|
|
|
|
|
+ return cls_out, gate, stability_map, saliency_map, debug
|
|
|
|
|
|
|
|
cls_out, gate = outputs
|
|
cls_out, gate = outputs
|
|
|
- stability_map = gate.reshape(patch_tokens.shape[0], 1, H, W)
|
|
|
|
|
- return cls_out, gate, stability_map
|
|
|
|
|
|
|
+ stability_map = self._score_to_map(self._build_stability_prior(
|
|
|
|
|
+ self._fourier_stability_score(patch_tokens),
|
|
|
|
|
+ self._wavelet_saliency_score(patch_tokens),
|
|
|
|
|
+ ), patch_tokens.shape[0])
|
|
|
|
|
+ saliency_map = self._score_to_map(self._build_saliency_prior(
|
|
|
|
|
+ self._wavelet_saliency_score(patch_tokens)
|
|
|
|
|
+ ), patch_tokens.shape[0])
|
|
|
|
|
+ return cls_out, gate, stability_map, saliency_map
|
|
|
|
|
+
|
|
|
|
|
+ def _build_global_token(
|
|
|
|
|
+ self,
|
|
|
|
|
+ patch_tokens: torch.Tensor,
|
|
|
|
|
+ fourier_score: torch.Tensor,
|
|
|
|
|
+ wavelet_score: torch.Tensor,
|
|
|
|
|
+ cls_token: torch.Tensor | None = None,
|
|
|
|
|
+ ) -> torch.Tensor:
|
|
|
|
|
+ if cls_token is not None:
|
|
|
|
|
+ return cls_token
|
|
|
|
|
+
|
|
|
|
|
+ if not self.learnable_global_token:
|
|
|
|
|
+ return patch_tokens.mean(dim=1)
|
|
|
|
|
+
|
|
|
|
|
+ batch_size, _, channels = patch_tokens.shape
|
|
|
|
|
+ token = self.base_global_token.expand(batch_size, channels)
|
|
|
|
|
+ if self.global_context_proj is not None:
|
|
|
|
|
+ pre_context_gate = self._build_context_gate(fourier_score, wavelet_score)
|
|
|
|
|
+ image_context = torch.sum(pre_context_gate.unsqueeze(-1) * patch_tokens, dim=1)
|
|
|
|
|
+ token = token + self.global_context_proj(image_context)
|
|
|
|
|
+ return self.global_token_norm(token)
|
|
|
|
|
|
|
|
def _fourier_stability_score(self, patch_tokens: torch.Tensor) -> torch.Tensor:
|
|
def _fourier_stability_score(self, patch_tokens: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
"""
|
|
@@ -223,29 +290,59 @@ class FourierWaveletTokenAggregation(nn.Module):
|
|
|
ll_energy = F.interpolate(ll_energy, size=(H, W), mode="nearest")
|
|
ll_energy = F.interpolate(ll_energy, size=(H, W), mode="nearest")
|
|
|
|
|
|
|
|
edge_energy = torch.zeros_like(ll_energy)
|
|
edge_energy = torch.zeros_like(ll_energy)
|
|
|
- noise_energy = torch.zeros_like(ll_energy)
|
|
|
|
|
|
|
+ hh_energy = torch.zeros_like(ll_energy)
|
|
|
|
|
|
|
|
for level_detail in detail_coeffs:
|
|
for level_detail in detail_coeffs:
|
|
|
lh, hl, hh = level_detail
|
|
lh, hl, hh = level_detail
|
|
|
level_edge = 0.5 * (lh.abs().mean(dim=1, keepdim=True) + hl.abs().mean(dim=1, keepdim=True))
|
|
level_edge = 0.5 * (lh.abs().mean(dim=1, keepdim=True) + hl.abs().mean(dim=1, keepdim=True))
|
|
|
- level_noise = hh.abs().mean(dim=1, keepdim=True)
|
|
|
|
|
|
|
+ level_hh = hh.abs().mean(dim=1, keepdim=True)
|
|
|
|
|
|
|
|
target_size = (H, W)
|
|
target_size = (H, W)
|
|
|
level_edge = F.interpolate(level_edge, size=target_size, mode="nearest")
|
|
level_edge = F.interpolate(level_edge, size=target_size, mode="nearest")
|
|
|
- level_noise = F.interpolate(level_noise, size=target_size, mode="nearest")
|
|
|
|
|
|
|
+ level_hh = F.interpolate(level_hh, size=target_size, mode="nearest")
|
|
|
|
|
|
|
|
edge_energy = edge_energy + level_edge
|
|
edge_energy = edge_energy + level_edge
|
|
|
- noise_energy = noise_energy + level_noise
|
|
|
|
|
|
|
+ hh_energy = hh_energy + level_hh
|
|
|
|
|
|
|
|
raw_score = (
|
|
raw_score = (
|
|
|
self.wavelet_ll_weight * ll_energy
|
|
self.wavelet_ll_weight * ll_energy
|
|
|
+ self.wavelet_edge_weight * edge_energy
|
|
+ self.wavelet_edge_weight * edge_energy
|
|
|
- - self.wavelet_noise_weight * noise_energy
|
|
|
|
|
|
|
+ + self.wavelet_hh_weight * hh_energy
|
|
|
)
|
|
)
|
|
|
raw_score = raw_score.flatten(1) # [B, N]
|
|
raw_score = raw_score.flatten(1) # [B, N]
|
|
|
score = torch.sigmoid(raw_score)
|
|
score = torch.sigmoid(raw_score)
|
|
|
return score
|
|
return score
|
|
|
|
|
|
|
|
|
|
+ def _build_stability_prior(
|
|
|
|
|
+ self,
|
|
|
|
|
+ fourier_score: torch.Tensor,
|
|
|
|
|
+ wavelet_score: torch.Tensor,
|
|
|
|
|
+ ) -> torch.Tensor:
|
|
|
|
|
+ raw = (
|
|
|
|
|
+ self.stability_fourier_weight * fourier_score
|
|
|
|
|
+ + self.stability_wavelet_weight * wavelet_score
|
|
|
|
|
+ )
|
|
|
|
|
+ return torch.sigmoid(raw)
|
|
|
|
|
+
|
|
|
|
|
+ def _build_saliency_prior(self, wavelet_score: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
+ raw = self.saliency_wavelet_weight * wavelet_score
|
|
|
|
|
+ return torch.sigmoid(raw)
|
|
|
|
|
+
|
|
|
|
|
+ def _build_context_gate(
|
|
|
|
|
+ self,
|
|
|
|
|
+ fourier_score: torch.Tensor,
|
|
|
|
|
+ wavelet_score: torch.Tensor,
|
|
|
|
|
+ ) -> torch.Tensor:
|
|
|
|
|
+ context_score = (
|
|
|
|
|
+ self.context_fourier_weight * fourier_score
|
|
|
|
|
+ + self.context_wavelet_weight * wavelet_score
|
|
|
|
|
+ )
|
|
|
|
|
+ return torch.softmax(context_score / max(self.gate_temperature, self.eps), dim=1)
|
|
|
|
|
+
|
|
|
|
|
+ def _score_to_map(self, score: torch.Tensor, batch_size: int) -> torch.Tensor:
|
|
|
|
|
+ H, W = self.grid_size
|
|
|
|
|
+ return score.reshape(batch_size, 1, H, W)
|
|
|
|
|
+
|
|
|
def _cls_alignment_score(self, cls_token: torch.Tensor, patch_tokens: torch.Tensor) -> torch.Tensor:
|
|
def _cls_alignment_score(self, cls_token: torch.Tensor, patch_tokens: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
"""
|
|
|
可选稳定器:偏好已与现有 CLS 令牌对齐的令牌。
|
|
可选稳定器:偏好已与现有 CLS 令牌对齐的令牌。
|