Przeglądaj źródła

refactor: switch branch to supervised-only training

kekezack 1 miesiąc temu
rodzic
commit
120b6a80b0

+ 1 - 1
README.md

@@ -1,3 +1,3 @@
 # X_SSL_Net
 
-ultrasound segmentation semi-supervised medical-imaging
+ultrasound segmentation supervised medical-imaging

+ 102 - 0
configs/segmentation/train_sup_us_template.yaml

@@ -0,0 +1,102 @@
+trainer:
+  name: supervised_segmentation
+
+train:
+  seed: 42
+  epochs: 200
+  batch_size: 4
+  val_batch_size: 4
+  amp: true
+  num_workers: 4
+  pin_memory: true
+  persistent_workers: true
+  prefetch_factor: 2
+  device: cuda
+  grad_clip:
+    enabled: true
+    max_norm: 1.0
+    norm_type: 2.0
+  auto_batch_size:
+    enabled: false
+    target_memory_fraction: 0.75
+    reference_gpu_gb: 8.0
+    reference_batch_size: 4
+    min_batch_size: 1
+    max_batch_size: 8
+
+metrics:
+  task_mode: binary
+  metrics:
+    - name: dice
+    - name: iou
+
+validation:
+  enabled: true
+  interval: 1
+  threshold: 0.5
+  early_stopping: true
+  early_stopping_patience: 40
+  early_stopping_min_delta: 0.0
+  metrics:
+    task_mode: binary
+    metrics:
+      - name: dice
+      - name: iou
+
+dataset:
+  name: ultrasound_sup_seg
+  dataset_name: BUSI
+  root: data/BUSI
+  split: train
+  split_file: null
+  val_split: val
+  val_split_file: null
+  image_size: [256, 256]
+  in_channels: 3
+  num_classes: 1
+
+model:
+  model_name: swinv2_tiny_patch4_window8_256
+  load_weights: false
+  decoder_channels: [384, 192, 96, 96]
+  fwta_wavelet: haar
+  fwta_level: 1
+  fwta_sigma_ratio: 0.35
+  fwta_tau_fourier: 0.15
+  fwta_gate_temperature: 1.0
+  use_multiscale_features: true
+  include_patch_embed: true
+
+optimizer:
+  name: adamw
+  lr: 1.0e-4
+  weight_decay: 0.05
+
+scheduler:
+  name: cosine
+  warmup:
+    name: linear
+    params:
+      start_factor: 0.1
+      total_iters: 10
+  params:
+    T_max: 190
+    eta_min: 1.0e-6
+
+checkpoint:
+  dir: outputs/supervised_segmentation/train_sup_us_template
+  save: true
+  save_last: true
+  monitor: dice
+  monitor_mode: max
+  resume: null
+  resume_strict: true
+  resume_training: true
+
+logging:
+  log_interval: 10
+  print_training_setup: true
+  use_swanlab: true
+  project: X_SSL_Net
+  experiment_name: supervised_segmentation
+  swanlab_mode: null

+ 92 - 0
configs/segmentation/us_exp_sup_busi.yaml

@@ -0,0 +1,92 @@
+trainer:
+  name: supervised_segmentation
+
+train:
+  seed: 42
+  epochs: 200
+  batch_size: 4
+  val_batch_size: 4
+  amp: true
+  num_workers: 4
+  pin_memory: true
+  persistent_workers: true
+  prefetch_factor: 2
+  device: cuda
+  grad_clip:
+    enabled: true
+    max_norm: 1.0
+    norm_type: 2.0
+
+metrics:
+  task_mode: binary
+  metrics:
+    - name: dice
+    - name: iou
+
+validation:
+  enabled: true
+  interval: 1
+  threshold: 0.5
+  early_stopping: true
+  early_stopping_patience: 40
+  early_stopping_min_delta: 0.0
+  metrics:
+    task_mode: binary
+    metrics:
+      - name: dice
+      - name: iou
+
+dataset:
+  dataset_name: BUSI
+  root: data/BUSI
+  split: train
+  val_split: val
+  image_size: [256, 256]
+  in_channels: 3
+  num_classes: 1
+
+model:
+  model_name: swinv2_tiny_patch4_window8_256
+  load_weights: false
+  decoder_channels: [384, 192, 96, 96]
+  fwta_wavelet: haar
+  fwta_level: 1
+  fwta_sigma_ratio: 0.35
+  fwta_tau_fourier: 0.15
+  fwta_gate_temperature: 1.0
+  use_multiscale_features: true
+  include_patch_embed: true
+
+optimizer:
+  name: adamw
+  lr: 1.0e-4
+  weight_decay: 0.05
+
+scheduler:
+  name: cosine
+  warmup:
+    name: linear
+    params:
+      start_factor: 0.1
+      total_iters: 10
+  params:
+    T_max: 190
+    eta_min: 1.0e-6
+
+checkpoint:
+  dir: outputs/experiments/supervised/BUSI
+  save: true
+  save_last: true
+  monitor: dice
+  monitor_mode: max
+  resume: null
+  resume_strict: true
+  resume_training: true
+
+logging:
+  log_interval: 10
+  print_training_setup: true
+  use_swanlab: true
+  project: X_SSL_Net
+  experiment_name: sup_busi
+  swanlab_mode: null

+ 14 - 19
lib/modules/__init__.py

@@ -1,5 +1,8 @@
 from .attentions_2d import CirculantAttention2d, ComplexLinear, WaveletAttentionGlobalBranch2d
 from .blocks_2d import WaveletFFTBlock2d, WaveletFFTMRFFIModule2d
+from .build_swinv2 import build_swinv2, build_swinv2_auto
+from .decoder_2d import BoundaryRefineBlock2d, StructureAwareDecodeBlock2d, StructureAwareDecoder2d
+from .fwta_2d import FourierWaveletTokenAggregation
 from .layers_2d import (
     BNLinear1d,
     Conv2dBN,
@@ -10,16 +13,8 @@ from .layers_2d import (
     Residual,
     Scale,
 )
-from .nets_2d import (
-    WaveletFFTNet2d,
-    wavelet_fft_b1,
-    wavelet_fft_b2,
-    wavelet_fft_b4,
-    wavelet_fft_s6,
-    wavelet_fft_t2,
-    wavelet_fft_t4,
-)
-from .build_swinv2 import build_swinv2, build_swinv2_auto
+from .segmentation_2d import GlobalTokenConditioning2d, SegmentationNet2d
+from .swinv2_fwta_encoder_2d import SwinV2FWTAEncoder2d
 
 __all__ = [
     "CirculantAttention2d",
@@ -27,6 +22,12 @@ __all__ = [
     "WaveletAttentionGlobalBranch2d",
     "WaveletFFTBlock2d",
     "WaveletFFTMRFFIModule2d",
+    "build_swinv2",
+    "build_swinv2_auto",
+    "BoundaryRefineBlock2d",
+    "StructureAwareDecodeBlock2d",
+    "StructureAwareDecoder2d",
+    "FourierWaveletTokenAggregation",
     "BNLinear1d",
     "Conv2dBN",
     "DWConv2dBNReLU",
@@ -35,13 +36,7 @@ __all__ = [
     "PatchMerging2d",
     "Residual",
     "Scale",
-    "WaveletFFTNet2d",
-    "wavelet_fft_t2",
-    "wavelet_fft_t4",
-    "wavelet_fft_s6",
-    "wavelet_fft_b1",
-    "wavelet_fft_b2",
-    "wavelet_fft_b4",
-    "build_swinv2",
-    "build_swinv2_auto",
+    "GlobalTokenConditioning2d",
+    "SegmentationNet2d",
+    "SwinV2FWTAEncoder2d",
 ]

+ 138 - 0
lib/modules/segmentation_2d.py

@@ -0,0 +1,138 @@
+from __future__ import annotations
+
+from argparse import Namespace
+from pathlib import Path
+from typing import Any, Sequence
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .decoder_2d import StructureAwareDecoder2d
+from .layers_2d import Conv2dBN
+from .swinv2_fwta_encoder_2d import SwinV2FWTAEncoder2d
+
+
+class SegmentationHead2d(nn.Module):
+    def __init__(self, in_channels: int, out_channels: int) -> None:
+        super().__init__()
+        self.block = nn.Sequential(
+            Conv2dBN(in_channels, in_channels, 3, 1, 1),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True),
+        )
+
+    def forward(self, x: torch.Tensor, output_size: tuple[int, int]) -> torch.Tensor:
+        x = self.block(x)
+        return F.interpolate(x, size=output_size, mode="bilinear", align_corners=False)
+
+
+class BoundaryHead2d(nn.Module):
+    def __init__(self, in_channels: int, out_channels: int = 1) -> None:
+        super().__init__()
+        self.block = nn.Sequential(
+            Conv2dBN(in_channels, in_channels, 3, 1, 1),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True),
+        )
+
+    def forward(self, x: torch.Tensor, output_size: tuple[int, int]) -> torch.Tensor:
+        x = self.block(x)
+        return F.interpolate(x, size=output_size, mode="bilinear", align_corners=False)
+
+
+class GlobalTokenConditioning2d(nn.Module):
+    """
+    使用 FWTA 更新后的全局前景 token 对解码特征做通道调制。
+    """
+
+    def __init__(self, token_channels: int, feature_channels: int) -> None:
+        super().__init__()
+        hidden_channels = max(feature_channels // 2, 32)
+        self.gate = nn.Sequential(
+            nn.LayerNorm(token_channels),
+            nn.Linear(token_channels, hidden_channels),
+            nn.GELU(),
+            nn.Linear(hidden_channels, feature_channels),
+            nn.Sigmoid(),
+        )
+
+    def forward(self, x: torch.Tensor, global_token: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+        channel_gate = self.gate(global_token).unsqueeze(-1).unsqueeze(-1)
+        return x * (1.0 + channel_gate), channel_gate
+
+
+class SegmentationNet2d(nn.Module):
+    """
+    第一版超声分割主网络骨架。
+
+    当前职责:
+    - 编码器输出多尺度特征和稳定性图
+    - 结构感知解码器恢复分割特征
+    - 同时输出分割图和边界图
+    """
+
+    def __init__(
+            self,
+            num_classes: int,
+            model_name: str | None = None,
+            config_path: str | Path | None = None,
+            weight_path: str | Path | None = None,
+            args: Namespace | None = None,
+            *,
+            decoder_channels: Sequence[int] | None = None,
+            load_weights: bool = True,
+            **encoder_kwargs: Any,
+    ) -> None:
+        super().__init__()
+        self.encoder = SwinV2FWTAEncoder2d(
+            model_name=model_name,
+            config_path=config_path,
+            weight_path=weight_path,
+            args=args,
+            load_weights=load_weights,
+            **encoder_kwargs,
+        )
+        self.decoder = StructureAwareDecoder2d(
+            encoder_channels=self.encoder.stage_channels,
+            decoder_channels=decoder_channels,
+        )
+        self.global_conditioning = GlobalTokenConditioning2d(
+            token_channels=self.encoder.stage_channels[-1],
+            feature_channels=self.decoder.out_channels,
+        )
+        self.segmentation_head = SegmentationHead2d(self.decoder.out_channels, num_classes)
+        self.boundary_head = BoundaryHead2d(self.decoder.out_channels, out_channels=1)
+
+    def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor | list[torch.Tensor]]:
+        encoder_outputs = self.encoder(x)
+        features = encoder_outputs["features"]
+        stability_map = encoder_outputs["stability_map"]
+
+        decoder_out, decoder_features = self.decoder(
+            features=features,
+            stability_map=stability_map,
+        )
+        conditioned_decoder_out, global_channel_gate = self.global_conditioning(
+            decoder_out,
+            encoder_outputs["global_token"],
+        )
+
+        output_size = x.shape[-2:]
+        seg_logits = self.segmentation_head(conditioned_decoder_out, output_size=output_size)
+        boundary_logits = self.boundary_head(conditioned_decoder_out, output_size=output_size)
+
+        return {
+            "seg_logits": seg_logits,
+            "boundary_logits": boundary_logits,
+            "stability_map": F.interpolate(
+                stability_map, size=output_size, mode="bilinear", align_corners=False
+            ),
+            "encoder_features": features,
+            "decoder_features": decoder_features,
+            "conditioned_decoder_feature": conditioned_decoder_out,
+            "deepest_feature": encoder_outputs["deepest_feature"],
+            "global_token": encoder_outputs["global_token"],
+            "global_channel_gate": global_channel_gate,
+            "token_gate": encoder_outputs["token_gate"],
+        }

+ 8 - 0
lib/tools/__init__.py

@@ -1,4 +1,6 @@
+from .boundary import boundary_band_map, logits_to_boundary, logits_to_binary_mask, mask_to_boundary_map
 from .loss import DEFAULT_TASK_LOSS, LOSS_REGISTRY, build_loss
+from .loss import BinaryBoundaryLoss, MaskBoundaryConsistencyLoss
 from .metrics import (
     DEFAULT_METRIC_CONFIG,
     METRIC_REGISTRY,
@@ -26,7 +28,13 @@ __all__ = [
     "METRIC_REGISTRY",
     "OPTIMIZER_REGISTRY",
     "SCHEDULER_REGISTRY",
+    "mask_to_boundary_map",
+    "logits_to_binary_mask",
+    "logits_to_boundary",
+    "boundary_band_map",
     "build_loss",
+    "BinaryBoundaryLoss",
+    "MaskBoundaryConsistencyLoss",
     "build_metric",
     "build_metrics",
     "compute_metrics",

+ 52 - 0
lib/tools/boundary.py

@@ -0,0 +1,52 @@
+from __future__ import annotations
+
+import torch
+import torch.nn.functional as F
+
+
+def _ensure_nchw(mask: torch.Tensor) -> torch.Tensor:
+    if mask.ndim == 3:
+        return mask.unsqueeze(1)
+    if mask.ndim != 4:
+        raise ValueError(f"Expected mask with 3 or 4 dims, got shape {tuple(mask.shape)}")
+    return mask
+
+
+def mask_to_boundary_map(mask: torch.Tensor, dilation: int = 1) -> torch.Tensor:
+    """
+    通过最大池化近似形态学梯度,生成边界图。
+    """
+    mask = _ensure_nchw(mask).float()
+    kernel_size = dilation * 2 + 1
+    pad = dilation
+    dilated = F.max_pool2d(mask, kernel_size=kernel_size, stride=1, padding=pad)
+    eroded = -F.max_pool2d(-mask, kernel_size=kernel_size, stride=1, padding=pad)
+    boundary = (dilated - eroded).clamp_min(0.0)
+    return (boundary > 0).float()
+
+
+def logits_to_binary_mask(logits: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
+    if logits.shape[1] == 1:
+        probs = torch.sigmoid(logits)
+        return (probs >= threshold).float()
+    preds = torch.argmax(logits, dim=1, keepdim=True)
+    return preds.float()
+
+
+def logits_to_boundary(logits: torch.Tensor, threshold: float = 0.5, dilation: int = 1) -> torch.Tensor:
+    mask = logits_to_binary_mask(logits, threshold=threshold)
+    return mask_to_boundary_map(mask, dilation=dilation)
+
+
+def boundary_band_map(boundary: torch.Tensor, radius: int = 2) -> torch.Tensor:
+    boundary = _ensure_nchw(boundary).float()
+    kernel_size = radius * 2 + 1
+    return F.max_pool2d(boundary, kernel_size=kernel_size, stride=1, padding=radius)
+
+
+__all__ = [
+    "mask_to_boundary_map",
+    "logits_to_binary_mask",
+    "logits_to_boundary",
+    "boundary_band_map",
+]

+ 42 - 1
lib/tools/loss.py

@@ -2,6 +2,8 @@ from __future__ import annotations
 
 from typing import Any
 
+import torch
+import torch.nn.functional as F
 from torch import nn
 
 try:
@@ -140,4 +142,43 @@ def build_loss(config: dict[str, Any]) -> nn.Module:
     return loss_cls(**params)
 
 
-__all__ = ["DEFAULT_TASK_LOSS", "LOSS_REGISTRY", "build_loss"]
+class BinaryBoundaryLoss(nn.Module):
+    def __init__(self, bce_weight: float = 1.0, dice_weight: float = 1.0, eps: float = 1e-6) -> None:
+        super().__init__()
+        self.bce_weight = bce_weight
+        self.dice_weight = dice_weight
+        self.eps = eps
+
+    def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+        target = target.float()
+        bce = F.binary_cross_entropy_with_logits(logits, target)
+        probs = torch.sigmoid(logits)
+        intersection = (probs * target).sum(dim=(1, 2, 3))
+        union = probs.sum(dim=(1, 2, 3)) + target.sum(dim=(1, 2, 3))
+        dice = 1.0 - ((2.0 * intersection + self.eps) / (union + self.eps))
+        return self.bce_weight * bce + self.dice_weight * dice.mean()
+
+
+class MaskBoundaryConsistencyLoss(nn.Module):
+    def forward(self, seg_logits: torch.Tensor, boundary_logits: torch.Tensor) -> torch.Tensor:
+        if seg_logits.shape[1] == 1:
+            seg_prob = torch.sigmoid(seg_logits)
+        else:
+            seg_prob = torch.softmax(seg_logits, dim=1)[:, 1:2]
+
+        boundary_prob = torch.sigmoid(boundary_logits)
+        grad_x = torch.abs(seg_prob[:, :, :, 1:] - seg_prob[:, :, :, :-1])
+        grad_y = torch.abs(seg_prob[:, :, 1:, :] - seg_prob[:, :, :-1, :])
+        grad_x = F.pad(grad_x, (0, 1, 0, 0))
+        grad_y = F.pad(grad_y, (0, 0, 0, 1))
+        edge_proxy = torch.clamp(grad_x + grad_y, 0.0, 1.0)
+        return F.l1_loss(boundary_prob, edge_proxy)
+
+
+__all__ = [
+    "DEFAULT_TASK_LOSS",
+    "LOSS_REGISTRY",
+    "build_loss",
+    "BinaryBoundaryLoss",
+    "MaskBoundaryConsistencyLoss",
+]

+ 10 - 0
lib/trainers/__init__.py

@@ -0,0 +1,10 @@
+from .base import BaseTrainer
+from .builder import TRAINER_REGISTRY, build_trainer
+from .supervised import SupervisedSegmentationTrainer
+
+__all__ = [
+    "BaseTrainer",
+    "TRAINER_REGISTRY",
+    "build_trainer",
+    "SupervisedSegmentationTrainer",
+]

+ 636 - 0
lib/trainers/base.py

@@ -0,0 +1,636 @@
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from pathlib import Path
+import pprint
+import time
+from typing import Any
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.amp import GradScaler
+
+from lib.data import build_dataloader
+from lib.tools import build_metrics, compute_metrics, reset_metrics, update_metrics
+
+try:
+    import swanlab
+except ImportError:
+    swanlab = None
+
+
+class BaseTrainer(ABC):
+    """
+    训练器基类。
+
+    设计目标:
+    - 统一配置入口
+    - 统一模型/优化器/调度器创建
+    - 不同训练流程只重写最少的方法
+    """
+
+    def __init__(self, cfg: dict[str, Any], args: Any | None = None) -> None:
+        self.cfg = cfg
+        self.args = args
+        self.device = self._build_device()
+        self.output_dir = self._build_output_dir()
+        self.start_epoch = 0
+        self.best_metric: float | None = None
+        self.no_improve_epochs = 0
+        self.swanlab_run = None
+        self.grad_scaler = GradScaler("cuda", enabled=self._amp_enabled())
+
+    def _build_device(self) -> torch.device:
+        device_name = self.cfg.get("train", {}).get("device", "cpu")
+        if device_name == "cuda" and not torch.cuda.is_available():
+            device_name = "cpu"
+        return torch.device(device_name)
+
+    def _build_output_dir(self) -> Path:
+        output_dir = self.cfg.get("checkpoint", {}).get("dir", "outputs/supervised_segmentation")
+        path = Path(output_dir)
+        path.mkdir(parents=True, exist_ok=True)
+        return path
+
+    def _amp_enabled(self) -> bool:
+        return bool(self.cfg.get("train", {}).get("amp", False)) and self.device.type == "cuda"
+
+    def _auto_batch_size_cfg(self) -> dict[str, Any]:
+        cfg = self.cfg.get("train", {}).get("auto_batch_size", {})
+        return cfg if isinstance(cfg, dict) else {}
+
+    def _auto_batch_size_enabled(self) -> bool:
+        return bool(self._auto_batch_size_cfg().get("enabled", False))
+
+    def _gpu_total_memory_gb(self) -> float | None:
+        if self.device.type != "cuda" or not torch.cuda.is_available():
+            return None
+        props = torch.cuda.get_device_properties(self.device)
+        return float(props.total_memory / (1024 ** 3))
+
+    def _estimate_auto_batch_size(self, *, default_batch_size: int, ssl: bool = False) -> int:
+        cfg = self._auto_batch_size_cfg()
+        if not cfg.get("enabled", False):
+            return int(default_batch_size)
+
+        total_gb = self._gpu_total_memory_gb()
+        if total_gb is None:
+            return int(default_batch_size)
+
+        target_fraction = float(cfg.get("target_memory_fraction", 0.75))
+        target_fraction = min(max(target_fraction, 0.1), 0.95)
+        reference_gpu_gb = float(cfg.get("reference_gpu_gb", 8.0))
+        reference_batch_size = int(cfg.get("reference_batch_size", default_batch_size))
+        max_batch_size = int(cfg.get("max_batch_size", reference_batch_size))
+        min_batch_size = int(cfg.get("min_batch_size", 1))
+
+        memory_penalty = float(cfg.get("memory_penalty", 1.0 if not ssl else 1.35))
+        scaled = int((reference_batch_size * total_gb * target_fraction) / max(reference_gpu_gb * 0.75 * memory_penalty, 1e-6))
+        batch_size = max(min_batch_size, min(max_batch_size, max(default_batch_size, scaled)))
+        return int(batch_size)
+
+    def _resolve_batch_size(self, key: str, default: int, *, ssl: bool = False) -> int:
+        train_cfg = self.cfg.get("train", {})
+        configured = int(train_cfg.get(key, default))
+        batch_size = self._estimate_auto_batch_size(default_batch_size=configured, ssl=ssl)
+        if self._auto_batch_size_enabled() and batch_size != configured:
+            print(
+                {
+                    "message": "auto_batch_size adjusted",
+                    "key": key,
+                    "configured": configured,
+                    "resolved": batch_size,
+                    "gpu_total_gb": self._gpu_total_memory_gb(),
+                }
+            )
+        return batch_size
+
+    def _dataset_cfg(self) -> dict[str, Any]:
+        return self.cfg.get("dataset", {})
+
+    def _dataset_name(self) -> str:
+        dataset_cfg = self._dataset_cfg()
+        dataset_name = dataset_cfg.get("dataset_name") or dataset_cfg.get("name")
+        if not dataset_name:
+            raise ValueError("dataset.dataset_name is required.")
+        return str(dataset_name)
+
+    def _dataset_root(self) -> str:
+        dataset_cfg = self._dataset_cfg()
+        root = dataset_cfg.get("root")
+        if not root:
+            raise ValueError("dataset.root is required.")
+        return str(root)
+
+    def _image_size(self) -> tuple[int, int]:
+        dataset_cfg = self._dataset_cfg()
+        image_size = dataset_cfg.get("image_size")
+        if image_size is None:
+            raise ValueError("dataset.image_size is required.")
+        return int(image_size[0]), int(image_size[1])
+
+    def _build_resize_transform(self, *, mode: str) -> Any:
+        height, width = self._image_size()
+        interpolation_mode = "bilinear" if mode == "image" else "nearest"
+
+        def _transform(tensor: torch.Tensor) -> torch.Tensor:
+            resized = F.interpolate(
+                tensor.unsqueeze(0),
+                size=(height, width),
+                mode=interpolation_mode,
+                align_corners=False if interpolation_mode != "nearest" else None,
+            )
+            return resized.squeeze(0)
+
+        return _transform
+
+    def _build_segmentation_loader(
+            self,
+            *,
+            split: str,
+            batch_size: int,
+            shuffle: bool,
+            split_file: str | None = None,
+    ):
+        dataset_cfg = self._dataset_cfg()
+        train_cfg = self.cfg.get("train", {})
+        num_workers = max(0, int(train_cfg.get("num_workers", 0)))
+        persistent_workers = bool(train_cfg.get("persistent_workers", False)) if num_workers > 0 else False
+        loader = build_dataloader(
+            dataset_name=self._dataset_name(),
+            root=self._dataset_root(),
+            split=split,
+            split_file=split_file,
+            batch_size=batch_size,
+            shuffle=shuffle,
+            num_workers=num_workers,
+            image_transform=self._build_resize_transform(mode="image"),
+            mask_transform=self._build_resize_transform(mode="mask"),
+            pin_memory=bool(train_cfg.get("pin_memory", self.device.type == "cuda")),
+            persistent_workers=persistent_workers,
+            prefetch_factor=train_cfg.get("prefetch_factor") if num_workers > 0 else None,
+        )
+        return loader
+
+    def _build_val_loader(
+            self,
+            *,
+            batch_size: int,
+            shuffle: bool = False,
+    ):
+        dataset_cfg = self._dataset_cfg()
+        val_split = dataset_cfg.get("val_split", "val")
+        if val_split is None:
+            return None
+        return self._build_segmentation_loader(
+            split=str(val_split),
+            split_file=dataset_cfg.get("val_split_file"),
+            batch_size=batch_size,
+            shuffle=shuffle,
+        )
+
+    def _checkpoint_cfg(self) -> dict[str, Any]:
+        return self.cfg.get("checkpoint", {})
+
+    def _logging_cfg(self) -> dict[str, Any]:
+        return self.cfg.get("logging", {})
+
+    def _validation_cfg(self) -> dict[str, Any]:
+        return self.cfg.get("validation", {})
+
+    def _checkpoint_enabled(self) -> bool:
+        return bool(self._checkpoint_cfg().get("save", True))
+
+    def _best_mode(self) -> str:
+        return str(self._checkpoint_cfg().get("monitor_mode", "min"))
+
+    def _is_better_metric(self, metric: float) -> bool:
+        if self.best_metric is None:
+            return True
+        if self._best_mode() == "max":
+            return metric > self.best_metric
+        return metric < self.best_metric
+
+    def _save_checkpoint(self, filename: str, state: dict[str, Any]) -> Path | None:
+        if not self._checkpoint_enabled():
+            return None
+        path = self.output_dir / filename
+        torch.save(state, path)
+        return path
+
+    def _resume_checkpoint_path(self) -> Path | None:
+        resume_path = self._checkpoint_cfg().get("resume")
+        if not resume_path:
+            return None
+        path = Path(str(resume_path))
+        if not path.is_absolute():
+            path = Path.cwd() / path
+        return path
+
+    def _maybe_resume(
+            self,
+            *,
+            module_map: dict[str, Any],
+            optimizer: Any | None = None,
+            scheduler: Any | None = None,
+    ) -> dict[str, Any] | None:
+        path = self._resume_checkpoint_path()
+        if path is None:
+            return None
+        if not path.exists():
+            raise FileNotFoundError(f"Resume checkpoint not found: {path}")
+
+        checkpoint = torch.load(path, map_location="cpu")
+        strict = bool(self._checkpoint_cfg().get("resume_strict", True))
+        for key, module in module_map.items():
+            if module is None:
+                continue
+            state_dict = checkpoint.get(key)
+            if state_dict is not None:
+                module.load_state_dict(state_dict, strict=strict)
+
+        if optimizer is not None and checkpoint.get("optimizer") is not None:
+            optimizer.load_state_dict(checkpoint["optimizer"])
+        if scheduler is not None and checkpoint.get("scheduler") is not None:
+            scheduler.load_state_dict(checkpoint["scheduler"])
+        if checkpoint.get("grad_scaler") is not None:
+            self.grad_scaler.load_state_dict(checkpoint["grad_scaler"])
+
+        if checkpoint.get("best_metric") is not None:
+            self.best_metric = float(checkpoint["best_metric"])
+        elif checkpoint.get("metrics") is not None:
+            monitor_name = str(self._checkpoint_cfg().get("monitor", "total"))
+            monitor_value = checkpoint["metrics"].get(f"val_{monitor_name}")
+            if monitor_value is None:
+                monitor_value = checkpoint["metrics"].get(monitor_name)
+            if monitor_value is not None:
+                self.best_metric = float(monitor_value)
+        if checkpoint.get("no_improve_epochs") is not None:
+            self.no_improve_epochs = int(checkpoint["no_improve_epochs"])
+
+        if bool(self._checkpoint_cfg().get("resume_training", True)):
+            self.start_epoch = int(checkpoint.get("epoch", -1)) + 1
+        return checkpoint
+
+    def _validation_enabled(self) -> bool:
+        return bool(self._validation_cfg().get("enabled", True))
+
+    def _validation_interval(self) -> int:
+        return max(1, int(self._validation_cfg().get("interval", 1)))
+
+    def _should_validate(self, epoch: int) -> bool:
+        return self._validation_enabled() and ((epoch + 1) % self._validation_interval() == 0)
+
+    def _metric_task_mode(self) -> str:
+        validation_cfg = self._validation_cfg()
+        metrics_cfg = validation_cfg.get("metrics", self.cfg.get("metrics"))
+        if isinstance(metrics_cfg, dict):
+            return str(metrics_cfg.get("task_mode", "binary"))
+        return "binary"
+
+    def _metric_threshold(self) -> float:
+        validation_cfg = self._validation_cfg()
+        threshold = validation_cfg.get("threshold", 0.5)
+        return float(threshold)
+
+    def _build_validation_metrics(self) -> dict[str, Any]:
+        validation_cfg = self._validation_cfg()
+        metrics_cfg = validation_cfg.get("metrics", self.cfg.get("metrics"))
+        if metrics_cfg is None:
+            return {}
+        return build_metrics(metrics_cfg)
+
+    def _early_stopping_enabled(self) -> bool:
+        return bool(self._validation_cfg().get("early_stopping", False))
+
+    def _early_stopping_patience(self) -> int:
+        return max(1, int(self._validation_cfg().get("early_stopping_patience", 10)))
+
+    def _early_stopping_min_delta(self) -> float:
+        return float(self._validation_cfg().get("early_stopping_min_delta", 0.0))
+
+    def _update_validation_metrics(
+            self,
+            metrics: dict[str, Any],
+            *,
+            logits: torch.Tensor,
+            target: torch.Tensor,
+    ) -> None:
+        if not metrics:
+            return
+        update_metrics(
+            metrics,
+            logits,
+            target,
+            task_mode=self._metric_task_mode(),
+            threshold=self._metric_threshold(),
+            num_classes=int(self._dataset_cfg().get("num_classes", 1)),
+        )
+
+    def _compute_validation_metric_values(self, metrics: dict[str, Any]) -> dict[str, float]:
+        if not metrics:
+            return {}
+        values = compute_metrics(metrics)
+        reset_metrics(metrics)
+        return values
+
+    def _init_swanlab(self) -> None:
+        logging_cfg = self._logging_cfg()
+        if not bool(logging_cfg.get("use_swanlab", False)):
+            return
+        if swanlab is None:
+            print("SwanLab is not installed. Logging will continue without SwanLab.")
+            return
+
+        run_name = logging_cfg.get("experiment_name") or self.output_dir.name
+        self.swanlab_run = swanlab.init(
+            project=logging_cfg.get("project", "X_SSL_Net"),
+            name=run_name,
+            config=self.cfg,
+            mode=logging_cfg.get("swanlab_mode"),
+        )
+
+    def _log_metrics(self, metrics: dict[str, float], *, step: int) -> None:
+        if self.swanlab_run is None:
+            return
+        swanlab.log(metrics, step=step)
+
+    def _close_loggers(self) -> None:
+        if self.swanlab_run is not None:
+            swanlab.finish()
+            self.swanlab_run = None
+
+    def _log_interval(self) -> int:
+        return max(1, int(self._logging_cfg().get("log_interval", 20)))
+
+    def _grad_clip_cfg(self) -> dict[str, Any]:
+        cfg = self.cfg.get("train", {}).get("grad_clip", {})
+        return cfg if isinstance(cfg, dict) else {}
+
+    def _grad_clip_enabled(self) -> bool:
+        return bool(self._grad_clip_cfg().get("enabled", False))
+
+    def _clip_gradients(self, module: nn.Module | None) -> float | None:
+        if module is None or not self._grad_clip_enabled():
+            return None
+        cfg = self._grad_clip_cfg()
+        max_norm = float(cfg.get("max_norm", 1.0))
+        norm_type = float(cfg.get("norm_type", 2.0))
+        params = [param for param in module.parameters() if param.requires_grad and param.grad is not None]
+        if not params:
+            return None
+        total_norm = torch.nn.utils.clip_grad_norm_(params, max_norm=max_norm, norm_type=norm_type)
+        return float(total_norm.detach().cpu() if isinstance(total_norm, torch.Tensor) else total_norm)
+
+    def _current_lrs(self, optimizer: Any | None) -> list[float]:
+        if optimizer is None:
+            return []
+        return [float(group.get("lr", 0.0)) for group in optimizer.param_groups]
+
+    @staticmethod
+    def _count_parameters(module: nn.Module | None) -> dict[str, int]:
+        if module is None:
+            return {"total": 0, "trainable": 0}
+        total = sum(param.numel() for param in module.parameters())
+        trainable = sum(param.numel() for param in module.parameters() if param.requires_grad)
+        return {"total": int(total), "trainable": int(trainable)}
+
+    @staticmethod
+    def _loader_summary(loader: Any | None) -> dict[str, Any] | None:
+        if loader is None:
+            return None
+        dataset = getattr(loader, "dataset", None)
+        return {
+            "dataset_size": len(dataset) if dataset is not None else None,
+            "num_batches": len(loader),
+            "batch_size": getattr(loader, "batch_size", None),
+            "num_workers": getattr(loader, "num_workers", None),
+            "pin_memory": getattr(loader, "pin_memory", None),
+            "persistent_workers": getattr(loader, "persistent_workers", None),
+            "prefetch_factor": getattr(loader, "prefetch_factor", None),
+            "drop_last": getattr(loader, "drop_last", None),
+        }
+
+    def _training_setup_summary(
+            self,
+            *,
+            model_map: dict[str, nn.Module | None],
+            loader_map: dict[str, Any | None],
+            optimizer: Any | None = None,
+            scheduler: Any | None = None,
+    ) -> dict[str, Any]:
+        return {
+            "trainer": self.cfg.get("trainer", {}).get("name"),
+            "device": str(self.device),
+            "amp_enabled": self._amp_enabled(),
+            "output_dir": str(self.output_dir),
+            "start_epoch": self.start_epoch,
+            "train": self.cfg.get("train", {}),
+            "dataset": self.cfg.get("dataset", {}),
+            "model": self.cfg.get("model", {}),
+            "optimizer": self.cfg.get("optimizer", {}),
+            "scheduler": self.cfg.get("scheduler"),
+            "current_lrs": self._current_lrs(optimizer),
+            "validation": self.cfg.get("validation", {}),
+            "checkpoint": self.cfg.get("checkpoint", {}),
+            "logging": self.cfg.get("logging", {}),
+            "model_parameters": {
+                name: self._count_parameters(module)
+                for name, module in model_map.items()
+            },
+            "loaders": {
+                name: self._loader_summary(loader)
+                for name, loader in loader_map.items()
+            },
+            "cuda": {
+                "available": torch.cuda.is_available(),
+                "device_name": torch.cuda.get_device_name(self.device) if self.device.type == "cuda" else None,
+                "device_count": torch.cuda.device_count(),
+            },
+        }
+
+    def _print_training_setup(
+            self,
+            *,
+            model_map: dict[str, nn.Module | None],
+            loader_map: dict[str, Any | None],
+            optimizer: Any | None = None,
+            scheduler: Any | None = None,
+    ) -> None:
+        if not bool(self._logging_cfg().get("print_training_setup", True)):
+            return
+        summary = self._training_setup_summary(
+            model_map=model_map,
+            loader_map=loader_map,
+            optimizer=optimizer,
+            scheduler=scheduler,
+        )
+        print("========== TRAINING SETUP ==========")
+        pprint.pprint(summary, sort_dicts=False, width=120)
+        print("======== END TRAINING SETUP ========")
+
+    def _gpu_memory_mb(self) -> float:
+        if self.device.type != "cuda" or not torch.cuda.is_available():
+            return 0.0
+        return float(torch.cuda.max_memory_allocated(device=self.device) / (1024 ** 2))
+
+    def _performance_snapshot(
+            self,
+            *,
+            epoch: int,
+            step: int,
+            num_steps: int,
+            data_time: float,
+            iter_time: float,
+            metrics: dict[str, float],
+            prefix: str = "train",
+    ) -> dict[str, float | int]:
+        snapshot: dict[str, float | int] = {
+            "epoch": epoch,
+            "step": step,
+            "num_steps": num_steps,
+            "data_time": data_time,
+            "iter_time": iter_time,
+            "gpu_memory_mb": self._gpu_memory_mb(),
+        }
+        lrs = self._current_lrs(getattr(self, "optimizer", None))
+        if lrs:
+            snapshot["lr"] = lrs[0]
+        for key, value in metrics.items():
+            snapshot[f"{prefix}_{key}"] = value
+        return snapshot
+
+    def _maybe_log_step(
+            self,
+            *,
+            epoch: int,
+            step: int,
+            num_steps: int,
+            data_time: float,
+            iter_time: float,
+            metrics: dict[str, float],
+            prefix: str = "train",
+    ) -> None:
+        if step % self._log_interval() != 0 and step != num_steps:
+            return
+        snapshot = self._performance_snapshot(
+            epoch=epoch,
+            step=step,
+            num_steps=num_steps,
+            data_time=data_time,
+            iter_time=iter_time,
+            metrics=metrics,
+            prefix=prefix,
+        )
+        print(snapshot)
+        log_metrics = {
+            f"{prefix}/{key}": value
+            for key, value in metrics.items()
+        }
+        log_metrics.update(
+            {
+                f"{prefix}/data_time": data_time,
+                f"{prefix}/iter_time": iter_time,
+                f"{prefix}/gpu_memory_mb": float(snapshot["gpu_memory_mb"]),
+            }
+        )
+        if "lr" in snapshot:
+            log_metrics[f"{prefix}/lr"] = float(snapshot["lr"])
+        self._log_metrics(log_metrics, step=epoch * max(1, num_steps) + step)
+
+    @staticmethod
+    def _average_metric_sums(metric_sums: dict[str, float], steps: int) -> dict[str, float]:
+        if steps <= 0:
+            return {}
+        return {key: value / steps for key, value in metric_sums.items()}
+
+    def _base_checkpoint_state(self, *, epoch: int, metrics: dict[str, float] | None = None) -> dict[str, Any]:
+        state = {
+            "epoch": epoch,
+            "cfg": self.cfg,
+            "metrics": metrics or {},
+            "grad_scaler": self.grad_scaler.state_dict(),
+            "no_improve_epochs": self.no_improve_epochs,
+        }
+        return state
+
+    def _finalize_epoch(
+            self,
+            *,
+            epoch: int,
+            train_metrics: dict[str, float],
+            val_metrics: dict[str, float] | None,
+            checkpoint_state: dict[str, Any],
+    ) -> tuple[dict[str, Any], bool]:
+        merged_metrics = dict(train_metrics)
+        if val_metrics is not None:
+            merged_metrics.update({f"val_{key}": value for key, value in val_metrics.items()})
+
+        improved = False
+        if val_metrics is not None:
+            monitor_name = str(self._checkpoint_cfg().get("monitor", "total"))
+            if monitor_name not in val_metrics:
+                raise KeyError(f"Checkpoint monitor '{monitor_name}' not found in val metrics.")
+            monitor_value = float(val_metrics[monitor_name])
+            delta = self._early_stopping_min_delta()
+            previous_best = self.best_metric
+            is_better = self._is_better_metric(monitor_value)
+            if previous_best is not None and self._best_mode() == "max":
+                is_better = monitor_value > (previous_best + delta)
+            elif previous_best is not None and self._best_mode() == "min":
+                is_better = monitor_value < (previous_best - delta)
+
+            if is_better:
+                self.best_metric = monitor_value
+                self.no_improve_epochs = 0
+                improved = True
+                best_state = dict(checkpoint_state)
+                best_state.update(
+                    self._base_checkpoint_state(
+                        epoch=epoch,
+                        metrics=merged_metrics,
+                    )
+                )
+                best_state["best_metric"] = self.best_metric
+                self._save_checkpoint("best.pth", best_state)
+            else:
+                self.no_improve_epochs += 1
+
+        save_last = bool(self._checkpoint_cfg().get("save_last", True))
+        if save_last:
+            last_state = dict(checkpoint_state)
+            last_state.update(self._base_checkpoint_state(epoch=epoch, metrics=merged_metrics))
+            if self.best_metric is not None:
+                last_state["best_metric"] = self.best_metric
+            self._save_checkpoint("last.pth", last_state)
+
+        summary = {"epoch": epoch}
+        summary.update(train_metrics)
+        if val_metrics is not None:
+            summary.update({f"val_{key}": value for key, value in val_metrics.items()})
+        if self.best_metric is not None:
+            summary["best_metric"] = float(self.best_metric)
+        summary["no_improve_epochs"] = self.no_improve_epochs
+        lrs = self._current_lrs(getattr(self, "optimizer", None))
+        if lrs:
+            summary["lr"] = lrs[0]
+        self._log_metrics(summary, step=epoch)
+        should_stop = False
+        if val_metrics is not None and self._early_stopping_enabled():
+            should_stop = self.no_improve_epochs >= self._early_stopping_patience()
+            summary["early_stop"] = should_stop
+            summary["improved"] = improved
+        return summary, should_stop
+
+    @abstractmethod
+    def build(self) -> None:
+        """
+        创建模型、优化器、数据加载器等运行所需对象。
+        """
+
+    @abstractmethod
+    def train(self) -> None:
+        """
+        执行完整训练流程。
+        """

+ 27 - 0
lib/trainers/builder.py

@@ -0,0 +1,27 @@
+from __future__ import annotations
+
+from typing import Any
+
+from .base import BaseTrainer
+from .supervised import SupervisedSegmentationTrainer
+
+
+TRAINER_REGISTRY = {
+    "supervised_segmentation": SupervisedSegmentationTrainer,
+}
+
+
+def build_trainer(cfg: dict[str, Any], args: Any | None = None) -> BaseTrainer:
+    trainer_cfg = cfg.get("trainer", {})
+    trainer_name = trainer_cfg.get("name", "supervised_segmentation")
+    trainer_cls = TRAINER_REGISTRY.get(trainer_name)
+    if trainer_cls is None:
+        raise ValueError(
+            f"Unsupported trainer '{trainer_name}'. Expected one of: {', '.join(TRAINER_REGISTRY)}."
+        )
+    trainer = trainer_cls(cfg=cfg, args=args)
+    trainer.build()
+    return trainer
+
+
+__all__ = ["TRAINER_REGISTRY", "build_trainer"]

+ 216 - 0
lib/trainers/supervised.py

@@ -0,0 +1,216 @@
+from __future__ import annotations
+
+import time
+from typing import Any
+
+import torch
+from torch.utils.data import DataLoader
+
+from lib.modules import SegmentationNet2d
+from lib.tools import (
+    BinaryBoundaryLoss,
+    MaskBoundaryConsistencyLoss,
+    build_optimizer,
+    build_scheduler,
+    mask_to_boundary_map,
+)
+from .base import BaseTrainer
+
+
+class SupervisedSegmentationTrainer(BaseTrainer):
+    def __init__(self, cfg: dict[str, Any], args: Any | None = None) -> None:
+        super().__init__(cfg=cfg, args=args)
+        self.model: SegmentationNet2d | None = None
+        self.optimizer = None
+        self.scheduler = None
+        self.loader: DataLoader | None = None
+        self.val_loader: DataLoader | None = None
+        self.seg_loss = None
+        self.boundary_loss = BinaryBoundaryLoss()
+        self.consistency_loss = MaskBoundaryConsistencyLoss()
+
+    def build(self) -> None:
+        dataset_cfg = self.cfg["dataset"]
+        model_cfg = self.cfg["model"]
+        train_cfg = self.cfg["train"]
+
+        self.model = SegmentationNet2d(
+            num_classes=dataset_cfg["num_classes"],
+            model_name=model_cfg["model_name"],
+            load_weights=model_cfg.get("load_weights", False),
+            decoder_channels=model_cfg.get("decoder_channels"),
+            fwta_wavelet=model_cfg.get("fwta_wavelet", "haar"),
+            fwta_level=model_cfg.get("fwta_level", 1),
+            fwta_sigma_ratio=model_cfg.get("fwta_sigma_ratio", 0.35),
+            fwta_tau_fourier=model_cfg.get("fwta_tau_fourier", 0.15),
+            fwta_gate_temperature=model_cfg.get("fwta_gate_temperature", 1.0),
+        ).to(self.device)
+
+        self.optimizer = build_optimizer(self.model, self.cfg["optimizer"])
+        self.scheduler = build_scheduler(self.optimizer, self.cfg.get("scheduler"))
+        self.loader = self._build_segmentation_loader(
+            split=str(dataset_cfg.get("split", "train")),
+            split_file=dataset_cfg.get("split_file"),
+            batch_size=self._resolve_batch_size("batch_size", 4),
+            shuffle=bool(train_cfg.get("shuffle", True)),
+        )
+        self.val_loader = self._build_val_loader(
+            batch_size=self._resolve_batch_size(
+                "val_batch_size",
+                int(train_cfg.get("batch_size", 4)),
+            ),
+            shuffle=False,
+        )
+        self._maybe_resume(
+            module_map={"model": self.model},
+            optimizer=self.optimizer,
+            scheduler=self.scheduler,
+        )
+        self._init_swanlab()
+
+    def _compute_losses(
+            self,
+            image: torch.Tensor,
+            mask: torch.Tensor,
+    ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
+        if self.model is None:
+            raise RuntimeError("Model is not initialized.")
+        with torch.autocast(device_type=self.device.type, enabled=self._amp_enabled()):
+            outputs = self.model(image)
+            seg_logits = outputs["seg_logits"]
+            boundary_logits = outputs["boundary_logits"]
+
+            seg_loss = torch.nn.functional.binary_cross_entropy_with_logits(seg_logits, mask)
+            boundary_target = mask_to_boundary_map(mask)
+            boundary_loss = self.boundary_loss(boundary_logits, boundary_target)
+            consistency_loss = self.consistency_loss(seg_logits, boundary_logits)
+            total_loss = seg_loss + boundary_loss + 0.1 * consistency_loss
+
+        losses = {
+            "total": total_loss,
+            "seg": seg_loss,
+            "boundary": boundary_loss,
+            "consistency": consistency_loss,
+        }
+        return outputs, losses
+
+    @staticmethod
+    def _detach_metrics(losses: dict[str, torch.Tensor]) -> dict[str, float]:
+        return {key: float(value.detach().cpu()) for key, value in losses.items()}
+
+    def _validate(self) -> dict[str, float] | None:
+        if self.model is None or self.val_loader is None:
+            return None
+
+        self.model.eval()
+        metrics = self._build_validation_metrics()
+        total = 0.0
+        seg = 0.0
+        boundary = 0.0
+        consistency = 0.0
+        steps = 0
+        with torch.no_grad():
+            for batch in self.val_loader:
+                image = batch["image"].to(self.device)
+                mask = batch["mask"].to(self.device)
+                outputs, losses = self._compute_losses(image, mask)
+                total += float(losses["total"].detach().cpu())
+                seg += float(losses["seg"].detach().cpu())
+                boundary += float(losses["boundary"].detach().cpu())
+                consistency += float(losses["consistency"].detach().cpu())
+                self._update_validation_metrics(
+                    metrics,
+                    logits=outputs["seg_logits"],
+                    target=mask,
+                )
+                steps += 1
+
+        if steps == 0:
+            return None
+        val_metrics = {
+            "total": total / steps,
+            "seg": seg / steps,
+            "boundary": boundary / steps,
+            "consistency": consistency / steps,
+        }
+        val_metrics.update(self._compute_validation_metric_values(metrics))
+        return val_metrics
+
+    def train(self) -> None:
+        if self.model is None or self.loader is None or self.optimizer is None:
+            raise RuntimeError("Trainer.build() must be called before train().")
+
+        epochs = int(self.cfg["train"].get("epochs", 1))
+        try:
+            self._print_training_setup(
+                model_map={"model": self.model},
+                loader_map={"train": self.loader, "val": self.val_loader},
+                optimizer=self.optimizer,
+                scheduler=self.scheduler,
+            )
+            for epoch in range(self.start_epoch, epochs):
+                self.model.train()
+                train_metric_sums = {
+                    "total": 0.0,
+                    "seg": 0.0,
+                    "boundary": 0.0,
+                    "consistency": 0.0,
+                }
+                train_metrics: dict[str, float] | None = None
+                end_time = time.perf_counter()
+                num_steps = len(self.loader)
+                for step, batch in enumerate(self.loader, start=1):
+                    data_time = time.perf_counter() - end_time
+                    iter_start = time.perf_counter()
+                    image = batch["image"].to(self.device)
+                    mask = batch["mask"].to(self.device)
+                    _, losses = self._compute_losses(image, mask)
+                    self.optimizer.zero_grad()
+                    self.grad_scaler.scale(losses["total"]).backward()
+                    grad_norm = None
+                    if self._grad_clip_enabled():
+                        self.grad_scaler.unscale_(self.optimizer)
+                        grad_norm = self._clip_gradients(self.model)
+                    self.grad_scaler.step(self.optimizer)
+                    self.grad_scaler.update()
+                    train_metrics = self._detach_metrics(losses)
+                    if grad_norm is not None:
+                        train_metrics["grad_norm"] = grad_norm
+                    for key, value in train_metrics.items():
+                        train_metric_sums.setdefault(key, 0.0)
+                        train_metric_sums[key] += value
+                    iter_time = time.perf_counter() - iter_start
+                    self._maybe_log_step(
+                        epoch=epoch,
+                        step=step,
+                        num_steps=num_steps,
+                        data_time=data_time,
+                        iter_time=iter_time,
+                        metrics=train_metrics,
+                        prefix="train",
+                    )
+                    end_time = time.perf_counter()
+
+                if self.scheduler is not None:
+                    self.scheduler.step()
+
+                if train_metrics is None:
+                    raise RuntimeError("Training loader is empty.")
+                train_metrics = self._average_metric_sums(train_metric_sums, num_steps)
+                val_metrics = self._validate() if self._should_validate(epoch) else None
+                summary, should_stop = self._finalize_epoch(
+                    epoch=epoch,
+                    train_metrics=train_metrics,
+                    val_metrics=val_metrics,
+                    checkpoint_state={
+                        "model": self.model.state_dict(),
+                        "optimizer": self.optimizer.state_dict(),
+                        "scheduler": self.scheduler.state_dict() if self.scheduler is not None else None,
+                    },
+                )
+                print(summary)
+                if should_stop:
+                    print({"epoch": epoch, "message": "early stopping triggered"})
+                    break
+        finally:
+            self._close_loggers()

+ 91 - 0
tools/run_us_experiments.sh

@@ -0,0 +1,91 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "$ROOT_DIR"
+
+# ===== 可直接改这里 =====
+DATASET="${DATASET:-BUSI}"          # BUS-UCLM | BUSI | BUS-BRA | BUS_UC | CCAUI | DDTI | OTU_2d | TN3K | TG3K
+SEED="${SEED:-42}"
+RUN_ALL_SUP="${RUN_ALL_SUP:-0}"     # 1 表示跑内置所有全监督实验
+PYTHON_BIN="${PYTHON_BIN:-python}"
+EXTRA_SET_ARGS="${EXTRA_SET_ARGS:-}"
+
+# ===== 数据集根目录 =====
+dataset_root() {
+  case "$1" in
+    "BUS-UCLM") echo "data/BUS-UCLM" ;;
+    "BUSI") echo "data/BUSI" ;;
+    "BUS-BRA") echo "data/BUS-BRA" ;;
+    "BUS_UC") echo "data/BUS_UC" ;;
+    "CCAUI") echo "data/CCAUI" ;;
+    "DDTI") echo "data/DDTI" ;;
+    "OTU_2d") echo "data/OTU_2d" ;;
+    "TN3K") echo "data/TN3K" ;;
+    "TG3K") echo "data/TG3K" ;;
+    *) echo "Unsupported dataset: $1" >&2; exit 1 ;;
+  esac
+}
+
+# ===== 是否需要项目级 train/val =====
+needs_project_split() {
+  case "$1" in
+    "BUS-UCLM"|"BUSI"|"BUS-BRA"|"BUS_UC"|"CCAUI"|"DDTI") return 0 ;;
+    *) return 1 ;;
+  esac
+}
+
+prepare_project_splits() {
+  local dataset="$1"
+  local root
+  root="$(dataset_root "$dataset")"
+
+  if needs_project_split "$dataset"; then
+    echo "[split] generate project split for ${dataset}"
+    "$PYTHON_BIN" tmp/generate_project_split.py --dataset "$dataset" --root "$root" --seed "$SEED"
+  fi
+}
+
+run_supervised() {
+  local dataset="$1"
+  local root
+  root="$(dataset_root "$dataset")"
+  prepare_project_splits "$dataset"
+  echo "[train] supervised ${dataset}"
+  "$PYTHON_BIN" tools/train.py \
+    --config configs/segmentation/train_sup_us_template.yaml \
+    --set \
+      dataset.dataset_name="$dataset" \
+      dataset.root="$root" \
+      checkpoint.dir="outputs/experiments/supervised/${dataset}" \
+      logging.experiment_name="sup_${dataset}" \
+      ${EXTRA_SET_ARGS}
+}
+
+run_all_supervised_suite() {
+  local datasets=(
+    "BUS-UCLM"
+    "BUSI"
+    "BUS-BRA"
+    "BUS_UC"
+    "CCAUI"
+    "DDTI"
+    "OTU_2d"
+    "TN3K"
+    "TG3K"
+  )
+  for ds in "${datasets[@]}"; do
+    run_supervised "$ds"
+  done
+}
+
+main() {
+  if [[ "$RUN_ALL_SUP" == "1" ]]; then
+    run_all_supervised_suite
+    exit 0
+  fi
+
+  run_supervised "$DATASET"
+}
+
+main "$@"

+ 110 - 0
tools/summarize_results.py

@@ -0,0 +1,110 @@
+from __future__ import annotations
+
+import argparse
+import csv
+from pathlib import Path
+from typing import Any
+
+import torch
+
+
+def _infer_mode(path: Path) -> str:
+    parts = set(path.parts)
+    if "supervised" in parts:
+        return "supervised"
+    return "unknown"
+
+
+def _infer_dataset(ckpt: dict[str, Any], path: Path) -> str:
+    cfg = ckpt.get("cfg", {})
+    dataset_cfg = cfg.get("dataset", {})
+    dataset_name = dataset_cfg.get("dataset_name") or dataset_cfg.get("name")
+    if dataset_name:
+        return str(dataset_name)
+
+    parts = path.parts
+    if "supervised" in parts:
+        idx = parts.index("supervised")
+        if idx + 1 < len(parts):
+            return parts[idx + 1]
+    return "unknown"
+
+
+def _infer_ratio(ckpt: dict[str, Any], path: Path) -> str:
+    return "-"
+
+
+def _extract_metric(metrics: dict[str, Any], *names: str) -> float | None:
+    for name in names:
+        value = metrics.get(name)
+        if value is not None:
+            return float(value)
+    return None
+
+
+def collect_rows(outputs_dir: Path) -> list[dict[str, Any]]:
+    rows: list[dict[str, Any]] = []
+    for best_path in sorted(outputs_dir.rglob("best.pth")):
+        ckpt = torch.load(best_path, map_location="cpu")
+        metrics = ckpt.get("metrics", {}) or {}
+        row = {
+            "dataset": _infer_dataset(ckpt, best_path),
+            "mode": _infer_mode(best_path),
+            "ratio": _infer_ratio(ckpt, best_path),
+            "epoch": ckpt.get("epoch"),
+            "best_metric": ckpt.get("best_metric"),
+            "dice": _extract_metric(metrics, "val_dice", "dice"),
+            "iou": _extract_metric(metrics, "val_iou", "val_miou", "iou", "miou"),
+            "checkpoint": str(best_path),
+        }
+        rows.append(row)
+    return rows
+
+
+def write_csv(rows: list[dict[str, Any]], path: Path) -> None:
+    path.parent.mkdir(parents=True, exist_ok=True)
+    fieldnames = ["dataset", "mode", "ratio", "epoch", "best_metric", "dice", "iou", "checkpoint"]
+    with path.open("w", encoding="utf-8", newline="") as handle:
+        writer = csv.DictWriter(handle, fieldnames=fieldnames)
+        writer.writeheader()
+        writer.writerows(rows)
+
+
+def write_markdown(rows: list[dict[str, Any]], path: Path) -> None:
+    path.parent.mkdir(parents=True, exist_ok=True)
+    lines = [
+        "# 实验结果汇总",
+        "",
+        "| dataset | mode | ratio | epoch | best_metric | dice | iou | checkpoint |",
+        "| --- | --- | --- | --- | --- | --- | --- | --- |",
+    ]
+    for row in rows:
+        lines.append(
+            f"| {row['dataset']} | {row['mode']} | {row['ratio']} | {row['epoch']} | "
+            f"{row['best_metric']} | {row['dice']} | {row['iou']} | {row['checkpoint']} |"
+        )
+    if not rows:
+        lines.append("| - | - | - | - | - | - | - | - |")
+    path.write_text("\n".join(lines) + "\n", encoding="utf-8")
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser(description="Summarize best experiment results from best.pth files.")
+    parser.add_argument("--outputs-dir", default="outputs", help="Root output directory")
+    parser.add_argument("--results-dir", default="results", help="Directory to write summary tables")
+    args = parser.parse_args()
+
+    outputs_dir = Path(args.outputs_dir)
+    results_dir = Path(args.results_dir)
+    rows = collect_rows(outputs_dir)
+
+    csv_path = results_dir / "experiment_summary.csv"
+    md_path = results_dir / "experiment_summary.md"
+    write_csv(rows, csv_path)
+    write_markdown(rows, md_path)
+
+    print({"num_results": len(rows), "csv": str(csv_path), "markdown": str(md_path)})
+
+
+if __name__ == "__main__":
+    main()

+ 15 - 0
tools/summarize_results.sh

@@ -0,0 +1,15 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "$ROOT_DIR"
+
+PYTHON_BIN="${PYTHON_BIN:-python}"
+OUTPUTS_DIR="${OUTPUTS_DIR:-outputs}"
+RESULTS_DIR="${RESULTS_DIR:-results}"
+
+"$PYTHON_BIN" tools/summarize_results.py --outputs-dir "$OUTPUTS_DIR" --results-dir "$RESULTS_DIR"
+
+echo "[done] results written to:"
+echo "  - ${RESULTS_DIR}/experiment_summary.csv"
+echo "  - ${RESULTS_DIR}/experiment_summary.md"

+ 53 - 0
tools/train.py

@@ -0,0 +1,53 @@
+from __future__ import annotations
+
+import argparse
+import sys
+from pathlib import Path
+
+ROOT_DIR = Path(__file__).resolve().parents[1]
+if str(ROOT_DIR) not in sys.path:
+    sys.path.insert(0, str(ROOT_DIR))
+
+from lib.trainers import build_trainer
+from lib.utils.config import apply_dotlist_overrides, load_yaml_config
+
+
+def parse_args() -> argparse.Namespace:
+    parser = argparse.ArgumentParser(description="Unified training entrypoint.")
+    parser.add_argument(
+        "--config",
+        type=str,
+        required=True,
+        help="Path to yaml config.",
+    )
+    parser.add_argument(
+        "--trainer",
+        type=str,
+        default=None,
+        help="Override trainer name from config.",
+    )
+    parser.add_argument(
+        "--set",
+        nargs="*",
+        default=None,
+        help="Override config values with key=value pairs, e.g. train.epochs=2 model.load_weights=false",
+    )
+    return parser.parse_args()
+
+
+def main() -> None:
+    args = parse_args()
+    cfg_path = ROOT_DIR / args.config if not Path(args.config).is_absolute() else Path(args.config)
+    cfg = load_yaml_config(cfg_path)
+    cfg = apply_dotlist_overrides(cfg, args.set)
+
+    if args.trainer is not None:
+        cfg.setdefault("trainer", {})
+        cfg["trainer"]["name"] = args.trainer
+
+    trainer = build_trainer(cfg, args=args)
+    trainer.train()
+
+
+if __name__ == "__main__":
+    main()