Просмотр исходного кода

refactor: simplify active segmentation baseline

kekezack 1 месяц назад
Родитель
Сommit
92bf26dc75

+ 21 - 5
configs/segmentation/train_sup_us_template.yaml

@@ -3,9 +3,11 @@ trainer:
 
 train:
   seed: 42
+  deterministic: false
   epochs: 200
   batch_size: 4
   val_batch_size: 4
+  accum_steps: 1
   amp: true
   num_workers: 4
   pin_memory: true
@@ -30,6 +32,14 @@ metrics:
     - name: dice
     - name: iou
 
+loss:
+  name: dicece
+  task_mode: binary
+  params:
+    include_background: true
+    lambda_dice: 0.7
+    lambda_ce: 0.3
+
 validation:
   enabled: true
   interval: 1
@@ -59,11 +69,6 @@ model:
   model_name: swinv2_tiny_patch4_window8_256
   load_weights: false
   decoder_channels: [384, 192, 96, 96]
-  fwta_wavelet: haar
-  fwta_level: 1
-  fwta_sigma_ratio: 0.35
-  fwta_tau_fourier: 0.15
-  fwta_gate_temperature: 1.0
   use_multiscale_features: true
   include_patch_embed: true
 
@@ -83,6 +88,17 @@ scheduler:
     T_max: 190
     eta_min: 1.0e-6
 
+augmentation:
+  train:
+    random_flip: true
+    random_rotate_90: true
+    random_brightness_contrast: true
+    brightness_limit: 0.15
+    contrast_limit: 0.15
+    random_gaussian_noise: true
+    gaussian_noise_std: 0.03
+  val: {}
+
 checkpoint:
   dir: outputs/supervised_segmentation/train_sup_us_template
   save: true

+ 21 - 5
configs/segmentation/us_exp_sup_busi.yaml

@@ -3,9 +3,11 @@ trainer:
 
 train:
   seed: 42
+  deterministic: false
   epochs: 200
   batch_size: 4
   val_batch_size: 4
+  accum_steps: 1
   amp: true
   num_workers: 4
   pin_memory: true
@@ -23,6 +25,14 @@ metrics:
     - name: dice
     - name: iou
 
+loss:
+  name: dicece
+  task_mode: binary
+  params:
+    include_background: true
+    lambda_dice: 0.7
+    lambda_ce: 0.3
+
 validation:
   enabled: true
   interval: 1
@@ -49,11 +59,6 @@ model:
   model_name: swinv2_tiny_patch4_window8_256
   load_weights: false
   decoder_channels: [384, 192, 96, 96]
-  fwta_wavelet: haar
-  fwta_level: 1
-  fwta_sigma_ratio: 0.35
-  fwta_tau_fourier: 0.15
-  fwta_gate_temperature: 1.0
   use_multiscale_features: true
   include_patch_embed: true
 
@@ -73,6 +78,17 @@ scheduler:
     T_max: 190
     eta_min: 1.0e-6
 
+augmentation:
+  train:
+    random_flip: true
+    random_rotate_90: true
+    random_brightness_contrast: true
+    brightness_limit: 0.15
+    contrast_limit: 0.15
+    random_gaussian_noise: true
+    gaussian_noise_std: 0.03
+  val: {}
+
 checkpoint:
   dir: outputs/experiments/supervised/BUSI
   save: true

+ 108 - 0
configs/segmentation/us_exp_sup_busi_ablation.yaml

@@ -0,0 +1,108 @@
+trainer:
+  name: supervised_segmentation
+
+train:
+  seed: 42
+  deterministic: false
+  epochs: 200
+  batch_size: 4
+  val_batch_size: 4
+  accum_steps: 1
+  amp: true
+  num_workers: 4
+  pin_memory: true
+  persistent_workers: true
+  prefetch_factor: 2
+  device: cuda
+  grad_clip:
+    enabled: true
+    max_norm: 1.0
+    norm_type: 2.0
+
+metrics:
+  task_mode: binary
+  metrics:
+    - name: dice
+    - name: iou
+
+loss:
+  name: dicece
+  task_mode: binary
+  params:
+    include_background: true
+    lambda_dice: 0.7
+    lambda_ce: 0.3
+
+validation:
+  enabled: true
+  interval: 1
+  threshold: 0.5
+  early_stopping: true
+  early_stopping_patience: 40
+  early_stopping_min_delta: 0.0
+  metrics:
+    task_mode: binary
+    metrics:
+      - name: dice
+      - name: iou
+
+dataset:
+  dataset_name: BUSI
+  root: data/BUSI
+  split: train
+  val_split: val
+  image_size: [256, 256]
+  in_channels: 3
+  num_classes: 1
+
+model:
+  model_name: swinv2_tiny_patch4_window8_256
+  load_weights: false
+  decoder_channels: [384, 192, 96, 96]
+  use_multiscale_features: true
+  include_patch_embed: true
+
+optimizer:
+  name: adamw
+  lr: 1.0e-4
+  weight_decay: 0.05
+
+scheduler:
+  name: cosine
+  warmup:
+    name: linear
+    params:
+      start_factor: 0.1
+      total_iters: 10
+  params:
+    T_max: 190
+    eta_min: 1.0e-6
+
+augmentation:
+  train:
+    random_flip: true
+    random_rotate_90: true
+    random_brightness_contrast: true
+    brightness_limit: 0.15
+    contrast_limit: 0.15
+    random_gaussian_noise: true
+    gaussian_noise_std: 0.03
+  val: {}
+
+checkpoint:
+  dir: outputs/experiments/supervised/BUSI_ablation
+  save: true
+  save_last: true
+  monitor: dice
+  monitor_mode: max
+  resume: null
+  resume_strict: true
+  resume_training: true
+
+logging:
+  log_interval: 10
+  print_training_setup: true
+  use_swanlab: true
+  project: X_SSL_Net
+  experiment_name: sup_busi_ablation
+  swanlab_mode: null

+ 8 - 11
lib/modules/__init__.py

@@ -1,8 +1,7 @@
 from .attentions_2d import CirculantAttention2d, ComplexLinear, WaveletAttentionGlobalBranch2d
 from .blocks_2d import WaveletFFTBlock2d, WaveletFFTMRFFIModule2d
 from .build_swinv2 import build_swinv2, build_swinv2_auto
-from .decoder_2d import BoundaryRefineBlock2d, StructureAwareDecodeBlock2d, StructureAwareDecoder2d
-from .fwta_2d import FourierWaveletTokenAggregation
+from .decoder_2d import DecodeRefineBlock2d, SegmentationDecodeBlock2d, SegmentationDecoder2d
 from .layers_2d import (
     BNLinear1d,
     Conv2dBN,
@@ -13,8 +12,8 @@ from .layers_2d import (
     Residual,
     Scale,
 )
-from .segmentation_2d import GlobalTokenConditioning2d, SegmentationNet2d
-from .swinv2_fwta_encoder_2d import SwinV2FWTAEncoder2d
+from .segmentation_2d import SegmentationModel2d
+from .swinv2_encoder_2d import SwinV2Encoder2d
 
 __all__ = [
     "CirculantAttention2d",
@@ -24,10 +23,9 @@ __all__ = [
     "WaveletFFTMRFFIModule2d",
     "build_swinv2",
     "build_swinv2_auto",
-    "BoundaryRefineBlock2d",
-    "StructureAwareDecodeBlock2d",
-    "StructureAwareDecoder2d",
-    "FourierWaveletTokenAggregation",
+    "DecodeRefineBlock2d",
+    "SegmentationDecodeBlock2d",
+    "SegmentationDecoder2d",
     "BNLinear1d",
     "Conv2dBN",
     "DWConv2dBNReLU",
@@ -36,7 +34,6 @@ __all__ = [
     "PatchMerging2d",
     "Residual",
     "Scale",
-    "GlobalTokenConditioning2d",
-    "SegmentationNet2d",
-    "SwinV2FWTAEncoder2d",
+    "SegmentationModel2d",
+    "SwinV2Encoder2d",
 ]

+ 20 - 52
lib/modules/decoder_2d.py

@@ -9,9 +9,9 @@ import torch.nn.functional as F
 from .layers_2d import Conv2dBN
 
 
-class BoundaryRefineBlock2d(nn.Module):
+class DecodeRefineBlock2d(nn.Module):
     """
-    使用边界提示和稳定性图对解码特征做轻量细化。
+    对解码后的融合特征做轻量残差细化。
     """
 
     def __init__(self, channels: int) -> None:
@@ -22,32 +22,13 @@ class BoundaryRefineBlock2d(nn.Module):
             Conv2dBN(channels, channels, 3, 1, 1),
         )
 
-    def forward(
-            self,
-            x: torch.Tensor,
-            boundary_hint: torch.Tensor | None = None,
-            stability_map: torch.Tensor | None = None,
-    ) -> torch.Tensor:
-        modulator = 1.0
-
-        if stability_map is not None:
-            stability_map = F.interpolate(
-                stability_map, size=x.shape[-2:], mode="bilinear", align_corners=False
-            )
-            modulator = modulator + stability_map
-
-        if boundary_hint is not None:
-            boundary_hint = F.interpolate(
-                boundary_hint, size=x.shape[-2:], mode="bilinear", align_corners=False
-            )
-            modulator = modulator + boundary_hint
-
-        return x + self.refine(x * modulator)
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return x + self.refine(x)
 
 
-class StructureAwareDecodeBlock2d(nn.Module):
+class SegmentationDecodeBlock2d(nn.Module):
     """
-    单层结构感知解码块。
+    单层解码块:上采样高层特征,与 skip 特征融合后细化。
     """
 
     def __init__(self, in_channels: int, skip_channels: int, out_channels: int) -> None:
@@ -66,31 +47,28 @@ class StructureAwareDecodeBlock2d(nn.Module):
             Conv2dBN(out_channels, out_channels, 3, 1, 1),
             nn.ReLU(inplace=True),
         )
-        self.refine = BoundaryRefineBlock2d(out_channels)
+        self.refine = DecodeRefineBlock2d(out_channels)
 
-    def forward(
-            self,
-            x: torch.Tensor,
-            skip: torch.Tensor,
-            stability_map: torch.Tensor | None = None,
-            boundary_hint: torch.Tensor | None = None,
-    ) -> torch.Tensor:
+    def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
         x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
         x = self.high_proj(x)
         skip = self.skip_proj(skip)
         x = self.fuse(torch.cat([x, skip], dim=1))
-        x = self.refine(x, boundary_hint=boundary_hint, stability_map=stability_map)
-        return x
+        return self.refine(x)
 
 
-class StructureAwareDecoder2d(nn.Module):
+class SegmentationDecoder2d(nn.Module):
     """
-    第一版结构感知解码器骨架。
+    纯净的多尺度解码器骨架。
 
     输入特征默认按从浅到深排列,最后一个特征视为最深层输入。
     """
 
-    def __init__(self, encoder_channels: Sequence[int], decoder_channels: Sequence[int] | None = None) -> None:
+    def __init__(
+            self,
+            encoder_channels: Sequence[int],
+            decoder_channels: Sequence[int] | None = None,
+    ) -> None:
         super().__init__()
         if len(encoder_channels) < 2:
             raise ValueError("encoder_channels must contain at least two stages.")
@@ -107,17 +85,12 @@ class StructureAwareDecoder2d(nn.Module):
 
         blocks = []
         for skip_ch, out_ch in zip(skip_channels, decoder_channels):
-            blocks.append(StructureAwareDecodeBlock2d(in_channels, skip_ch, out_ch))
+            blocks.append(SegmentationDecodeBlock2d(in_channels, skip_ch, out_ch))
             in_channels = out_ch
         self.blocks = nn.ModuleList(blocks)
         self.out_channels = in_channels
 
-    def forward(
-            self,
-            features: Sequence[torch.Tensor],
-            stability_map: torch.Tensor | None = None,
-            boundary_hints: Sequence[torch.Tensor] | None = None,
-    ) -> tuple[torch.Tensor, list[torch.Tensor]]:
+    def forward(self, features: Sequence[torch.Tensor]) -> tuple[torch.Tensor, list[torch.Tensor]]:
         if len(features) != len(self.encoder_channels):
             raise ValueError(
                 f"feature count mismatch: got {len(features)}, expected {len(self.encoder_channels)}"
@@ -127,13 +100,8 @@ class StructureAwareDecoder2d(nn.Module):
         skips = list(reversed(features[:-1]))
         decoder_features = []
 
-        if boundary_hints is None:
-            boundary_hints = [None] * len(self.blocks)
-        elif len(boundary_hints) != len(self.blocks):
-            raise ValueError("boundary_hints length must match decoder depth.")
-
-        for block, skip, boundary_hint in zip(self.blocks, skips, boundary_hints):
-            x = block(x, skip, stability_map=stability_map, boundary_hint=boundary_hint)
+        for block, skip in zip(self.blocks, skips):
+            x = block(x, skip)
             decoder_features.append(x)
 
         return x, decoder_features

+ 11 - 81
lib/modules/segmentation_2d.py

@@ -8,9 +8,9 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
-from .decoder_2d import StructureAwareDecoder2d
+from .decoder_2d import SegmentationDecoder2d
 from .layers_2d import Conv2dBN
-from .swinv2_fwta_encoder_2d import SwinV2FWTAEncoder2d
+from .swinv2_encoder_2d import SwinV2Encoder2d
 
 
 class SegmentationHead2d(nn.Module):
@@ -27,49 +27,9 @@ class SegmentationHead2d(nn.Module):
         return F.interpolate(x, size=output_size, mode="bilinear", align_corners=False)
 
 
-class BoundaryHead2d(nn.Module):
-    def __init__(self, in_channels: int, out_channels: int = 1) -> None:
-        super().__init__()
-        self.block = nn.Sequential(
-            Conv2dBN(in_channels, in_channels, 3, 1, 1),
-            nn.ReLU(inplace=True),
-            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True),
-        )
-
-    def forward(self, x: torch.Tensor, output_size: tuple[int, int]) -> torch.Tensor:
-        x = self.block(x)
-        return F.interpolate(x, size=output_size, mode="bilinear", align_corners=False)
-
-
-class GlobalTokenConditioning2d(nn.Module):
-    """
-    使用 FWTA 更新后的全局前景 token 对解码特征做通道调制。
+class SegmentationModel2d(nn.Module):
     """
-
-    def __init__(self, token_channels: int, feature_channels: int) -> None:
-        super().__init__()
-        hidden_channels = max(feature_channels // 2, 32)
-        self.gate = nn.Sequential(
-            nn.LayerNorm(token_channels),
-            nn.Linear(token_channels, hidden_channels),
-            nn.GELU(),
-            nn.Linear(hidden_channels, feature_channels),
-            nn.Sigmoid(),
-        )
-
-    def forward(self, x: torch.Tensor, global_token: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
-        channel_gate = self.gate(global_token).unsqueeze(-1).unsqueeze(-1)
-        return x * (1.0 + channel_gate), channel_gate
-
-
-class SegmentationNet2d(nn.Module):
-    """
-    第一版超声分割主网络骨架。
-
-    当前职责:
-    - 编码器输出多尺度特征和稳定性图
-    - 结构感知解码器恢复分割特征
-    - 同时输出分割图和边界图
+    简化后的超声分割网络。
     """
 
     def __init__(
@@ -85,7 +45,7 @@ class SegmentationNet2d(nn.Module):
             **encoder_kwargs: Any,
     ) -> None:
         super().__init__()
-        self.encoder = SwinV2FWTAEncoder2d(
+        self.encoder = SwinV2Encoder2d(
             model_name=model_name,
             config_path=config_path,
             weight_path=weight_path,
@@ -93,46 +53,16 @@ class SegmentationNet2d(nn.Module):
             load_weights=load_weights,
             **encoder_kwargs,
         )
-        self.decoder = StructureAwareDecoder2d(
+        self.decoder = SegmentationDecoder2d(
             encoder_channels=self.encoder.stage_channels,
             decoder_channels=decoder_channels,
         )
-        self.global_conditioning = GlobalTokenConditioning2d(
-            token_channels=self.encoder.stage_channels[-1],
-            feature_channels=self.decoder.out_channels,
-        )
         self.segmentation_head = SegmentationHead2d(self.decoder.out_channels, num_classes)
-        self.boundary_head = BoundaryHead2d(self.decoder.out_channels, out_channels=1)
-
-    def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor | list[torch.Tensor]]:
-        encoder_outputs = self.encoder(x)
-        features = encoder_outputs["features"]
-        stability_map = encoder_outputs["stability_map"]
 
-        decoder_out, decoder_features = self.decoder(
-            features=features,
-            stability_map=stability_map,
-        )
-        conditioned_decoder_out, global_channel_gate = self.global_conditioning(
-            decoder_out,
-            encoder_outputs["global_token"],
-        )
+    def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
+        features = self.encoder(x)["features"]
+        decoder_out, _ = self.decoder(features)
 
         output_size = x.shape[-2:]
-        seg_logits = self.segmentation_head(conditioned_decoder_out, output_size=output_size)
-        boundary_logits = self.boundary_head(conditioned_decoder_out, output_size=output_size)
-
-        return {
-            "seg_logits": seg_logits,
-            "boundary_logits": boundary_logits,
-            "stability_map": F.interpolate(
-                stability_map, size=output_size, mode="bilinear", align_corners=False
-            ),
-            "encoder_features": features,
-            "decoder_features": decoder_features,
-            "conditioned_decoder_feature": conditioned_decoder_out,
-            "deepest_feature": encoder_outputs["deepest_feature"],
-            "global_token": encoder_outputs["global_token"],
-            "global_channel_gate": global_channel_gate,
-            "token_gate": encoder_outputs["token_gate"],
-        }
+        seg_logits = self.segmentation_head(decoder_out, output_size=output_size)
+        return {"seg_logits": seg_logits}

+ 0 - 8
lib/tools/__init__.py

@@ -1,6 +1,4 @@
-from .boundary import boundary_band_map, logits_to_boundary, logits_to_binary_mask, mask_to_boundary_map
 from .loss import DEFAULT_TASK_LOSS, LOSS_REGISTRY, build_loss
-from .loss import BinaryBoundaryLoss, MaskBoundaryConsistencyLoss
 from .metrics import (
     DEFAULT_METRIC_CONFIG,
     METRIC_REGISTRY,
@@ -28,13 +26,7 @@ __all__ = [
     "METRIC_REGISTRY",
     "OPTIMIZER_REGISTRY",
     "SCHEDULER_REGISTRY",
-    "mask_to_boundary_map",
-    "logits_to_binary_mask",
-    "logits_to_boundary",
-    "boundary_band_map",
     "build_loss",
-    "BinaryBoundaryLoss",
-    "MaskBoundaryConsistencyLoss",
     "build_metric",
     "build_metrics",
     "compute_metrics",

+ 28 - 40
lib/trainers/supervised.py

@@ -6,53 +6,46 @@ from typing import Any
 import torch
 from torch.utils.data import DataLoader
 
-from lib.modules import SegmentationNet2d
-from lib.tools import (
-    BinaryBoundaryLoss,
-    MaskBoundaryConsistencyLoss,
-    build_optimizer,
-    build_scheduler,
-    mask_to_boundary_map,
-)
+from lib.modules import SegmentationModel2d
+from lib.tools import build_loss, build_optimizer, build_scheduler
 from .base import BaseTrainer
 
 
 class SupervisedSegmentationTrainer(BaseTrainer):
     def __init__(self, cfg: dict[str, Any], args: Any | None = None) -> None:
         super().__init__(cfg=cfg, args=args)
-        self.model: SegmentationNet2d | None = None
+        self.model: SegmentationModel2d | None = None
         self.optimizer = None
         self.scheduler = None
         self.loader: DataLoader | None = None
         self.val_loader: DataLoader | None = None
         self.seg_loss = None
-        self.boundary_loss = BinaryBoundaryLoss()
-        self.consistency_loss = MaskBoundaryConsistencyLoss()
 
     def build(self) -> None:
         dataset_cfg = self.cfg["dataset"]
         model_cfg = self.cfg["model"]
         train_cfg = self.cfg["train"]
 
-        self.model = SegmentationNet2d(
+        self.model = SegmentationModel2d(
             num_classes=dataset_cfg["num_classes"],
             model_name=model_cfg["model_name"],
             load_weights=model_cfg.get("load_weights", False),
             decoder_channels=model_cfg.get("decoder_channels"),
-            fwta_wavelet=model_cfg.get("fwta_wavelet", "haar"),
-            fwta_level=model_cfg.get("fwta_level", 1),
-            fwta_sigma_ratio=model_cfg.get("fwta_sigma_ratio", 0.35),
-            fwta_tau_fourier=model_cfg.get("fwta_tau_fourier", 0.15),
-            fwta_gate_temperature=model_cfg.get("fwta_gate_temperature", 1.0),
+            use_multiscale_features=model_cfg.get("use_multiscale_features", True),
+            include_patch_embed=model_cfg.get("include_patch_embed", True),
         ).to(self.device)
 
         self.optimizer = build_optimizer(self.model, self.cfg["optimizer"])
         self.scheduler = build_scheduler(self.optimizer, self.cfg.get("scheduler"))
+        loss_cfg = self.cfg.get("loss")
+        if loss_cfg is not None:
+            self.seg_loss = build_loss(loss_cfg)
         self.loader = self._build_segmentation_loader(
             split=str(dataset_cfg.get("split", "train")),
             split_file=dataset_cfg.get("split_file"),
             batch_size=self._resolve_batch_size("batch_size", 4),
             shuffle=bool(train_cfg.get("shuffle", True)),
+            augmentation_config=self.cfg.get("augmentation", {}).get("train"),
         )
         self.val_loader = self._build_val_loader(
             batch_size=self._resolve_batch_size(
@@ -78,19 +71,17 @@ class SupervisedSegmentationTrainer(BaseTrainer):
         with torch.autocast(device_type=self.device.type, enabled=self._amp_enabled()):
             outputs = self.model(image)
             seg_logits = outputs["seg_logits"]
-            boundary_logits = outputs["boundary_logits"]
 
-            seg_loss = torch.nn.functional.binary_cross_entropy_with_logits(seg_logits, mask)
-            boundary_target = mask_to_boundary_map(mask)
-            boundary_loss = self.boundary_loss(boundary_logits, boundary_target)
-            consistency_loss = self.consistency_loss(seg_logits, boundary_logits)
-            total_loss = seg_loss + boundary_loss + 0.1 * consistency_loss
+            if self.seg_loss is None:
+                seg_loss = torch.nn.functional.binary_cross_entropy_with_logits(seg_logits, mask)
+            else:
+                seg_loss = self.seg_loss(seg_logits, mask)
+
+            total_loss = seg_loss
 
         losses = {
             "total": total_loss,
             "seg": seg_loss,
-            "boundary": boundary_loss,
-            "consistency": consistency_loss,
         }
         return outputs, losses
 
@@ -106,8 +97,6 @@ class SupervisedSegmentationTrainer(BaseTrainer):
         metrics = self._build_validation_metrics()
         total = 0.0
         seg = 0.0
-        boundary = 0.0
-        consistency = 0.0
         steps = 0
         with torch.no_grad():
             for batch in self.val_loader:
@@ -116,8 +105,6 @@ class SupervisedSegmentationTrainer(BaseTrainer):
                 outputs, losses = self._compute_losses(image, mask)
                 total += float(losses["total"].detach().cpu())
                 seg += float(losses["seg"].detach().cpu())
-                boundary += float(losses["boundary"].detach().cpu())
-                consistency += float(losses["consistency"].detach().cpu())
                 self._update_validation_metrics(
                     metrics,
                     logits=outputs["seg_logits"],
@@ -130,8 +117,6 @@ class SupervisedSegmentationTrainer(BaseTrainer):
         val_metrics = {
             "total": total / steps,
             "seg": seg / steps,
-            "boundary": boundary / steps,
-            "consistency": consistency / steps,
         }
         val_metrics.update(self._compute_validation_metric_values(metrics))
         return val_metrics
@@ -141,6 +126,7 @@ class SupervisedSegmentationTrainer(BaseTrainer):
             raise RuntimeError("Trainer.build() must be called before train().")
 
         epochs = int(self.cfg["train"].get("epochs", 1))
+        accum_steps = self._accum_steps()
         try:
             self._print_training_setup(
                 model_map={"model": self.model},
@@ -150,11 +136,10 @@ class SupervisedSegmentationTrainer(BaseTrainer):
             )
             for epoch in range(self.start_epoch, epochs):
                 self.model.train()
+                self.optimizer.zero_grad()
                 train_metric_sums = {
                     "total": 0.0,
                     "seg": 0.0,
-                    "boundary": 0.0,
-                    "consistency": 0.0,
                 }
                 train_metrics: dict[str, float] | None = None
                 end_time = time.perf_counter()
@@ -165,14 +150,17 @@ class SupervisedSegmentationTrainer(BaseTrainer):
                     image = batch["image"].to(self.device)
                     mask = batch["mask"].to(self.device)
                     _, losses = self._compute_losses(image, mask)
-                    self.optimizer.zero_grad()
-                    self.grad_scaler.scale(losses["total"]).backward()
+                    scaled_total_loss = losses["total"] / accum_steps
+                    self.grad_scaler.scale(scaled_total_loss).backward()
                     grad_norm = None
-                    if self._grad_clip_enabled():
-                        self.grad_scaler.unscale_(self.optimizer)
-                        grad_norm = self._clip_gradients(self.model)
-                    self.grad_scaler.step(self.optimizer)
-                    self.grad_scaler.update()
+                    should_step = (step % accum_steps == 0) or (step == num_steps)
+                    if should_step:
+                        if self._grad_clip_enabled():
+                            self.grad_scaler.unscale_(self.optimizer)
+                            grad_norm = self._clip_gradients(self.model)
+                        self.grad_scaler.step(self.optimizer)
+                        self.grad_scaler.update()
+                        self.optimizer.zero_grad()
                     train_metrics = self._detach_metrics(losses)
                     if grad_norm is not None:
                         train_metrics["grad_norm"] = grad_norm

+ 822 - 0
tmp/docs/training/当前项目详解与纯文本架构流程图.md

@@ -0,0 +1,822 @@
+# 当前项目详解与纯文本架构流程图
+
+## 1. 当前项目定位
+
+当前项目的现役主路径已经收缩为一个纯净的超声分割全监督基线。
+
+现在真实生效的模型与训练链路只有一条:
+
+`输入图像 -> 纯 SwinV2 编码器 -> 纯分割解码器 -> 分割头 -> seg_logits -> 分割损失`
+
+当前主路径已经满足以下要求:
+
+1. 不调用 FWTA
+2. 不调用 FWTA 编码器
+3. 不存在边界分支
+4. 不存在边界损失
+5. 不存在一致性损失
+6. 配置文件中不再使用 `aux_loss`
+7. 实验脚本中不再做边界相关消融
+
+这份文档只描述当前真实主链,不再记录历史结构语义。
+
+---
+
+## 2. 当前项目一句话概述
+
+`X_SSL_Net` 当前可以理解为:
+
+`一个面向超声图像分割的纯单头 2D segmentation baseline,主干为 SwinV2 encoder + segmentation decoder。`
+
+---
+
+## 3. 当前主链总览
+
+### 3.1 当前训练入口
+
+训练统一从下面这个入口启动:
+
+`tools/train.py`
+
+它负责:
+
+1. 读取 YAML 配置
+2. 接收命令行 `--set` 覆盖参数
+3. 构建 trainer
+4. 调用 `trainer.train()`
+
+### 3.2 当前 trainer
+
+当前正式 trainer 仍然叫:
+
+`SupervisedSegmentationTrainer`
+
+但它现在已经只做纯分割训练,不再处理任何边界辅助分支。
+
+### 3.3 当前模型
+
+当前现役模型是:
+
+`SegmentationModel2d`
+
+它的内部结构为:
+
+1. `SwinV2Encoder2d`
+2. `SegmentationDecoder2d`
+3. `SegmentationHead2d`
+
+最终只输出:
+
+1. `seg_logits`
+
+---
+
+## 4. 目录职责总览
+
+```text
+X_SSL_Net/
+|
+|-- tools/
+|   |-- train.py
+|   |-- run_us_experiments.sh
+|   `-- summarize_results.py
+|
+|-- configs/
+|   |-- segmentation/
+|   `-- swinv2/
+|
+|-- lib/
+|   |-- trainers/
+|   |-- modules/
+|   |-- data/
+|   `-- tools/
+|
+`-- tmp/docs/training/
+```
+
+### 4.1 `tools/`
+
+1. `tools/train.py`
+   统一训练启动入口
+2. `tools/run_us_experiments.sh`
+   纯分割实验脚本
+3. `tools/summarize_results.py`
+   汇总 `best.pth` 中的实验结果
+
+### 4.2 `configs/segmentation/`
+
+1. `train_sup_us_template.yaml`
+   当前最核心的纯分割训练模板
+2. `us_exp_sup_busi.yaml`
+   BUSI 示例配置
+3. `us_exp_sup_busi_ablation.yaml`
+   目前内容已与普通纯分割配置对齐,不再承载边界消融语义
+
+### 4.3 `lib/trainers/`
+
+1. `builder.py`
+   trainer 注册与构建
+2. `base.py`
+   训练公共底座
+3. `supervised.py`
+   当前纯分割训练流程实现
+
+### 4.4 `lib/modules/`
+
+1. `swinv2_encoder_2d.py`
+   纯 SwinV2 编码器封装
+2. `decoder_2d.py`
+   当前纯分割解码器实现
+3. `segmentation_2d.py`
+   当前主模型封装
+4. `build_swinv2.py`
+   SwinV2 backbone 构建器
+
+### 4.5 `lib/data/`
+
+1. `builder.py`
+   构建样本索引
+2. `loaders.py`
+   构建 dataset 与 dataloader
+3. `datasets.py`
+   真正读取 image 和 mask
+4. `augment.py`
+   数据增强
+
+### 4.6 `lib/tools/`
+
+当前主路径真正会用到的核心工具只有:
+
+1. `loss.py`
+   主分割损失构建
+2. `metrics.py`
+   Dice / IoU 等验证指标
+3. `optim.py`
+   optimizer 和 scheduler 构建
+
+---
+
+## 5. 当前纯文本架构图
+
+### 5.1 训练系统总架构图
+
+```text
++----------------------------------------------------------------------------------+
+|                              当前纯分割训练系统                                  |
++----------------------------------------------------------------------------------+
+|                                                                                  |
+|  tools/train.py                                                                  |
+|      |                                                                           |
+|      v                                                                           |
+|  读取 YAML 配置 + 应用 --set 覆盖                                                |
+|      |                                                                           |
+|      v                                                                           |
+|  build_trainer(cfg)                                                              |
+|      |                                                                           |
+|      v                                                                           |
+|  SupervisedSegmentationTrainer                                                   |
+|      |                                                                           |
+|      +------------------- build() -------------------+                           |
+|      |                                               |                           |
+|      v                                               v                           |
+|  build_dataloader()                            SegmentationModel2d               |
+|      |                                               |                           |
+|      v                                               v                           |
+|  SegmentationRecordDataset                     SwinV2Encoder2d                   |
+|      |                                               |                           |
+|      v                                               v                           |
+|  image, mask batch                            multi-scale features               |
+|                                                        |                         |
+|                                                        v                         |
+|                                              SegmentationDecoder2d              |
+|                                                        |                         |
+|                                                        v                         |
+|                                                decoded feature                   |
+|                                                        |                         |
+|                                                        v                         |
+|                                                SegmentationHead2d                |
+|                                                        |                         |
+|                                                        v                         |
+|                                                     seg_logits                   |
+|                                                        |                         |
+|                                                        v                         |
+|                                                     seg_loss                     |
+|                                                        |                         |
+|                                                        v                         |
+|                                            backward + optimizer.step()           |
+|                                                        |                         |
+|                                                        v                         |
+|                                          validation / metric / checkpoint        |
+|                                                                                  |
++----------------------------------------------------------------------------------+
+```
+
+### 5.2 模型内部结构图
+
+```text
+Input Image [B, 3, H, W]
+    |
+    v
+SwinV2Encoder2d
+    |
+    +-- feature_0  (可选 patch_embed 特征)
+    +-- feature_1
+    +-- feature_2
+    +-- feature_3
+    `-- feature_4
+            |
+            v
+SegmentationDecoder2d
+    |
+    +-- Decode Block 1: deepest + skip_3
+    +-- Decode Block 2: upsample + skip_2
+    +-- Decode Block 3: upsample + skip_1
+    `-- Decode Block 4: upsample + skip_0
+            |
+            v
+decoded feature
+    |
+    v
+SegmentationHead2d
+    |
+    v
+seg_logits
+    |
+    v
+上采样回输入分辨率
+```
+
+### 5.3 数据系统结构图
+
+```text
+dataset.root
+    |
+    v
+build_dataset_index(dataset_name, root)
+    |
+    v
+apply split
+    |
+    v
+SegmentationRecordDataset
+    |
+    +-- 读取 image
+    +-- 读取 mask
+    +-- 应用 augmentation
+    +-- resize image/mask
+    |
+    v
+DataLoader
+    |
+    v
+batch = {
+    image,
+    mask,
+    dataset_name,
+    sample_id,
+    split,
+    class_name,
+    meta
+}
+```
+
+---
+
+## 6. 当前纯文本流程图
+
+### 6.1 启动流程图
+
+```text
+[开始]
+   |
+   v
+执行 tools/train.py --config xxx.yaml --set key=value ...
+   |
+   v
+parse_args()
+   |
+   v
+load_yaml_config()
+   |
+   v
+apply_dotlist_overrides()
+   |
+   v
+build_trainer(cfg)
+   |
+   v
+trainer.build()
+   |
+   +--> 构建 model
+   +--> 构建 optimizer
+   +--> 构建 scheduler
+   +--> 构建 train loader
+   +--> 构建 val loader
+   +--> 恢复 checkpoint(如配置)
+   `--> 初始化 SwanLab(如启用)
+   |
+   v
+trainer.train()
+   |
+   v
+[进入 epoch 循环]
+```
+
+### 6.2 单个 epoch 流程图
+
+```text
+[epoch 开始]
+   |
+   v
+model.train()
+optimizer.zero_grad()
+   |
+   v
+for batch in train_loader:
+   |
+   +--> 取 image, mask
+   |
+   +--> model(image)
+   |      |
+   |      `--> seg_logits
+   |
+   +--> 计算 seg_loss
+   |
+   +--> total_loss = seg_loss
+   |
+   +--> backward()
+   |
+   +--> 如果到达 accum_steps:
+   |      |
+   |      +--> 可选 gradient clipping
+   |      +--> optimizer.step()
+   |      +--> scaler.update()
+   |      `--> optimizer.zero_grad()
+   |
+   `--> 记录 step 日志
+   |
+   v
+epoch 结束后 scheduler.step()
+   |
+   v
+如果需要验证:
+   |
+   +--> model.eval()
+   +--> 遍历 val_loader
+   +--> 统计 val loss
+   `--> 统计 Dice / IoU
+   |
+   v
+保存 checkpoint
+   |
+   v
+判断 early stopping
+```
+
+### 6.3 单次前向传播流程图
+
+```text
+image
+  |
+  v
+SegmentationModel2d.forward()
+  |
+  +--> features = encoder(image)["features"]
+  |
+  +--> decoder_out, _ = decoder(features)
+  |
+  `--> seg_logits = segmentation_head(decoder_out, output_size=input_size)
+           |
+           v
+return {
+  seg_logits
+}
+```
+
+---
+
+## 7. 当前模型链路详解
+
+### 7.1 `SegmentationModel2d`
+
+文件:`lib/modules/segmentation_2d.py`
+
+这是当前主模型封装。
+
+它的职责非常直接:
+
+1. 调用编码器提取多尺度特征
+2. 调用解码器恢复高分辨率分割特征
+3. 调用分割头输出 `seg_logits`
+
+它不再负责:
+
+1. 边界输出
+2. 多分支辅助预测
+3. 任何历史先验注入接口
+
+### 7.2 `SwinV2Encoder2d`
+
+文件:`lib/modules/swinv2_encoder_2d.py`
+
+职责:
+
+1. 构建 SwinV2 backbone
+2. 输出多尺度特征列表
+3. 支持是否包含 patch embed 特征
+4. 支持是否输出多尺度特征
+
+当前主链只使用它的:
+
+1. `features`
+
+### 7.3 `SegmentationDecoder2d`
+
+文件:`lib/modules/decoder_2d.py`
+
+这是当前纯分割解码器。
+
+每层解码块 `SegmentationDecodeBlock2d` 的工作方式是:
+
+1. 对高层特征上采样
+2. 对 skip 特征做通道映射
+3. 拼接后卷积融合
+4. 用 `DecodeRefineBlock2d` 做轻量残差细化
+
+当前已经没有:
+
+1. FWTA 注入接口
+2. stability prior
+3. saliency prior
+4. boundary hint
+
+### 7.4 `SegmentationHead2d`
+
+文件:`lib/modules/segmentation_2d.py`
+
+职责:
+
+1. 接收解码器输出特征
+2. 经过卷积块生成分割 logits
+3. 上采样到输入分辨率
+
+最终输出:
+
+1. `seg_logits`
+
+---
+
+## 8. 当前训练器链路详解
+
+### 8.1 `SupervisedSegmentationTrainer`
+
+文件:`lib/trainers/supervised.py`
+
+虽然类名仍然保留 `SegmentationTrainer` 字样,但当前内容已经是纯单头分割训练。
+
+### 8.2 `build()`
+
+`build()` 当前做的事:
+
+1. 构建 `SegmentationModel2d`
+2. 构建 optimizer
+3. 构建 scheduler
+4. 构建主分割 loss
+5. 构建 train loader
+6. 构建 val loader
+7. 恢复 checkpoint
+8. 初始化日志系统
+
+### 8.3 `_compute_losses()`
+
+当前 loss 计算已经简化为:
+
+```text
+outputs = model(image)
+seg_logits = outputs["seg_logits"]
+seg_loss = criterion(seg_logits, mask)
+total_loss = seg_loss
+```
+
+也就是:
+
+```text
+total_loss = seg_loss
+```
+
+### 8.4 `train()`
+
+当前训练循环做的事:
+
+1. 前向得到 `seg_logits`
+2. 计算 `seg_loss`
+3. backward
+4. 可选梯度累计
+5. 可选梯度裁剪
+6. optimizer step
+7. scheduler step
+8. 验证集评估
+9. checkpoint 与 early stopping
+
+---
+
+## 9. 当前数据系统详解
+
+### 9.1 当前输入输出格式
+
+`SegmentationRecordDataset` 返回:
+
+1. `image`
+2. `mask`
+3. `dataset_name`
+4. `sample_id`
+5. `split`
+6. `class_name`
+7. `meta`
+
+### 9.2 当前数据读取流程
+
+1. 读 RGB 图像,归一化到 `[0, 1]`
+2. 读二值分割 mask
+3. 对 image 和 mask 做联合增强
+4. resize 到配置指定尺寸
+5. 送入 DataLoader
+
+### 9.3 当前支持的数据集
+
+当前 `lib/data/builder.py` 支持:
+
+1. `BUS-UCLM`
+2. `BUSI`
+3. `BUS-BRA`
+4. `BUS_UC`
+5. `CCAUI`
+6. `DDTI`
+7. `OTU_2d`
+8. `TN3K`
+9. `TG3K`
+
+---
+
+## 10. 当前损失、指标与优化流程
+
+### 10.1 当前损失
+
+当前主路径只保留主分割损失。
+
+默认配置示例:
+
+```yaml
+loss:
+  name: dicece
+  task_mode: binary
+  params:
+    include_background: true
+    lambda_dice: 0.7
+    lambda_ce: 0.3
+```
+
+也就是说当前优化目标只围绕 mask 分割本身。
+
+### 10.2 当前指标
+
+当前验证指标通常为:
+
+1. `dice`
+2. `iou`
+
+### 10.3 当前优化器
+
+默认配置:
+
+1. `adamw`
+2. 支持 warmup + cosine scheduler
+
+### 10.4 当前 AMP
+
+当前仍支持:
+
+1. `torch.autocast`
+2. `GradScaler`
+
+---
+
+## 11. 关键配置变量解释(中英对照)
+
+### 11.1 `trainer.name`
+
+1. 中文:训练器名称
+2. English: trainer name
+3. 作用:决定实例化哪个 trainer
+4. 当前值:`supervised_segmentation`
+
+### 11.2 `train.seed`
+
+1. 中文:随机种子
+2. English: random seed
+3. 作用:控制随机性
+
+### 11.3 `train.epochs`
+
+1. 中文:训练轮数
+2. English: number of epochs
+3. 作用:决定最大训练轮数
+
+### 11.4 `train.batch_size`
+
+1. 中文:训练 batch 大小
+2. English: training batch size
+3. 作用:控制每批样本数
+
+### 11.5 `train.val_batch_size`
+
+1. 中文:验证 batch 大小
+2. English: validation batch size
+3. 作用:控制验证批次大小
+
+### 11.6 `train.accum_steps`
+
+1. 中文:梯度累计步数
+2. English: gradient accumulation steps
+3. 作用:模拟更大有效 batch
+
+### 11.7 `train.amp`
+
+1. 中文:自动混合精度
+2. English: automatic mixed precision
+3. 作用:降低显存、提高吞吐
+
+### 11.8 `train.grad_clip.enabled`
+
+1. 中文:是否启用梯度裁剪
+2. English: enable gradient clipping
+3. 作用:提升训练稳定性
+
+### 11.9 `metrics.task_mode`
+
+1. 中文:指标任务模式
+2. English: metric task mode
+3. 作用:指定 binary 或 multiclass
+
+### 11.10 `loss.name`
+
+1. 中文:损失名称
+2. English: loss name
+3. 作用:指定分割损失类型
+
+### 11.11 `dataset.dataset_name`
+
+1. 中文:数据集名称
+2. English: dataset name
+3. 作用:决定用哪个数据集构建器
+
+### 11.12 `dataset.root`
+
+1. 中文:数据根目录
+2. English: dataset root
+3. 作用:决定从哪里读取数据
+
+### 11.13 `dataset.image_size`
+
+1. 中文:输入图像尺寸
+2. English: input image size
+3. 作用:控制 resize 大小
+
+### 11.14 `dataset.num_classes`
+
+1. 中文:分割类别数
+2. English: number of classes
+3. 作用:决定分割头输出通道数
+
+### 11.15 `model.model_name`
+
+1. 中文:backbone 名称
+2. English: backbone model name
+3. 作用:决定加载哪个 SwinV2 结构配置
+
+### 11.16 `model.decoder_channels`
+
+1. 中文:解码器通道数配置
+2. English: decoder channel sizes
+3. 作用:控制解码器每层输出维度
+
+### 11.17 `model.use_multiscale_features`
+
+1. 中文:是否使用多尺度特征
+2. English: use multi-scale features
+3. 作用:控制 encoder 是否输出多层特征给 decoder
+
+### 11.18 `model.include_patch_embed`
+
+1. 中文:是否包含 patch embed 特征
+2. English: include patch embedding feature
+3. 作用:控制最浅层特征是否进入解码链
+
+### 11.19 `optimizer.lr`
+
+1. 中文:学习率
+2. English: learning rate
+3. 作用:控制参数更新步长
+
+### 11.20 `checkpoint.monitor`
+
+1. 中文:监控指标
+2. English: monitored metric
+3. 作用:决定 best checkpoint 基于什么指标保存
+
+---
+
+## 12. 关键运行时变量解释(中英对照)
+
+### 12.1 `cfg`
+
+1. 中文:总配置字典
+2. English: global config dictionary
+3. 作用:控制整个训练流程
+
+### 12.2 `image`
+
+1. 中文:输入图像
+2. English: input image
+3. 典型形状:`[B, C, H, W]`
+
+### 12.3 `mask`
+
+1. 中文:真实分割掩膜
+2. English: ground-truth mask
+3. 典型形状:`[B, 1, H, W]`
+
+### 12.4 `features`
+
+1. 中文:编码器多尺度特征
+2. English: multi-scale features
+3. 作用:供 decoder 恢复空间细节
+
+### 12.5 `decoder_out`
+
+1. 中文:解码器输出特征
+2. English: decoder output feature
+3. 作用:作为分割头输入
+
+### 12.6 `seg_logits`
+
+1. 中文:分割 logits
+2. English: segmentation logits
+3. 作用:主预测结果
+
+### 12.7 `seg_loss`
+
+1. 中文:主分割损失
+2. English: segmentation loss
+3. 作用:训练优化目标
+
+### 12.8 `total_loss`
+
+1. 中文:总损失
+2. English: total loss
+3. 当前关系:`total_loss = seg_loss`
+
+### 12.9 `best_metric`
+
+1. 中文:最佳验证指标
+2. English: best validation metric
+3. 作用:控制 best checkpoint 与 early stopping
+
+### 12.10 `grad_scaler`
+
+1. 中文:混合精度梯度缩放器
+2. English: gradient scaler
+3. 作用:保证 AMP 下训练稳定
+
+---
+
+## 13. 当前实验脚本的真实含义
+
+文件:`tools/run_us_experiments.sh`
+
+当前脚本只做纯分割实验组织。
+
+支持两种模式:
+
+1. 单个数据集训练
+2. 所有数据集批量训练
+
+已经不再支持:
+
+1. 边界辅助损失消融
+2. 一致性损失消融
+3. FWTA 消融
+
+---
+
+## 14. 当前状态的最简结论
+
+当前项目已经完成主路径净化,可以用下面几句话概括:
+
+1. 当前主模型是 `SegmentationModel2d`。
+2. 当前主结构是 `SwinV2Encoder2d + SegmentationDecoder2d + SegmentationHead2d`。
+3. 当前训练只做单头分割,不再有边界分支。
+4. 当前总损失就是主分割损失,不再有边界损失和一致性损失。
+5. 当前配置文件与实验脚本已经去掉 `aux_loss` 和相关消融入口。
+6. 当前现役主链不再调用 FWTA,也不再对外保留 FWTA 语义。
+

+ 17 - 6
tools/summarize_results.py

@@ -34,6 +34,10 @@ def _infer_ratio(ckpt: dict[str, Any], path: Path) -> str:
     return "-"
 
 
+def _infer_ablation_case(ckpt: dict[str, Any], path: Path) -> str:
+    return "-"
+
+
 def _extract_metric(metrics: dict[str, Any], *names: str) -> float | None:
     for name in names:
         value = metrics.get(name)
@@ -50,6 +54,7 @@ def collect_rows(outputs_dir: Path) -> list[dict[str, Any]]:
         row = {
             "dataset": _infer_dataset(ckpt, best_path),
             "mode": _infer_mode(best_path),
+            "ablation_case": _infer_ablation_case(ckpt, best_path),
             "ratio": _infer_ratio(ckpt, best_path),
             "epoch": ckpt.get("epoch"),
             "best_metric": ckpt.get("best_metric"),
@@ -63,7 +68,7 @@ def collect_rows(outputs_dir: Path) -> list[dict[str, Any]]:
 
 def write_csv(rows: list[dict[str, Any]], path: Path) -> None:
     path.parent.mkdir(parents=True, exist_ok=True)
-    fieldnames = ["dataset", "mode", "ratio", "epoch", "best_metric", "dice", "iou", "checkpoint"]
+    fieldnames = ["dataset", "mode", "ablation_case", "ratio", "epoch", "best_metric", "dice", "iou", "checkpoint"]
     with path.open("w", encoding="utf-8", newline="") as handle:
         writer = csv.DictWriter(handle, fieldnames=fieldnames)
         writer.writeheader()
@@ -75,16 +80,16 @@ def write_markdown(rows: list[dict[str, Any]], path: Path) -> None:
     lines = [
         "# 实验结果汇总",
         "",
-        "| dataset | mode | ratio | epoch | best_metric | dice | iou | checkpoint |",
-        "| --- | --- | --- | --- | --- | --- | --- | --- |",
+        "| dataset | mode | ablation_case | ratio | epoch | best_metric | dice | iou | checkpoint |",
+        "| --- | --- | --- | --- | --- | --- | --- | --- | --- |",
     ]
     for row in rows:
         lines.append(
-            f"| {row['dataset']} | {row['mode']} | {row['ratio']} | {row['epoch']} | "
+            f"| {row['dataset']} | {row['mode']} | {row['ablation_case']} | {row['ratio']} | {row['epoch']} | "
             f"{row['best_metric']} | {row['dice']} | {row['iou']} | {row['checkpoint']} |"
         )
     if not rows:
-        lines.append("| - | - | - | - | - | - | - | - |")
+        lines.append("| - | - | - | - | - | - | - | - | - |")
     path.write_text("\n".join(lines) + "\n", encoding="utf-8")
 
 
@@ -103,7 +108,13 @@ def main() -> None:
     write_csv(rows, csv_path)
     write_markdown(rows, md_path)
 
-    print({"num_results": len(rows), "csv": str(csv_path), "markdown": str(md_path)})
+    print(
+        {
+            "num_results": len(rows),
+            "csv": str(csv_path),
+            "markdown": str(md_path),
+        }
+    )
 
 
 if __name__ == "__main__":