Kaynağa Gözat

chore: organize training infrastructure and legacy encoders

kekezack 1 ay önce
ebeveyn
işleme
f9d0facd5d

+ 7 - 6
.gitignore

@@ -33,12 +33,13 @@ tmp/
 *.ckpt
 *.onnx
 
-# Logs & outputs
-*.log
-outputs/
-results/
-runs/
-lightning_logs/
+# Logs & outputs
+*.log
+outputs/
+results/
+runs/
+lightning_logs/
+swanlog/
 
 # Jupyter
 .ipynb_checkpoints/

+ 124 - 27
lib/modules/fwta_2d.py

@@ -31,8 +31,11 @@ def build_gaussian_lowpass(
 
 @dataclass
 class FWTADebug:
+    initial_global_token: torch.Tensor
     fourier_score: torch.Tensor
     wavelet_score: torch.Tensor
+    stability_prior: torch.Tensor
+    saliency_prior: torch.Tensor
     fused_score: torch.Tensor
     gate: torch.Tensor
     pooled_token: torch.Tensor
@@ -69,6 +72,8 @@ class FourierWaveletTokenAggregation(nn.Module):
             residual_scale_init: float = 1.0,
             fusion_hidden_ratio: float = 0.5,
             use_cls_conditioning: bool = True,
+            learnable_global_token: bool = True,
+            global_token_use_image_conditioning: bool = True,
             eps: float = 1e-6,
     ) -> None:
         super().__init__()
@@ -80,6 +85,8 @@ class FourierWaveletTokenAggregation(nn.Module):
         self.tau_fourier = tau_fourier
         self.gate_temperature = gate_temperature
         self.use_cls_conditioning = use_cls_conditioning
+        self.learnable_global_token = learnable_global_token
+        self.global_token_use_image_conditioning = global_token_use_image_conditioning
         self.eps = eps
 
         hidden_dim = max(int(dim * fusion_hidden_ratio), 32)
@@ -101,17 +108,42 @@ class FourierWaveletTokenAggregation(nn.Module):
         self.out_norm = nn.LayerNorm(dim)
         self.residual_scale = nn.Parameter(torch.tensor(float(residual_scale_init)))
 
-        # 学习系数以平衡粗结构、边缘线索和噪声。
+        self.base_global_token = nn.Parameter(torch.zeros(1, dim))
+        nn.init.trunc_normal_(self.base_global_token, std=0.02)
+        if learnable_global_token and global_token_use_image_conditioning:
+            self.global_context_proj = nn.Sequential(
+                nn.LayerNorm(dim),
+                nn.Linear(dim, dim),
+                nn.GELU(),
+                nn.Linear(dim, dim),
+            )
+            self.global_token_norm = nn.LayerNorm(dim)
+        elif learnable_global_token:
+            self.global_context_proj = None
+            self.global_token_norm = nn.LayerNorm(dim)
+        else:
+            self.global_context_proj = None
+            self.global_token_norm = nn.Identity()
+
+        # 学习系数以平衡粗结构、边缘线索和高频细节。
+        # 注意:HH 子带不被预设为纯噪声,而是允许模型学习其正负贡献。
         self.wavelet_ll_weight = nn.Parameter(torch.tensor(1.0))
         self.wavelet_edge_weight = nn.Parameter(torch.tensor(0.5))
-        self.wavelet_noise_weight = nn.Parameter(torch.tensor(0.5))
+        self.wavelet_hh_weight = nn.Parameter(torch.tensor(-0.25))
+
+        self.stability_fourier_weight = nn.Parameter(torch.tensor(0.7))
+        self.stability_wavelet_weight = nn.Parameter(torch.tensor(0.3))
+        self.saliency_wavelet_weight = nn.Parameter(torch.tensor(1.0))
+        self.context_fourier_weight = nn.Parameter(torch.tensor(0.5))
+        self.context_wavelet_weight = nn.Parameter(torch.tensor(0.5))
+        self.alignment_residual_weight = nn.Parameter(torch.tensor(0.1))
 
         self.register_buffer("gaussian_kernel", build_gaussian_lowpass(dim, sigma_ratio), persistent=False)
 
     def forward(
             self,
-            cls_token: torch.Tensor,
             patch_tokens: torch.Tensor,
+            cls_token: torch.Tensor | None = None,
             return_debug: bool = False,
     ):
         B, N, C = patch_tokens.shape
@@ -123,26 +155,35 @@ class FourierWaveletTokenAggregation(nn.Module):
 
         fourier_score = self._fourier_stability_score(patch_tokens)
         wavelet_score = self._wavelet_saliency_score(patch_tokens)
+        initial_global_token = self._build_global_token(
+            patch_tokens,
+            fourier_score=fourier_score,
+            wavelet_score=wavelet_score,
+            cls_token=cls_token,
+        )
+        stability_prior = self._build_stability_prior(fourier_score, wavelet_score)
+        saliency_prior = self._build_saliency_prior(wavelet_score)
 
-        fuse_inputs = [fourier_score, wavelet_score]
-        if self.use_cls_conditioning:
-            cls_alignment = self._cls_alignment_score(cls_token, patch_tokens)
-            fuse_inputs.append(cls_alignment)
-
-        fused_input = torch.stack(fuse_inputs, dim=-1)  # [B, N, 2 or 3]
+        fused_input = torch.stack([fourier_score, wavelet_score], dim=-1)  # [B, N, 2]
         fused_score = self.score_fuser(fused_input).squeeze(-1)  # [B, N]
+        if self.use_cls_conditioning:
+            cls_alignment = self._cls_alignment_score(initial_global_token.detach(), patch_tokens)
+            fused_score = fused_score + self.alignment_residual_weight * cls_alignment
         gate = torch.softmax(fused_score / max(self.gate_temperature, self.eps), dim=1)
 
         pooled_token = torch.sum(gate.unsqueeze(-1) * patch_tokens, dim=1)  # [B, C]
         pooled_token = self.token_proj(pooled_token)
 
-        cls_out = cls_token + self.residual_scale * pooled_token
+        cls_out = initial_global_token + self.residual_scale * pooled_token
         cls_out = self.out_norm(cls_out)
 
         if return_debug:
             debug = FWTADebug(
+                initial_global_token=initial_global_token,
                 fourier_score=fourier_score,
                 wavelet_score=wavelet_score,
+                stability_prior=stability_prior,
+                saliency_prior=saliency_prior,
                 fused_score=fused_score,
                 gate=gate,
                 pooled_token=pooled_token,
@@ -152,39 +193,65 @@ class FourierWaveletTokenAggregation(nn.Module):
 
     def get_stability_map(self, patch_tokens: torch.Tensor) -> torch.Tensor:
         """
-        为分割任务提供二维稳定性图接口。
+        为分割任务提供二维稳定性先验图接口。
 
         Returns:
             Tensor of shape [B, 1, H, W].
         """
-        _, gate = self.forward(
-            cls_token=patch_tokens.mean(dim=1),
+        _, _, debug = self.forward(
             patch_tokens=patch_tokens,
-            return_debug=False,
+            return_debug=True,
         )
-        H, W = self.grid_size
-        return gate.reshape(patch_tokens.shape[0], 1, H, W)
+        return self._score_to_map(debug.stability_prior, patch_tokens.shape[0])
 
     def forward_with_map(
             self,
-            cls_token: torch.Tensor,
             patch_tokens: torch.Tensor,
+            cls_token: torch.Tensor | None = None,
             return_debug: bool = False,
     ):
         """
         同时返回 CLS 更新结果、门控权重以及二维稳定性图。
         """
-        outputs = self.forward(cls_token, patch_tokens, return_debug=return_debug)
+        outputs = self.forward(patch_tokens, cls_token=cls_token, return_debug=return_debug)
         H, W = self.grid_size
 
         if return_debug:
             cls_out, gate, debug = outputs
-            stability_map = gate.reshape(patch_tokens.shape[0], 1, H, W)
-            return cls_out, gate, stability_map, debug
+            stability_map = self._score_to_map(debug.stability_prior, patch_tokens.shape[0])
+            saliency_map = self._score_to_map(debug.saliency_prior, patch_tokens.shape[0])
+            return cls_out, gate, stability_map, saliency_map, debug
 
         cls_out, gate = outputs
-        stability_map = gate.reshape(patch_tokens.shape[0], 1, H, W)
-        return cls_out, gate, stability_map
+        stability_map = self._score_to_map(self._build_stability_prior(
+            self._fourier_stability_score(patch_tokens),
+            self._wavelet_saliency_score(patch_tokens),
+        ), patch_tokens.shape[0])
+        saliency_map = self._score_to_map(self._build_saliency_prior(
+            self._wavelet_saliency_score(patch_tokens)
+        ), patch_tokens.shape[0])
+        return cls_out, gate, stability_map, saliency_map
+
+    def _build_global_token(
+            self,
+            patch_tokens: torch.Tensor,
+            fourier_score: torch.Tensor,
+            wavelet_score: torch.Tensor,
+            cls_token: torch.Tensor | None = None,
+    ) -> torch.Tensor:
+        if cls_token is not None:
+            return cls_token
+
+        if not self.learnable_global_token:
+            return patch_tokens.mean(dim=1)
+
+        batch_size, _, channels = patch_tokens.shape
+        token = self.base_global_token.expand(batch_size, channels)
+        if self.global_context_proj is not None:
+            pre_context_gate = self._build_context_gate(fourier_score, wavelet_score)
+            image_context = torch.sum(pre_context_gate.unsqueeze(-1) * patch_tokens, dim=1)
+            token = token + self.global_context_proj(image_context)
+        return self.global_token_norm(token)
 
     def _fourier_stability_score(self, patch_tokens: torch.Tensor) -> torch.Tensor:
         """
@@ -223,29 +290,59 @@ class FourierWaveletTokenAggregation(nn.Module):
         ll_energy = F.interpolate(ll_energy, size=(H, W), mode="nearest")
 
         edge_energy = torch.zeros_like(ll_energy)
-        noise_energy = torch.zeros_like(ll_energy)
+        hh_energy = torch.zeros_like(ll_energy)
 
         for level_detail in detail_coeffs:
             lh, hl, hh = level_detail
             level_edge = 0.5 * (lh.abs().mean(dim=1, keepdim=True) + hl.abs().mean(dim=1, keepdim=True))
-            level_noise = hh.abs().mean(dim=1, keepdim=True)
+            level_hh = hh.abs().mean(dim=1, keepdim=True)
 
             target_size = (H, W)
             level_edge = F.interpolate(level_edge, size=target_size, mode="nearest")
-            level_noise = F.interpolate(level_noise, size=target_size, mode="nearest")
+            level_hh = F.interpolate(level_hh, size=target_size, mode="nearest")
 
             edge_energy = edge_energy + level_edge
-            noise_energy = noise_energy + level_noise
+            hh_energy = hh_energy + level_hh
 
         raw_score = (
                 self.wavelet_ll_weight * ll_energy
                 + self.wavelet_edge_weight * edge_energy
-                - self.wavelet_noise_weight * noise_energy
+                + self.wavelet_hh_weight * hh_energy
         )
         raw_score = raw_score.flatten(1)  # [B, N]
         score = torch.sigmoid(raw_score)
         return score
 
+    def _build_stability_prior(
+            self,
+            fourier_score: torch.Tensor,
+            wavelet_score: torch.Tensor,
+    ) -> torch.Tensor:
+        raw = (
+            self.stability_fourier_weight * fourier_score
+            + self.stability_wavelet_weight * wavelet_score
+        )
+        return torch.sigmoid(raw)
+
+    def _build_saliency_prior(self, wavelet_score: torch.Tensor) -> torch.Tensor:
+        raw = self.saliency_wavelet_weight * wavelet_score
+        return torch.sigmoid(raw)
+
+    def _build_context_gate(
+            self,
+            fourier_score: torch.Tensor,
+            wavelet_score: torch.Tensor,
+    ) -> torch.Tensor:
+        context_score = (
+            self.context_fourier_weight * fourier_score
+            + self.context_wavelet_weight * wavelet_score
+        )
+        return torch.softmax(context_score / max(self.gate_temperature, self.eps), dim=1)
+
+    def _score_to_map(self, score: torch.Tensor, batch_size: int) -> torch.Tensor:
+        H, W = self.grid_size
+        return score.reshape(batch_size, 1, H, W)
+
     def _cls_alignment_score(self, cls_token: torch.Tensor, patch_tokens: torch.Tensor) -> torch.Tensor:
         """
         可选稳定器:偏好已与现有 CLS 令牌对齐的令牌。

+ 76 - 0
lib/modules/swinv2_encoder_2d.py

@@ -0,0 +1,76 @@
+from __future__ import annotations
+
+from argparse import Namespace
+from pathlib import Path
+from typing import Any
+
+import torch
+import torch.nn as nn
+
+from .build_swinv2 import build_swinv2
+
+
+class SwinV2Encoder2d(nn.Module):
+    """
+    面向分割的纯 SwinV2 编码器封装
+    """
+
+    def __init__(
+        self,
+        model_name: str | None = None,
+        config_path: str | Path | None = None,
+        weight_path: str | Path | None = None,
+        args: Namespace | None = None,
+        *,
+        load_weights: bool = True,
+        normalize_features: bool = True,
+        use_multiscale_features: bool = True,
+        include_patch_embed: bool = True,
+        **model_kwargs: Any,
+    ) -> None:
+        super().__init__()
+        backbone, cfg = build_swinv2(
+            model_name=model_name,
+            config_path=config_path,
+            weight_path=weight_path,
+            args=args,
+            load_weights=load_weights,
+            return_config=True,
+            **model_kwargs,
+        )
+        self.backbone = backbone
+        self.cfg = cfg
+        self.normalize_features = normalize_features
+        self.use_multiscale_features = use_multiscale_features
+        self.include_patch_embed = include_patch_embed
+
+        depths = tuple(cfg.MODEL.SWINV2.DEPTHS)
+        embed_dim = int(cfg.MODEL.SWINV2.EMBED_DIM)
+        if self.use_multiscale_features:
+            stage_channels = []
+            if self.include_patch_embed:
+                stage_channels.append(embed_dim)
+            for i in range(len(depths)):
+                channel_multiplier = 2 ** min(i + 1, len(depths) - 1)
+                stage_channels.append(int(embed_dim * channel_multiplier))
+            self.stage_channels = stage_channels
+        else:
+            self.stage_channels = [int(embed_dim * 2**i) for i in range(len(depths))]
+
+    def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor | list[torch.Tensor]]:
+        if self.use_multiscale_features:
+            features = self.backbone.forward_multiscale_features(
+                x,
+                normalize=self.normalize_features,
+                include_patch_embed=self.include_patch_embed,
+            )
+        else:
+            features = self.backbone.forward_stage_features(
+                x, normalize=self.normalize_features
+            )
+
+        deepest = features[-1]
+        return {
+            "features": features,
+            "deepest_feature": deepest,
+        }

+ 10 - 5
lib/modules/swinv2_fwta_encoder_2d.py

@@ -35,6 +35,8 @@ class SwinV2FWTAEncoder2d(nn.Module):
             fwta_fusion_hidden_ratio: float = 0.5,
             fwta_use_global_conditioning: bool = True,
             fwta_residual_scale_init: float = 1.0,
+            fwta_learnable_global_token: bool = True,
+            fwta_global_token_use_image_conditioning: bool = True,
             **model_kwargs: Any,
     ) -> None:
         super().__init__()
@@ -60,7 +62,6 @@ class SwinV2FWTAEncoder2d(nn.Module):
             if self.include_patch_embed:
                 stage_channels.append(embed_dim)
             for i in range(len(depths)):
-                # forward_multiscale_features appends each layer output after its internal downsample.
                 channel_multiplier = 2 ** min(i + 1, len(depths) - 1)
                 stage_channels.append(int(embed_dim * channel_multiplier))
             self.stage_channels = stage_channels
@@ -82,6 +83,8 @@ class SwinV2FWTAEncoder2d(nn.Module):
             residual_scale_init=fwta_residual_scale_init,
             fusion_hidden_ratio=fwta_fusion_hidden_ratio,
             use_cls_conditioning=fwta_use_global_conditioning,
+            learnable_global_token=fwta_learnable_global_token,
+            global_token_use_image_conditioning=fwta_global_token_use_image_conditioning,
         )
 
     def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor | list[torch.Tensor]]:
@@ -93,16 +96,18 @@ class SwinV2FWTAEncoder2d(nn.Module):
             )
         else:
             features = self.backbone.forward_stage_features(x, normalize=self.normalize_features)
+
         deepest = features[-1]
-        b, c, h, w = deepest.shape
         patch_tokens = deepest.flatten(2).transpose(1, 2)
-        cls_token = patch_tokens.mean(dim=1)
-        cls_out, gate, stability_map = self.fwta.forward_with_map(cls_token, patch_tokens)
+        cls_out, gate, stability_prior, saliency_prior = self.fwta.forward_with_map(
+            patch_tokens=patch_tokens,
+        )
 
         return {
             "features": features,
             "deepest_feature": deepest,
             "global_token": cls_out,
             "token_gate": gate,
-            "stability_map": stability_map,
+            "stability_prior": stability_prior,
+            "saliency_prior": saliency_prior,
         }

+ 26 - 0
lib/trainers/base.py

@@ -3,9 +3,11 @@ from __future__ import annotations
 from abc import ABC, abstractmethod
 from pathlib import Path
 import pprint
+import random
 import time
 from typing import Any
 
+import numpy as np
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
@@ -33,6 +35,7 @@ class BaseTrainer(ABC):
     def __init__(self, cfg: dict[str, Any], args: Any | None = None) -> None:
         self.cfg = cfg
         self.args = args
+        self._set_random_seed()
         self.device = self._build_device()
         self.output_dir = self._build_output_dir()
         self.start_epoch = 0
@@ -41,6 +44,23 @@ class BaseTrainer(ABC):
         self.swanlab_run = None
         self.grad_scaler = GradScaler("cuda", enabled=self._amp_enabled())
 
+    def _set_random_seed(self) -> None:
+        seed = int(self.cfg.get("train", {}).get("seed", 42))
+        random.seed(seed)
+        np.random.seed(seed)
+        torch.manual_seed(seed)
+        if torch.cuda.is_available():
+            torch.cuda.manual_seed(seed)
+            torch.cuda.manual_seed_all(seed)
+
+        deterministic = bool(self.cfg.get("train", {}).get("deterministic", False))
+        if deterministic:
+            torch.backends.cudnn.deterministic = True
+            torch.backends.cudnn.benchmark = False
+        else:
+            torch.backends.cudnn.deterministic = False
+            torch.backends.cudnn.benchmark = True
+
     def _build_device(self) -> torch.device:
         device_name = self.cfg.get("train", {}).get("device", "cpu")
         if device_name == "cuda" and not torch.cuda.is_available():
@@ -152,6 +172,7 @@ class BaseTrainer(ABC):
             batch_size: int,
             shuffle: bool,
             split_file: str | None = None,
+            augmentation_config: dict[str, Any] | None = None,
     ):
         dataset_cfg = self._dataset_cfg()
         train_cfg = self.cfg.get("train", {})
@@ -165,6 +186,7 @@ class BaseTrainer(ABC):
             batch_size=batch_size,
             shuffle=shuffle,
             num_workers=num_workers,
+            augmentation_config=augmentation_config,
             image_transform=self._build_resize_transform(mode="image"),
             mask_transform=self._build_resize_transform(mode="mask"),
             pin_memory=bool(train_cfg.get("pin_memory", self.device.type == "cuda")),
@@ -188,6 +210,7 @@ class BaseTrainer(ABC):
             split_file=dataset_cfg.get("val_split_file"),
             batch_size=batch_size,
             shuffle=shuffle,
+            augmentation_config=self.cfg.get("augmentation", {}).get("val"),
         )
 
     def _checkpoint_cfg(self) -> dict[str, Any]:
@@ -371,6 +394,9 @@ class BaseTrainer(ABC):
     def _grad_clip_enabled(self) -> bool:
         return bool(self._grad_clip_cfg().get("enabled", False))
 
+    def _accum_steps(self) -> int:
+        return max(1, int(self.cfg.get("train", {}).get("accum_steps", 1)))
+
     def _clip_gradients(self, module: nn.Module | None) -> float | None:
         if module is None or not self._grad_clip_enabled():
             return None