Forráskód Böngészése

feat(data): 添加数据处理模块和配置频率细化参数

添加完整的数据处理模块,包括数据集构建器、增强、加载器等组件,
同时在模型配置中添加低频半径相关参数用于频率细化功能。

- 新增 lib/data 模块,包含数据集索引、增强、加载等功能
- 在 .gitignore 中添加 docs 目录和数据相关路径排除规则
- 在多个 YAML 配置文件中添加频率细化相关的配置参数:
  - low_freq_radius_h: 0.25
  - low_freq_radius_w: 0.25
  - learnable_low_freq_radius: true
kekezack 2 hete
szülő
commit
5f33d6e8b4

+ 60 - 58
.gitignore

@@ -1,58 +1,60 @@
-# Python
-__pycache__/
-*.py[cod]
-*$py.class
-*.so
-*.egg
-*.egg-info/
-dist/
-build/
-*.whl
-
-# TypeScript
-lib/sam2/demo/
-
-
-# IDE
-.idea/
-.vscode/
-*.swp
-*.swo
-
-# OS
-.DS_Store
-Thumbs.db
-
-# Reference code & papers (do not upload)
-ref/
-tmp/
-
-# Weights & checkpoints
-*.pth
-*.pt
-*.ckpt
-*.onnx
-
-# Logs & outputs
-*.log
-outputs/
-results/
-runs/
-lightning_logs/
-swanlog/
-
-# Jupyter
-.ipynb_checkpoints/
-
-# Environment
-.env
-.venv/
-venv/
-
-# data
-data/
-cache/
-
-# Codex and .gitignore
-.codex
-.gitignore
+# Python
+__pycache__/
+*.py[cod]
+*$py.class
+*.so
+*.egg
+*.egg-info/
+dist/
+build/
+*.whl
+
+# TypeScript
+lib/sam2/demo/
+
+
+# IDE
+.idea/
+.vscode/
+*.swp
+*.swo
+
+# OS
+.DS_Store
+Thumbs.db
+
+# Reference code & papers (do not upload)
+ref/
+tmp/
+
+# Weights & checkpoints
+*.pth
+*.pt
+*.ckpt
+*.onnx
+
+# Logs & outputs
+*.log
+outputs/
+results/
+runs/
+lightning_logs/
+swanlog/
+
+# Jupyter
+.ipynb_checkpoints/
+
+# Environment
+.env
+.venv/
+venv/
+
+# Data and generated documentation
+/data/
+cache/
+docs/
+!lib/data/
+
+# Codex and .gitignore
+.codex
+.gitignore

+ 3 - 0
configs/segmentation/train_sup_us_template.yaml

@@ -81,6 +81,9 @@ model:
   ssm_forward_type: v3
   ssm_backend: auto
   use_frequency_refine: true
+  low_freq_radius_h: 0.25
+  low_freq_radius_w: 0.25
+  learnable_low_freq_radius: true
   guide_mode: affine
   out_channels: null
 

+ 3 - 0
configs/segmentation/us_exp_sup_busi.yaml

@@ -71,6 +71,9 @@ model:
   ssm_forward_type: v3
   ssm_backend: auto
   use_frequency_refine: true
+  low_freq_radius_h: 0.25
+  low_freq_radius_w: 0.25
+  learnable_low_freq_radius: true
   guide_mode: affine
   out_channels: null
 

+ 3 - 0
configs/segmentation/us_exp_sup_busi_ablation.yaml

@@ -71,6 +71,9 @@ model:
   ssm_forward_type: v3
   ssm_backend: auto
   use_frequency_refine: false
+  low_freq_radius_h: 0.25
+  low_freq_radius_w: 0.25
+  learnable_low_freq_radius: true
   guide_mode: affine
   out_channels: null
 

+ 51 - 0
lib/data/__init__.py

@@ -0,0 +1,51 @@
+from .augment import SegmentationAugmentation, build_segmentation_augmentation
+from .builder import DATASET_REGISTRY, build_dataset_index
+from .collate import record_collate_fn
+from .datasets import SegmentationRecordDataset, default_image_loader, default_mask_loader
+from .ddti import build_ddti_mask, parse_ddti_xml
+from .loaders import (
+    OFFICIAL_SPLIT_FILES,
+    apply_official_split,
+    build_dataloader,
+    build_record_dataset,
+    get_official_split_file,
+    list_supported_splits,
+)
+from .project_splits import (
+    PROJECT_SPLIT_DATASETS,
+    PROJECT_SPLIT_ROOT,
+    generate_project_splits,
+    get_project_split_file,
+    load_project_split_ids,
+    select_project_split_base_records,
+)
+from .records import SegSampleRecord
+from .splits import load_id_txt, load_json_split
+
+__all__ = [
+    "DATASET_REGISTRY",
+    "OFFICIAL_SPLIT_FILES",
+    "PROJECT_SPLIT_DATASETS",
+    "PROJECT_SPLIT_ROOT",
+    "SegmentationAugmentation",
+    "SegSampleRecord",
+    "build_segmentation_augmentation",
+    "build_dataset_index",
+    "record_collate_fn",
+    "SegmentationRecordDataset",
+    "default_image_loader",
+    "default_mask_loader",
+    "parse_ddti_xml",
+    "build_ddti_mask",
+    "load_id_txt",
+    "load_json_split",
+    "apply_official_split",
+    "build_record_dataset",
+    "build_dataloader",
+    "get_official_split_file",
+    "generate_project_splits",
+    "get_project_split_file",
+    "load_project_split_ids",
+    "list_supported_splits",
+    "select_project_split_base_records",
+]

+ 73 - 0
lib/data/augment.py

@@ -0,0 +1,73 @@
+from __future__ import annotations
+
+from typing import Any
+
+import torch
+
+
+def _rand_uniform(low: float, high: float) -> float:
+    return float(torch.empty(1).uniform_(low, high).item())
+
+
+class SegmentationAugmentation:
+    def __init__(self, config: dict[str, Any] | None = None) -> None:
+        self.config = config or {}
+
+    def __call__(
+        self,
+        image: torch.Tensor,
+        mask: torch.Tensor | None = None,
+    ) -> tuple[torch.Tensor, torch.Tensor | None]:
+        image, mask = self._apply_spatial(image, mask)
+        image = self._apply_intensity(image)
+        return image, mask
+
+    def _apply_spatial(
+        self,
+        image: torch.Tensor,
+        mask: torch.Tensor | None,
+    ) -> tuple[torch.Tensor, torch.Tensor | None]:
+        if bool(self.config.get("random_flip", False)):
+            if torch.rand(1).item() < 0.5:
+                image = torch.flip(image, dims=(-1,))
+                if mask is not None:
+                    mask = torch.flip(mask, dims=(-1,))
+            if torch.rand(1).item() < 0.5:
+                image = torch.flip(image, dims=(-2,))
+                if mask is not None:
+                    mask = torch.flip(mask, dims=(-2,))
+
+        if bool(self.config.get("random_rotate_90", False)):
+            k = int(torch.randint(0, 4, (1,)).item())
+            if k > 0:
+                image = torch.rot90(image, k=k, dims=(-2, -1))
+                if mask is not None:
+                    mask = torch.rot90(mask, k=k, dims=(-2, -1))
+
+        return image, mask
+
+    def _apply_intensity(self, image: torch.Tensor) -> torch.Tensor:
+        if bool(self.config.get("random_brightness_contrast", False)):
+            brightness = float(self.config.get("brightness_limit", 0.15))
+            contrast = float(self.config.get("contrast_limit", 0.15))
+            brightness_factor = _rand_uniform(1.0 - brightness, 1.0 + brightness)
+            contrast_factor = _rand_uniform(1.0 - contrast, 1.0 + contrast)
+            mean = image.mean(dim=(-2, -1), keepdim=True)
+            image = (image - mean) * contrast_factor + mean
+            image = image * brightness_factor
+
+        if bool(self.config.get("random_gaussian_noise", False)):
+            std = float(self.config.get("gaussian_noise_std", 0.03))
+            if std > 0:
+                image = image + torch.randn_like(image) * std
+
+        return image.clamp(0.0, 1.0)
+
+
+def build_segmentation_augmentation(config: dict[str, Any] | None):
+    if not config:
+        return None
+    return SegmentationAugmentation(config)
+
+
+__all__ = ["SegmentationAugmentation", "build_segmentation_augmentation"]

+ 151 - 0
lib/data/builder.py

@@ -0,0 +1,151 @@
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Callable
+
+from .indexers import (
+    build_filename_matched_records,
+    build_paired_folder_records,
+    build_prefixed_paired_records,
+    build_pre_split_records,
+    build_stem_paired_records,
+    build_xml_annotation_records,
+)
+from .records import SegSampleRecord
+
+
+def _build_bus_uclm(root: Path) -> list[SegSampleRecord]:
+    return build_paired_folder_records(
+        dataset_name="BUS-UCLM",
+        image_dir=root / "images",
+        mask_dir=root / "masks",
+    )
+
+
+def _build_tg3k(root: Path) -> list[SegSampleRecord]:
+    return build_paired_folder_records(
+        dataset_name="TG3K",
+        image_dir=root / "thyroid-image",
+        mask_dir=root / "thyroid-mask",
+    )
+
+
+def _build_tn3k(root: Path) -> list[SegSampleRecord]:
+    return build_pre_split_records(
+        dataset_name="TN3K",
+        train_image_dir=root / "trainval-image",
+        train_mask_dir=root / "trainval-mask",
+        test_image_dir=root / "test-image",
+        test_mask_dir=root / "test-mask",
+    )
+
+
+def _build_otu_2d(root: Path) -> list[SegSampleRecord]:
+    return build_stem_paired_records(
+        dataset_name="OTU_2d",
+        image_dir=root / "images",
+        mask_dir=root / "annotations",
+    )
+
+
+def _build_busi(root: Path) -> list[SegSampleRecord]:
+    base = root / "Dataset_BUSI_with_GT"
+    records: list[SegSampleRecord] = []
+    for class_name in ["benign", "malignant", "normal"]:
+        class_dir = base / class_name
+        records.extend(
+            build_filename_matched_records(
+                dataset_name="BUSI",
+                folder=class_dir,
+                class_name=class_name,
+            )
+        )
+    return records
+
+
+def _build_bus_bra(root: Path) -> list[SegSampleRecord]:
+    image_dir = root / "BUSBRA" / "BUSBRA" / "Images"
+    mask_dir = root / "BUSBRA" / "BUSBRA" / "Masks"
+    return build_prefixed_paired_records(
+        dataset_name="BUS-BRA",
+        image_dir=image_dir,
+        mask_dir=mask_dir,
+        image_prefix_to_strip="bus_",
+        mask_prefix_to_strip="mask_",
+    )
+
+
+def _build_bus_uc(root: Path) -> list[SegSampleRecord]:
+    base = root / "BUS_UC" / "BUS_UC"
+    records: list[SegSampleRecord] = []
+
+    records.extend(
+        build_paired_folder_records(
+            dataset_name="BUS_UC",
+            image_dir=base / "All" / "images",
+            mask_dir=base / "All" / "masks",
+            split="all",
+            class_name="all",
+        )
+    )
+    records.extend(
+        build_paired_folder_records(
+            dataset_name="BUS_UC",
+            image_dir=base / "Benign" / "images",
+            mask_dir=base / "Benign" / "masks",
+            class_name="benign",
+        )
+    )
+    records.extend(
+        build_paired_folder_records(
+            dataset_name="BUS_UC",
+            image_dir=base / "Malignant" / "images",
+            mask_dir=base / "Malignant" / "masks",
+            class_name="malignant",
+        )
+    )
+    return records
+
+
+def _build_ccaui(root: Path) -> list[SegSampleRecord]:
+    base = root / "Common Carotid Artery Ultrasound Images"
+    return build_paired_folder_records(
+        dataset_name="CCAUI",
+        image_dir=base / "US images",
+        mask_dir=base / "Expert mask images",
+    )
+
+
+def _build_ddti(root: Path) -> list[SegSampleRecord]:
+    return build_xml_annotation_records(
+        dataset_name="DDTI",
+        root=root,
+    )
+
+
+DATASET_REGISTRY: dict[str, Callable[[Path], list[SegSampleRecord]]] = {
+    "BUS-UCLM": _build_bus_uclm,
+    "TG3K": _build_tg3k,
+    "TN3K": _build_tn3k,
+    "OTU_2d": _build_otu_2d,
+    "BUSI": _build_busi,
+    "BUS-BRA": _build_bus_bra,
+    "BUS_UC": _build_bus_uc,
+    "CCAUI": _build_ccaui,
+    "DDTI": _build_ddti,
+}
+
+
+def build_dataset_index(dataset_name: str, root: str | Path) -> list[SegSampleRecord]:
+    builder = DATASET_REGISTRY.get(dataset_name)
+    if builder is None:
+        raise ValueError(
+            f"Unsupported dataset '{dataset_name}'. Expected one of: {', '.join(DATASET_REGISTRY)}."
+        )
+    root = Path(root)
+    if not root.exists():
+        raise FileNotFoundError(f"Dataset root not found: {root}")
+    return builder(root)
+
+
+__all__ = ["DATASET_REGISTRY", "build_dataset_index"]

+ 40 - 0
lib/data/collate.py

@@ -0,0 +1,40 @@
+from __future__ import annotations
+
+from collections.abc import Sequence
+from typing import Any
+
+import torch
+
+
+def record_collate_fn(batch: Sequence[dict[str, Any]]) -> dict[str, Any]:
+    if not batch:
+        raise ValueError("Empty batch is not allowed.")
+
+    collated: dict[str, Any] = {}
+    keys = batch[0].keys()
+    for key in keys:
+        values = [sample[key] for sample in batch]
+        first = values[0]
+
+        if torch.is_tensor(first):
+            shapes = [tuple(value.shape) for value in values]
+            if all(shape == shapes[0] for shape in shapes):
+                collated[key] = torch.stack(values, dim=0)
+            else:
+                collated[key] = values
+            continue
+
+        if first is None:
+            collated[key] = values
+            continue
+
+        if isinstance(first, (str, int, float, dict)):
+            collated[key] = values
+            continue
+
+        collated[key] = values
+
+    return collated
+
+
+__all__ = ["record_collate_fn"]

+ 82 - 0
lib/data/datasets.py

@@ -0,0 +1,82 @@
+from __future__ import annotations
+
+from collections.abc import Callable
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+import torch
+from PIL import Image
+from torch.utils.data import Dataset
+
+from .ddti import build_ddti_mask
+from .records import SegSampleRecord
+
+
+def default_image_loader(path: str | Path) -> torch.Tensor:
+    image = Image.open(path).convert("RGB")
+    array = np.asarray(image, dtype=np.float32) / 255.0
+    return torch.from_numpy(array).permute(2, 0, 1).contiguous()
+
+
+def default_mask_loader(path: str | Path) -> torch.Tensor:
+    mask = Image.open(path).convert("L")
+    array = (np.asarray(mask, dtype=np.float32) > 0).astype(np.float32)
+    return torch.from_numpy(array).unsqueeze(0).contiguous()
+
+
+class SegmentationRecordDataset(Dataset):
+    def __init__(
+            self,
+            records: list[SegSampleRecord],
+            image_loader: Callable[[str | Path], torch.Tensor] | None = None,
+            mask_loader: Callable[[str | Path], torch.Tensor] | None = None,
+            joint_transform: Callable[[torch.Tensor, torch.Tensor | None], tuple[torch.Tensor, torch.Tensor | None]] | None = None,
+            image_transform: Callable[[torch.Tensor], torch.Tensor] | None = None,
+            mask_transform: Callable[[torch.Tensor], torch.Tensor] | None = None,
+    ) -> None:
+        self.records = records
+        self.image_loader = image_loader or default_image_loader
+        self.mask_loader = mask_loader or default_mask_loader
+        self.joint_transform = joint_transform
+        self.image_transform = image_transform
+        self.mask_transform = mask_transform
+
+    def __len__(self) -> int:
+        return len(self.records)
+
+    def __getitem__(self, index: int) -> dict[str, Any]:
+        record = self.records[index]
+        image = self.image_loader(record.image_path)
+
+        mask = None
+        if record.mask_path is not None:
+            mask = self.mask_loader(record.mask_path)
+        elif record.annotation_path is not None and record.dataset_name == "DDTI":
+            ddti_mask = build_ddti_mask(record.image_path, record.annotation_path)
+            mask_array = (np.asarray(ddti_mask, dtype=np.float32) > 0).astype(np.float32)
+            mask = torch.from_numpy(mask_array).unsqueeze(0).contiguous()
+
+        if self.joint_transform is not None:
+            image, mask = self.joint_transform(image, mask)
+        if self.image_transform is not None:
+            image = self.image_transform(image)
+        if mask is not None and self.mask_transform is not None:
+            mask = self.mask_transform(mask)
+
+        return {
+            "image": image,
+            "mask": mask,
+            "dataset_name": record.dataset_name,
+            "sample_id": record.sample_id,
+            "split": record.split,
+            "class_name": record.class_name,
+            "meta": record.meta,
+        }
+
+
+__all__ = [
+    "SegmentationRecordDataset",
+    "default_image_loader",
+    "default_mask_loader",
+]

+ 81 - 0
lib/data/ddti.py

@@ -0,0 +1,81 @@
+from __future__ import annotations
+
+import json
+from pathlib import Path
+import xml.etree.ElementTree as ET
+
+from PIL import Image, ImageDraw
+
+
+def parse_ddti_xml(annotation_path: str | Path) -> dict[int, list[list[tuple[int, int]]]]:
+    """
+    解析 DDTI 的 xml 标注。
+
+    Returns:
+        {image_index: [polygon1, polygon2, ...]}
+    """
+    annotation_path = Path(annotation_path)
+    root = ET.parse(annotation_path).getroot()
+    image_to_polygons: dict[int, list[list[tuple[int, int]]]] = {}
+
+    for mark in root.findall("mark"):
+        image_text = mark.findtext("image")
+        svg_text = mark.findtext("svg")
+        if not image_text or not svg_text:
+            continue
+
+        image_index = int(image_text)
+        try:
+            shapes = json.loads(svg_text)
+        except json.JSONDecodeError:
+            continue
+
+        polygons: list[list[tuple[int, int]]] = []
+        for shape in shapes:
+            points = shape.get("points", [])
+            polygon = []
+            for point in points:
+                x = int(round(point["x"]))
+                y = int(round(point["y"]))
+                polygon.append((x, y))
+            if len(polygon) >= 3:
+                polygons.append(polygon)
+
+        if polygons:
+            image_to_polygons[image_index] = polygons
+
+    return image_to_polygons
+
+
+def build_ddti_mask(
+        image_path: str | Path,
+        annotation_path: str | Path,
+        image_index: int | None = None,
+        fill_value: int = 255,
+) -> Image.Image:
+    """
+    根据 DDTI 的 xml 为指定图像生成二值掩膜。
+    """
+    image_path = Path(image_path)
+    annotation_path = Path(annotation_path)
+    image = Image.open(image_path)
+    width, height = image.size
+
+    if image_index is None:
+        stem = image_path.stem
+        if "_" not in stem:
+            raise ValueError(f"Cannot infer image index from file name: {image_path.name}")
+        _, image_idx_str = stem.split("_", 1)
+        image_index = int(image_idx_str)
+
+    polygons_map = parse_ddti_xml(annotation_path)
+    polygons = polygons_map.get(int(image_index), [])
+
+    mask = Image.new("L", (width, height), 0)
+    draw = ImageDraw.Draw(mask)
+    for polygon in polygons:
+        draw.polygon(polygon, outline=fill_value, fill=fill_value)
+    return mask
+
+
+__all__ = ["parse_ddti_xml", "build_ddti_mask"]

+ 225 - 0
lib/data/indexers.py

@@ -0,0 +1,225 @@
+from __future__ import annotations
+
+from pathlib import Path
+
+from .records import SegSampleRecord
+from .utils import list_image_files, stem_without_mask_suffix
+
+
+def build_paired_folder_records(
+        dataset_name: str,
+        image_dir: Path,
+        mask_dir: Path,
+        *,
+        split: str | None = None,
+        class_name: str | None = None,
+) -> list[SegSampleRecord]:
+    images = list_image_files(image_dir)
+    masks = list_image_files(mask_dir)
+    mask_map = {mask.name: mask for mask in masks}
+
+    records: list[SegSampleRecord] = []
+    for image in images:
+        mask = mask_map.get(image.name)
+        if mask is None:
+            continue
+        records.append(
+            SegSampleRecord(
+                dataset_name=dataset_name,
+                image_path=image,
+                mask_path=mask,
+                split=split,
+                sample_id=image.stem,
+                class_name=class_name,
+            )
+        )
+    return records
+
+
+def build_prefixed_paired_records(
+        dataset_name: str,
+        image_dir: Path,
+        mask_dir: Path,
+        *,
+        image_prefix_to_strip: str = "",
+        mask_prefix_to_strip: str = "",
+        split: str | None = None,
+        class_name: str | None = None,
+) -> list[SegSampleRecord]:
+    images = list_image_files(image_dir)
+    masks = list_image_files(mask_dir)
+
+    def _normalize(path: Path, prefix: str) -> str:
+        name = path.name
+        if prefix and name.startswith(prefix):
+            name = name[len(prefix):]
+        return name
+
+    mask_map = {_normalize(mask, mask_prefix_to_strip): mask for mask in masks}
+
+    records: list[SegSampleRecord] = []
+    for image in images:
+        key = _normalize(image, image_prefix_to_strip)
+        mask = mask_map.get(key)
+        if mask is None:
+            continue
+        records.append(
+            SegSampleRecord(
+                dataset_name=dataset_name,
+                image_path=image,
+                mask_path=mask,
+                split=split,
+                sample_id=image.stem,
+                class_name=class_name,
+            )
+        )
+    return records
+
+
+def build_stem_paired_records(
+        dataset_name: str,
+        image_dir: Path,
+        mask_dir: Path,
+        *,
+        split: str | None = None,
+        class_name: str | None = None,
+        prefer_plain_mask: bool = True,
+) -> list[SegSampleRecord]:
+    images = list_image_files(image_dir)
+    masks = list_image_files(mask_dir)
+
+    grouped_masks: dict[str, list[Path]] = {}
+    for mask in masks:
+        key = stem_without_mask_suffix(mask.name)
+        grouped_masks.setdefault(key, []).append(mask)
+
+    records: list[SegSampleRecord] = []
+    for image in images:
+        key = image.stem
+        candidates = sorted(grouped_masks.get(key, []))
+        if not candidates:
+            continue
+
+        mask = candidates[0]
+        if prefer_plain_mask:
+            plain = [candidate for candidate in candidates if "_binary" not in candidate.stem.lower()]
+            if plain:
+                mask = plain[0]
+
+        records.append(
+            SegSampleRecord(
+                dataset_name=dataset_name,
+                image_path=image,
+                mask_path=mask,
+                split=split,
+                sample_id=key,
+                class_name=class_name,
+                meta={"mask_candidates": str(len(candidates))},
+            )
+        )
+    return records
+
+
+def build_filename_matched_records(
+        dataset_name: str,
+        folder: Path,
+        *,
+        split: str | None = None,
+        class_name: str | None = None,
+) -> list[SegSampleRecord]:
+    files = list_image_files(folder)
+    image_map: dict[str, Path] = {}
+    mask_map: dict[str, list[Path]] = {}
+
+    for path in files:
+        key = stem_without_mask_suffix(path.name)
+        if "_mask" in path.stem:
+            mask_map.setdefault(key, []).append(path)
+        else:
+            image_map[key] = path
+
+    records: list[SegSampleRecord] = []
+    for key, image in sorted(image_map.items()):
+        masks = sorted(mask_map.get(key, []))
+        if not masks:
+            continue
+        records.append(
+            SegSampleRecord(
+                dataset_name=dataset_name,
+                image_path=image,
+                mask_path=masks[0],
+                split=split,
+                sample_id=key,
+                class_name=class_name,
+                meta={"mask_count": str(len(masks))},
+            )
+        )
+    return records
+
+
+def build_pre_split_records(
+        dataset_name: str,
+        train_image_dir: Path,
+        train_mask_dir: Path,
+        test_image_dir: Path,
+        test_mask_dir: Path,
+) -> list[SegSampleRecord]:
+    records = []
+    records.extend(
+        build_paired_folder_records(
+            dataset_name=dataset_name,
+            image_dir=train_image_dir,
+            mask_dir=train_mask_dir,
+            split="trainval",
+        )
+    )
+    records.extend(
+        build_paired_folder_records(
+            dataset_name=dataset_name,
+            image_dir=test_image_dir,
+            mask_dir=test_mask_dir,
+            split="test",
+        )
+    )
+    return records
+
+
+def build_xml_annotation_records(
+        dataset_name: str,
+        root: Path,
+        *,
+        split: str | None = None,
+        class_name: str | None = None,
+) -> list[SegSampleRecord]:
+    xml_map = {path.stem: path for path in sorted(root.glob("*.xml"))}
+    image_files = sorted(root.glob("*.jpg"))
+
+    records: list[SegSampleRecord] = []
+    for image in image_files:
+        sample_key = image.stem.split("_")[0]
+        annotation = xml_map.get(sample_key)
+        if annotation is None:
+            continue
+        records.append(
+            SegSampleRecord(
+                dataset_name=dataset_name,
+                image_path=image,
+                mask_path=None,
+                annotation_path=annotation,
+                split=split,
+                sample_id=image.stem,
+                class_name=class_name,
+                meta={"annotation_type": "xml"},
+            )
+        )
+    return records
+
+
+__all__ = [
+    "build_paired_folder_records",
+    "build_prefixed_paired_records",
+    "build_stem_paired_records",
+    "build_filename_matched_records",
+    "build_pre_split_records",
+    "build_xml_annotation_records",
+]

+ 224 - 0
lib/data/loaders.py

@@ -0,0 +1,224 @@
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Any
+
+from torch.utils.data import DataLoader
+
+from .augment import build_segmentation_augmentation
+from .builder import build_dataset_index
+from .collate import record_collate_fn
+from .datasets import SegmentationRecordDataset
+from .project_splits import (
+    PROJECT_SPLIT_DATASETS,
+    get_project_split_file,
+    load_project_split_ids,
+    select_project_split_base_records,
+)
+from .records import SegSampleRecord
+from .splits import load_id_txt, load_json_split
+
+OFFICIAL_SPLIT_FILES: dict[str, dict[str, str]] = {
+    "OTU_2d": {
+        "train": "train.txt",
+        "val": "val.txt",
+    },
+    "TN3K": {
+        "train": "tn3k-trainval.json",
+        "val": "tn3k-trainval.json",
+        "test": "tn3k-trainval.json",
+    },
+    "TG3K": {
+        "train": "tg3k-trainval.json",
+        "val": "tg3k-trainval.json",
+        "test": "tg3k-trainval.json",
+    },
+}
+
+
+def _normalize_id_set(values: list[str]) -> set[str]:
+    normalized = set()
+    for item in values:
+        normalized.add(item)
+        try:
+            normalized.add(f"{int(item):04d}")
+        except ValueError:
+            pass
+    return normalized
+
+
+def _as_exact_id_set(values: list[str]) -> set[str]:
+    return {item for item in values}
+
+
+def _clone_record(record: SegSampleRecord, split_name: str | None) -> SegSampleRecord:
+    return SegSampleRecord(
+        dataset_name=record.dataset_name,
+        image_path=record.image_path,
+        mask_path=record.mask_path,
+        annotation_path=record.annotation_path,
+        split=split_name,
+        sample_id=record.sample_id,
+        class_name=record.class_name,
+        meta=dict(record.meta),
+    )
+
+
+def _filter_by_sample_ids(records: list[SegSampleRecord], sample_ids: set[str], split_name: str) -> list[SegSampleRecord]:
+    filtered = []
+    for record in records:
+        if record.sample_id in sample_ids:
+            filtered.append(_clone_record(record, split_name))
+    return filtered
+
+
+def _filter_by_existing_split(records: list[SegSampleRecord], split: str) -> list[SegSampleRecord]:
+    return [_clone_record(record, split) for record in records if record.split == split]
+
+
+def get_official_split_file(
+        dataset_name: str,
+        root: str | Path,
+        split: str,
+) -> Path | None:
+    split_map = OFFICIAL_SPLIT_FILES.get(dataset_name)
+    if split_map is None:
+        return None
+
+    relative_path = split_map.get(split)
+    if relative_path is None:
+        return None
+    return Path(root) / relative_path
+
+
+def list_supported_splits(dataset_name: str) -> list[str]:
+    official = OFFICIAL_SPLIT_FILES.get(dataset_name)
+    if official is not None:
+        return list(official.keys())
+    if dataset_name in PROJECT_SPLIT_DATASETS:
+        return ["train", "val"]
+    return []
+
+
+def apply_official_split(
+        dataset_name: str,
+        root: str | Path,
+        records: list[SegSampleRecord],
+        split: str,
+        *,
+        split_file: str | Path | None = None,
+) -> list[SegSampleRecord]:
+    root = Path(root)
+
+    if dataset_name == "OTU_2d":
+        if split not in {"train", "val"}:
+            raise ValueError("OTU_2d currently supports official splits: train, val.")
+        split_path = Path(split_file) if split_file is not None else get_official_split_file(dataset_name, root, split)
+        ids = _normalize_id_set(load_id_txt(split_path))
+        return _filter_by_sample_ids(records, ids, split_name=split)
+
+    if dataset_name == "TN3K":
+        if split == "test":
+            return _filter_by_existing_split(records, "test")
+        split_path = Path(split_file) if split_file is not None else get_official_split_file(dataset_name, root, split)
+        split_map = load_json_split(split_path)
+        if split not in split_map:
+            raise ValueError(f"Split '{split}' not found in {split_path}.")
+        ids = _normalize_id_set(split_map[split])
+        trainval_records = [record for record in records if record.split == "trainval"]
+        return _filter_by_sample_ids(trainval_records, ids, split_name=split)
+
+    if dataset_name == "TG3K":
+        split_path = Path(split_file) if split_file is not None else get_official_split_file(dataset_name, root, split)
+        split_map = load_json_split(split_path)
+        if split not in split_map:
+            raise ValueError(f"Split '{split}' not found in {split_path}.")
+        ids = _normalize_id_set(split_map[split])
+        return _filter_by_sample_ids(records, ids, split_name=split)
+
+    if dataset_name in PROJECT_SPLIT_DATASETS:
+        if split not in {"train", "val"}:
+            raise ValueError(
+                f"{dataset_name} currently supports project splits: train, val."
+            )
+        records = select_project_split_base_records(dataset_name, records)
+        split_path = Path(split_file) if split_file is not None else get_project_split_file(root, split)
+        ids = _as_exact_id_set(load_project_split_ids(root, split) if split_file is None else load_id_txt(split_path))
+        return _filter_by_sample_ids(records, ids, split_name=split)
+
+    filtered = _filter_by_existing_split(records, split)
+    if filtered:
+        return filtered
+    raise ValueError(
+        f"No split handler registered for dataset '{dataset_name}' and split '{split}'."
+    )
+
+
+def build_record_dataset(
+        dataset_name: str,
+        root: str | Path,
+        *,
+        split: str | None = None,
+        split_file: str | Path | None = None,
+        augmentation_config: dict[str, Any] | None = None,
+        image_transform=None,
+        mask_transform=None,
+) -> SegmentationRecordDataset:
+    records = build_dataset_index(dataset_name, root)
+    if split is not None:
+        records = apply_official_split(
+            dataset_name=dataset_name,
+            root=root,
+            records=records,
+            split=split,
+            split_file=split_file,
+        )
+    return SegmentationRecordDataset(
+        records=records,
+        joint_transform=build_segmentation_augmentation(augmentation_config),
+        image_transform=image_transform,
+        mask_transform=mask_transform,
+    )
+
+
+def build_dataloader(
+        dataset_name: str,
+        root: str | Path,
+        *,
+        split: str | None = None,
+        split_file: str | Path | None = None,
+        batch_size: int = 4,
+        shuffle: bool = False,
+        num_workers: int = 0,
+        augmentation_config: dict[str, Any] | None = None,
+        image_transform=None,
+        mask_transform=None,
+        **loader_kwargs: Any,
+) -> DataLoader:
+    dataset = build_record_dataset(
+        dataset_name=dataset_name,
+        root=root,
+        split=split,
+        split_file=split_file,
+        augmentation_config=augmentation_config,
+        image_transform=image_transform,
+        mask_transform=mask_transform,
+    )
+    return DataLoader(
+        dataset,
+        batch_size=batch_size,
+        shuffle=shuffle,
+        num_workers=num_workers,
+        collate_fn=loader_kwargs.pop("collate_fn", record_collate_fn),
+        **loader_kwargs,
+    )
+
+
+__all__ = [
+    "OFFICIAL_SPLIT_FILES",
+    "apply_official_split",
+    "build_record_dataset",
+    "build_dataloader",
+    "get_official_split_file",
+    "list_supported_splits",
+]

+ 159 - 0
lib/data/project_splits.py

@@ -0,0 +1,159 @@
+from __future__ import annotations
+
+import random
+from collections import defaultdict
+from pathlib import Path
+
+from .builder import build_dataset_index
+from .records import SegSampleRecord
+
+
+PROJECT_SPLIT_ROOT = Path("splits") / "project"
+PROJECT_SPLIT_DATASETS = {"BUS-UCLM", "BUSI", "BUS-BRA", "BUS_UC", "CCAUI", "DDTI"}
+
+
+def _project_split_dir(root: str | Path) -> Path:
+    return Path(root) / PROJECT_SPLIT_ROOT
+
+
+def get_project_split_file(
+        root: str | Path,
+        split: str,
+) -> Path:
+    return _project_split_dir(root) / f"{split}.txt"
+
+
+def load_project_split_ids(
+        root: str | Path,
+        split: str,
+) -> list[str]:
+    path = get_project_split_file(root, split)
+    if not path.exists():
+        raise FileNotFoundError(f"Project split file not found: {path}")
+    return [
+        line.strip()
+        for line in path.read_text(encoding="utf-8", errors="ignore").splitlines()
+        if line.strip()
+    ]
+
+
+def _write_split_ids(path: Path, sample_ids: list[str]) -> None:
+    path.parent.mkdir(parents=True, exist_ok=True)
+    if sample_ids:
+        path.write_text("\n".join(sample_ids) + "\n", encoding="utf-8")
+    else:
+        path.write_text("", encoding="utf-8")
+
+
+def _deduplicate_records(
+        dataset_name: str,
+        records: list[SegSampleRecord],
+) -> list[SegSampleRecord]:
+    if dataset_name != "BUS_UC":
+        return records
+
+    # BUS_UC 的 All 与 Benign/Malignant 是重复样本,默认只保留 All 作为正式划分基底。
+    all_records = [record for record in records if record.class_name == "all"]
+    return all_records if all_records else records
+
+
+def select_project_split_base_records(
+        dataset_name: str,
+        records: list[SegSampleRecord],
+) -> list[SegSampleRecord]:
+    return _deduplicate_records(dataset_name, records)
+
+
+def _group_records_for_split(
+        records: list[SegSampleRecord],
+) -> dict[str, list[SegSampleRecord]]:
+    groups: dict[str, list[SegSampleRecord]] = defaultdict(list)
+    for record in records:
+        key = record.class_name or "__default__"
+        groups[key].append(record)
+    return groups
+
+
+def _split_group(
+        group_records: list[SegSampleRecord],
+        *,
+        val_ratio: float,
+        rng: random.Random,
+) -> tuple[list[SegSampleRecord], list[SegSampleRecord]]:
+    shuffled = list(group_records)
+    rng.shuffle(shuffled)
+
+    val_count = int(round(len(shuffled) * val_ratio))
+    if len(shuffled) >= 2:
+        val_count = max(1, min(len(shuffled) - 1, val_count))
+    elif len(shuffled) == 1:
+        val_count = 0
+
+    val_records = shuffled[:val_count]
+    train_records = shuffled[val_count:]
+    return train_records, val_records
+
+
+def generate_project_splits(
+        dataset_name: str,
+        root: str | Path,
+        *,
+        val_ratio: float = 0.2,
+        seed: int = 42,
+        stratify_by_class: bool = True,
+        reuse_existing: bool = True,
+) -> dict[str, list[str]]:
+    if dataset_name not in PROJECT_SPLIT_DATASETS:
+        raise ValueError(
+            f"Dataset '{dataset_name}' is not enabled for project split generation."
+        )
+    if not 0.0 < val_ratio < 1.0:
+        raise ValueError(f"val_ratio must be between 0 and 1, got {val_ratio}.")
+
+    train_path = get_project_split_file(root, "train")
+    val_path = get_project_split_file(root, "val")
+    if reuse_existing and train_path.exists() and val_path.exists():
+        return {
+            "train": load_project_split_ids(root, "train"),
+            "val": load_project_split_ids(root, "val"),
+        }
+
+    records = build_dataset_index(dataset_name, root)
+    records = _deduplicate_records(dataset_name, records)
+    rng = random.Random(seed)
+
+    train_records: list[SegSampleRecord] = []
+    val_records: list[SegSampleRecord] = []
+
+    if stratify_by_class:
+        groups = _group_records_for_split(records)
+        for group_records in groups.values():
+            group_train, group_val = _split_group(group_records, val_ratio=val_ratio, rng=rng)
+            train_records.extend(group_train)
+            val_records.extend(group_val)
+    else:
+        train_records, val_records = _split_group(records, val_ratio=val_ratio, rng=rng)
+
+    train_ids = sorted(record.sample_id for record in train_records if record.sample_id is not None)
+    val_ids = sorted(record.sample_id for record in val_records if record.sample_id is not None)
+
+    split_dir = _project_split_dir(root)
+    split_dir.mkdir(parents=True, exist_ok=True)
+
+    _write_split_ids(train_path, train_ids)
+    _write_split_ids(val_path, val_ids)
+
+    return {
+        "train": train_ids,
+        "val": val_ids,
+    }
+
+
+__all__ = [
+    "PROJECT_SPLIT_DATASETS",
+    "PROJECT_SPLIT_ROOT",
+    "generate_project_splits",
+    "get_project_split_file",
+    "load_project_split_ids",
+    "select_project_split_base_records",
+]

+ 25 - 0
lib/data/records.py

@@ -0,0 +1,25 @@
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+from pathlib import Path
+
+
+@dataclass(slots=True)
+class SegSampleRecord:
+    """
+    统一的分割样本记录格式。
+
+    这一层只负责“索引”,不负责真正读取图像。
+    """
+
+    dataset_name: str
+    image_path: Path
+    mask_path: Path | None = None
+    annotation_path: Path | None = None
+    split: str | None = None
+    sample_id: str | None = None
+    class_name: str | None = None
+    meta: dict[str, str] = field(default_factory=dict)
+
+
+__all__ = ["SegSampleRecord"]

+ 21 - 0
lib/data/splits.py

@@ -0,0 +1,21 @@
+from __future__ import annotations
+
+import json
+from pathlib import Path
+
+
+def load_id_txt(path: str | Path) -> list[str]:
+    path = Path(path)
+    return [line.strip() for line in path.read_text(encoding="utf-8", errors="ignore").splitlines() if line.strip()]
+
+
+def load_json_split(path: str | Path) -> dict[str, list[str]]:
+    path = Path(path)
+    obj = json.loads(path.read_text(encoding="utf-8"))
+    result: dict[str, list[str]] = {}
+    for key, value in obj.items():
+        result[key] = [str(item) for item in value]
+    return result
+
+
+__all__ = ["load_id_txt", "load_json_split"]

+ 36 - 0
lib/data/utils.py

@@ -0,0 +1,36 @@
+from __future__ import annotations
+
+from pathlib import Path
+import re
+
+
+IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"}
+
+
+def is_image_file(path: Path) -> bool:
+    return path.is_file() and path.suffix.lower() in IMAGE_EXTENSIONS
+
+
+def list_image_files(folder: Path) -> list[Path]:
+    if not folder.exists():
+        raise FileNotFoundError(f"Folder not found: {folder}")
+    return sorted([path for path in folder.iterdir() if is_image_file(path)])
+
+
+def stem_without_mask_suffix(name: str) -> str:
+    stem = Path(name).stem
+    stem = re.sub(r"_mask(_\d+)?$", "", stem)
+    return stem
+
+
+def relative_stem(path: Path) -> str:
+    return path.stem
+
+
+__all__ = [
+    "IMAGE_EXTENSIONS",
+    "is_image_file",
+    "list_image_files",
+    "stem_without_mask_suffix",
+    "relative_stem",
+]

+ 127 - 37
lib/modules/xnet_2d.py

@@ -2,6 +2,7 @@ from __future__ import annotations
 
 from collections.abc import Sequence
 
+import ptwt
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
@@ -11,6 +12,7 @@ from .lib_mamba.vmamba import SS2D as VMambaSS2D
 
 
 class XNetStem2d(nn.Module):
+    # Stem reduces spatial size by 4x while lifting features into encoder stage 1.
     def __init__(self, in_channels: int, stem_channels: int, out_channels: int) -> None:
         super().__init__()
         self.block = nn.Sequential(
@@ -43,6 +45,7 @@ class XNetDownsample2d(nn.Module):
 
 
 class XLocalBranch2d(nn.Module):
+    # Parallel depthwise branches capture short-range texture at two kernel scales.
     def __init__(self, channels: int) -> None:
         super().__init__()
         self.branch3 = nn.Sequential(
@@ -60,45 +63,36 @@ class XLocalBranch2d(nn.Module):
         return self.branch3(x) + self.branch5(x)
 
 
-class XHaarWaveletTransform2d(nn.Module):
-    def __init__(self, channels: int) -> None:
+class XWaveletTransform2d(nn.Module):
+    # ptwt-based wavelet decomposition/reconstruction with explicit crop so odd
+    # input sizes round-trip to the exact original spatial shape.
+    def __init__(
+        self, channels: int, wavelet_type: str = "haar", wavelet_level: int = 1
+    ) -> None:
         super().__init__()
-        ll = torch.tensor([[0.5, 0.5], [0.5, 0.5]], dtype=torch.float32)
-        lh = torch.tensor([[-0.5, -0.5], [0.5, 0.5]], dtype=torch.float32)
-        hl = torch.tensor([[-0.5, 0.5], [-0.5, 0.5]], dtype=torch.float32)
-        hh = torch.tensor([[0.5, -0.5], [-0.5, 0.5]], dtype=torch.float32)
-        filt = torch.stack([ll, lh, hl, hh], dim=0).unsqueeze(1)
-        self.register_buffer(
-            "analysis_filter", filt.repeat(channels, 1, 1, 1), persistent=False
-        )
-        self.register_buffer(
-            "synthesis_filter", filt.repeat(channels, 1, 1, 1), persistent=False
-        )
         self.channels = channels
+        self.wavelet_type = wavelet_type
+        self.wavelet_level = wavelet_level
 
     def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
-        b, c, h, w = x.shape
-        pad_h = h % 2
-        pad_w = w % 2
-        if pad_h or pad_w:
-            x = F.pad(x, (0, pad_w, 0, pad_h))
-        y = F.conv2d(x, self.analysis_filter, stride=2, groups=self.channels)
-        y = y.view(b, c, 4, y.shape[-2], y.shape[-1])
-        ll = y[:, :, 0]
-        high = y[:, :, 1:].reshape(b, c * 3, y.shape[-2], y.shape[-1])
+        coeffs = ptwt.wavedec2(x, self.wavelet_type, level=self.wavelet_level)
+        ll = coeffs[0]
+        high_parts = coeffs[1]
+        high = torch.cat(high_parts, dim=1)
         return ll, high
 
     def inverse(
         self, ll: torch.Tensor, high: torch.Tensor, output_size: tuple[int, int]
     ) -> torch.Tensor:
-        b, c, h, w = ll.shape
-        high = high.view(b, c, 3, h, w)
-        y = torch.cat([ll.unsqueeze(2), high], dim=2).reshape(b, c * 4, h, w)
-        x = F.conv_transpose2d(y, self.synthesis_filter, stride=2, groups=self.channels)
+        lh, hl, hh = torch.chunk(high, 3, dim=1)
+        coeffs = [ll, (lh, hl, hh)]
+        x = ptwt.waverec2(coeffs, self.wavelet_type)
         return x[:, :, : output_size[0], : output_size[1]]
 
 
 class XWaveletBranch2d(nn.Module):
+    # The wavelet branch learns on low/high-frequency components separately and
+    # then reconstructs back to the original feature size.
     def __init__(
         self, channels: int, wavelet_type: str = "haar", wavelet_level: int = 1
     ) -> None:
@@ -109,7 +103,9 @@ class XWaveletBranch2d(nn.Module):
             raise ValueError(
                 "Initial XNet implementation only supports wavelet_level=1."
             )
-        self.wavelet = XHaarWaveletTransform2d(channels)
+        self.wavelet = XWaveletTransform2d(
+            channels, wavelet_type=wavelet_type, wavelet_level=wavelet_level
+        )
         self.ll_proj = nn.Sequential(
             Conv2dBN(channels, channels, 3, 1, 1),
             nn.ReLU(inplace=True),
@@ -134,6 +130,7 @@ class XWaveletBranch2d(nn.Module):
 
 
 class XSSMGlobalBranch2d(nn.Module):
+    # The global branch wraps VMamba and switches scan backend at runtime.
     def __init__(
         self,
         channels: int,
@@ -240,6 +237,7 @@ class XBranchFusion2d(nn.Module):
 
 
 class XTEB2d(nn.Module):
+    # XTEB fuses local, wavelet, and global branches with residual post/ffn blocks.
     def __init__(
         self,
         channels: int,
@@ -333,6 +331,7 @@ class XNetEncoderStage2d(nn.Module):
 
 
 class XNetEncoder2d(nn.Module):
+    # The encoder is a 4-stage feature pyramid with optional stage-1 global branch.
     def __init__(
         self,
         in_channels: int,
@@ -416,6 +415,7 @@ class XNetEncoder2d(nn.Module):
 
 
 class XGuideProjector2d(nn.Module):
+    # Guides are projected from encoder features and aligned to decoder resolution.
     def __init__(
         self, in_channels: int, out_channels: int, mode: str = "affine"
     ) -> None:
@@ -450,6 +450,7 @@ class XGuideProjector2d(nn.Module):
 
 
 class XSkipFusion2d(nn.Module):
+    # Decoder input and skip feature are aligned, projected, and fused together.
     def __init__(self, in_channels: int, skip_channels: int, out_channels: int) -> None:
         super().__init__()
         self.input_proj = nn.Sequential(
@@ -473,6 +474,7 @@ class XSkipFusion2d(nn.Module):
 
 
 class XGuideModulation2d(nn.Module):
+    # Apply either direct affine guide or feature-to-affine modulation.
     def __init__(self, channels: int, guide_mode: str = "affine") -> None:
         super().__init__()
         self.guide_mode = guide_mode
@@ -493,15 +495,23 @@ class XGuideModulation2d(nn.Module):
 
 
 class XFrequencyRefine2d(nn.Module):
-    def __init__(self, channels: int) -> None:
+    def __init__(
+        self,
+        channels: int,
+        low_freq_radius_h: float = 0.25,
+        low_freq_radius_w: float = 0.25,
+        learnable_low_freq_radius: bool = True,
+    ) -> None:
         super().__init__()
+        if low_freq_radius_h <= 0.0 or low_freq_radius_w <= 0.0:
+            raise ValueError("Low-frequency radii must be positive.")
+        # Gates are predicted from half-spectrum magnitude statistics instead of
+        # directly reusing spatial-domain pooled features.
         self.low_gate = nn.Sequential(
-            nn.AdaptiveAvgPool2d(1),
             nn.Conv2d(channels, channels, kernel_size=1, bias=True),
             nn.Sigmoid(),
         )
         self.high_gate = nn.Sequential(
-            nn.AdaptiveAvgPool2d(1),
             nn.Conv2d(channels, channels, kernel_size=1, bias=True),
             nn.Sigmoid(),
         )
@@ -510,25 +520,77 @@ class XFrequencyRefine2d(nn.Module):
             nn.ReLU(inplace=True),
             Conv2dBN(channels, channels, 1, 1, 0),
         )
+        self.learnable_low_freq_radius = learnable_low_freq_radius
+        if learnable_low_freq_radius:
+            self.low_freq_radius_h = nn.Parameter(
+                torch.tensor(low_freq_radius_h, dtype=torch.float32)
+            )
+            self.low_freq_radius_w = nn.Parameter(
+                torch.tensor(low_freq_radius_w, dtype=torch.float32)
+            )
+        else:
+            self.register_buffer(
+                "low_freq_radius_h",
+                torch.tensor(low_freq_radius_h, dtype=torch.float32),
+                persistent=False,
+            )
+            self.register_buffer(
+                "low_freq_radius_w",
+                torch.tensor(low_freq_radius_w, dtype=torch.float32),
+                persistent=False,
+            )
+
+    def _resolve_radius(
+        self, value: torch.Tensor, max_ratio: float, device: torch.device
+    ) -> torch.Tensor:
+        radius = value.to(device=device, dtype=torch.float32)
+        if self.learnable_low_freq_radius:
+            radius = torch.sigmoid(radius) * max_ratio
+        return torch.clamp(radius, min=1.0e-3, max=max_ratio)
+
+    def _build_low_frequency_mask(
+        self, h_freq: int, w_freq: int, device: torch.device
+    ) -> torch.Tensor:
+        y = torch.arange(h_freq, device=device, dtype=torch.float32)
+        x = torch.arange(w_freq, device=device, dtype=torch.float32)
+        y = torch.minimum(y, h_freq - y)
+        radius_h = self._resolve_radius(self.low_freq_radius_h, 0.5, device) * max(
+            h_freq, 1
+        )
+        radius_w = self._resolve_radius(self.low_freq_radius_w, 1.0, device) * max(
+            w_freq, 1
+        )
+        y = y / torch.clamp(radius_h, min=1.0)
+        x = x / torch.clamp(radius_w, min=1.0)
+        y_grid, x_grid = torch.meshgrid(y, x, indexing="ij")
+        mask = (y_grid.square() + x_grid.square()) <= 1.0
+        return mask.unsqueeze(0).unsqueeze(0)
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         input_dtype = x.dtype
         if x.dtype != torch.float32:
             x = x.to(torch.float32)
         fft = torch.fft.rfft2(x, norm="ortho")
-        low = fft.clone()
-        h_freq, w_freq = low.shape[-2], low.shape[-1]
-        low[:, :, h_freq // 4 :, :] = 0
-        low[:, :, :, w_freq // 4 :] = 0
+        h_freq, w_freq = fft.shape[-2], fft.shape[-1]
+        low_mask = self._build_low_frequency_mask(h_freq, w_freq, fft.device).to(
+            dtype=x.dtype
+        )
+        low = fft * low_mask
         high = fft - low
-        low = low * self.low_gate(x)
-        high = high * self.high_gate(x)
+
+        magnitude = fft.abs()
+        low_stats = (magnitude * low_mask).mean(dim=(-2, -1), keepdim=True)
+        high_stats = (magnitude * (1.0 - low_mask)).mean(dim=(-2, -1), keepdim=True)
+
+        low = low * self.low_gate(low_stats)
+        high = high * self.high_gate(high_stats)
         out = torch.fft.irfft2(low + high, s=x.shape[-2:], norm="ortho")
         out = out.to(dtype=input_dtype)
         return self.refine(out)
 
 
 class XCRB2d(nn.Module):
+    # Decoder block: skip fusion -> guide modulation -> frequency refine -> residual output.
     def __init__(
         self,
         in_channels: int,
@@ -537,12 +599,22 @@ class XCRB2d(nn.Module):
         out_channels: int,
         guide_mode: str = "affine",
         use_frequency_refine: bool = True,
+        low_freq_radius_h: float = 0.25,
+        low_freq_radius_w: float = 0.25,
+        learnable_low_freq_radius: bool = True,
     ) -> None:
         super().__init__()
         self.skip_fusion = XSkipFusion2d(in_channels, skip_channels, out_channels)
         self.guide_modulation = XGuideModulation2d(out_channels, guide_mode=guide_mode)
         self.frequency_refine = (
-            XFrequencyRefine2d(out_channels) if use_frequency_refine else nn.Identity()
+            XFrequencyRefine2d(
+                out_channels,
+                low_freq_radius_h=low_freq_radius_h,
+                low_freq_radius_w=low_freq_radius_w,
+                learnable_low_freq_radius=learnable_low_freq_radius,
+            )
+            if use_frequency_refine
+            else nn.Identity()
         )
         self.out_refine = nn.Sequential(
             Conv2dBN(out_channels, out_channels, 3, 1, 1),
@@ -586,6 +658,9 @@ class XNetDecoder2d(nn.Module):
         decoder_channels: Sequence[int] = (128, 64, 32),
         guide_mode: str = "affine",
         use_frequency_refine: bool = True,
+        low_freq_radius_h: float = 0.25,
+        low_freq_radius_w: float = 0.25,
+        learnable_low_freq_radius: bool = True,
         out_channels: int | None = None,
     ) -> None:
         super().__init__()
@@ -605,6 +680,9 @@ class XNetDecoder2d(nn.Module):
             d4,
             guide_mode=guide_mode,
             use_frequency_refine=use_frequency_refine,
+            low_freq_radius_h=low_freq_radius_h,
+            low_freq_radius_w=low_freq_radius_w,
+            learnable_low_freq_radius=learnable_low_freq_radius,
         )
         self.dec3 = XCRB2d(
             d4,
@@ -613,6 +691,9 @@ class XNetDecoder2d(nn.Module):
             d3,
             guide_mode=guide_mode,
             use_frequency_refine=use_frequency_refine,
+            low_freq_radius_h=low_freq_radius_h,
+            low_freq_radius_w=low_freq_radius_w,
+            learnable_low_freq_radius=learnable_low_freq_radius,
         )
         self.dec2 = XCRB2d(
             d3,
@@ -621,6 +702,9 @@ class XNetDecoder2d(nn.Module):
             d2,
             guide_mode=guide_mode,
             use_frequency_refine=use_frequency_refine,
+            low_freq_radius_h=low_freq_radius_h,
+            low_freq_radius_w=low_freq_radius_w,
+            learnable_low_freq_radius=learnable_low_freq_radius,
         )
         self.head_refine = XNetHeadRefine2d(d2, out_channels or d2)
         self.out_channels = out_channels or d2
@@ -680,6 +764,9 @@ class XNet2d(nn.Module):
         ssm_forward_type: str = "v3",
         ssm_backend: str = "auto",
         use_frequency_refine: bool = True,
+        low_freq_radius_h: float = 0.25,
+        low_freq_radius_w: float = 0.25,
+        learnable_low_freq_radius: bool = True,
         guide_mode: str = "affine",
         out_channels: int | None = None,
     ) -> None:
@@ -720,6 +807,9 @@ class XNet2d(nn.Module):
             decoder_channels=decoder_channels,
             guide_mode=guide_mode,
             use_frequency_refine=use_frequency_refine,
+            low_freq_radius_h=low_freq_radius_h,
+            low_freq_radius_w=low_freq_radius_w,
+            learnable_low_freq_radius=learnable_low_freq_radius,
             out_channels=out_channels,
         )
         head_in_channels = self.decoder.out_channels

+ 980 - 0
lib/modules/xnet_2d_zh.py

@@ -0,0 +1,980 @@
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import ptwt
+
+from .layers_2d import Conv2dBN
+from .lib_mamba.vmamba import SS2D as VMambaSS2D
+
+"""
+## 完成的修改
+
+### 1. 小波变换模块迁移至 ptwt
+- **替换 `XHaarWaveletTransform2d` → `XWaveletTransform2d`**:使用 `ptwt.wavedec2` / `ptwt.waverec2` 实现可逆小波变换
+- **优势**:
+  - 支持任意 pywt 兼容小波(haar, db4, sym2, db6 等),通过 `wavelet_type` 参数切换
+  - 自动处理边界对齐,无需手动 padding/cropping
+  - 代码更简洁,无手工卷积滤波器
+- **`XWaveletBranch2d`** 已更新引用新类,移除了 wavelet 类型限制检查
+
+### 2. XFrequencyRefine2d 频率域精炼模块分析
+
+**发现的问题与修复:**
+- **原代码 FFT 低频掩码位置错误**:未使用 `fftshift`,直接在左上角做十字掩码,与真实低频位置(四角)不匹配
+- **已修复**:使用 `fftshift` → 圆形低频掩码 → `ifftshift` 还原的正确流程
+
+**设计合理性评估:**
+| 方面 | 评价 |
+|------|------|
+| 低频/高频分离 | ✅ 圆形掩码合理,可调节半径 |
+| 门控机制 | ⚠️ 门控值来自空间域而非频域,可能损失频域选择性 |
+| 通道注意力 | ✅ 每个通道独立门控,灵活 |
+| 重建精度 | ✅ 正交归一化 FFT + 完整频域保留 |
+| 计算开销 | ⚠️ meshgrid 每步计算,可缓存优化 |
+
+**改进建议:**
+1. 门控可改为频域计算(对 `|fft|` 做平均池化)而非空间域
+2. 低频半径可改为可学习参数
+3. meshgrid 可缓存为 buffer 避免重复计算
+
+### 验证结果
+所有模块测试通过,小波分解→重建误差 < 1e-4,输出形状一致。
+"""
+
+# ============================================================
+# 核心架构:XNet2D 医学图像分割网络
+# 业务意图:针对超声等医学图像分割任务,融合局部纹理、频率域、全局序列建模三重能力
+# 设计约束:
+#   - 2D 张量通道优先 (N,C,H,W)
+#   - 所有可逆变换需支持 inverse 恢复原始空间尺寸
+#   - SSM 后端可切换:GPU→oflex,CPU→torch
+# ============================================================
+
+
+# --------------------------------------------------------------------------
+# XNetStem2d:输入茎(Stem)
+# 为什么:将单张输入图快速降采样 4 倍 (H/4, W/4),并逐步提升通道维度
+# 关键行为:
+#   - 两次步幅为 2 的卷积实现 4 倍下采样
+#   - 中间嵌入 depthwise 卷积增强局部通道交互
+# --------------------------------------------------------------------------
+class XNetStem2d(nn.Module):
+    def __init__(self, in_channels: int, stem_channels: int, out_channels: int) -> None:
+        super().__init__()
+        self.block = nn.Sequential(
+            Conv2dBN(in_channels, stem_channels, 3, 2, 1),  # 首次下采样 H/2, W/2
+            nn.ReLU(inplace=True),
+            Conv2dBN(
+                stem_channels, stem_channels, 3, 1, 1, groups=stem_channels
+            ),  # depthwise 局部特征增强
+            nn.ReLU(inplace=True),
+            Conv2dBN(stem_channels, out_channels, 1, 1, 0),  # 通道升维
+            nn.ReLU(inplace=True),
+            Conv2dBN(out_channels, out_channels, 3, 2, 1),  # 二次下采样 H/4, W/4
+            nn.ReLU(inplace=True),
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return self.block(x)
+
+
+# --------------------------------------------------------------------------
+# XNetDownsample2d:阶段间下采样器
+# 为什么:在编码器各阶段之间平滑过渡,降低空间分辨率同时增加通道数
+# 关键行为:
+#   - 仅支持 conv 模式(扩展点由子类控制)
+# --------------------------------------------------------------------------
+class XNetDownsample2d(nn.Module):
+    def __init__(self, in_channels: int, out_channels: int, mode: str = "conv") -> None:
+        super().__init__()
+        if mode != "conv":
+            raise ValueError(f"Unsupported downsample mode: {mode}")
+        self.block = nn.Sequential(
+            Conv2dBN(in_channels, out_channels, 3, 2, 1),  # H/2, W/2 下采样
+            nn.ReLU(inplace=True),
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return self.block(x)
+
+
+# --------------------------------------------------------------------------
+# XLocalBranch2d:局部感受野分支
+# 为什么:并行捕获 3×3 和 5×5 多尺度局部纹理,对医学图像边缘/细微结构敏感
+# 关键行为:
+#   - 两组 depthwise 卷积 + 1×1 通道压缩
+#   - 输出直接相加(残差式局部特征累积)
+# --------------------------------------------------------------------------
+class XLocalBranch2d(nn.Module):
+    def __init__(self, channels: int) -> None:
+        super().__init__()
+        self.branch3 = nn.Sequential(
+            Conv2dBN(
+                channels, channels, 3, 1, 1, groups=channels
+            ),  # 3×3 depthwise 局部感受野
+            nn.ReLU(inplace=True),
+            Conv2dBN(channels, channels, 1, 1, 0),  # 1×1 通道重映射
+        )
+        self.branch5 = nn.Sequential(
+            Conv2dBN(
+                channels, channels, 5, 1, 2, groups=channels
+            ),  # 5×5 depthwise 更大感受野
+            nn.ReLU(inplace=True),
+            Conv2dBN(channels, channels, 1, 1, 0),
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return self.branch3(x) + self.branch5(x)  # 多尺度局部特征融合
+
+
+# --------------------------------------------------------------------------
+# XWaveletTransform2d:基于 ptwt 的 2D 小波变换
+# 为什么:将特征分解为低频近似 (LL) 与高频细节 (LH, HL, HH),便于频率域操作
+# 关键行为:
+#   - 使用 ptwt.wavedec2 / ptwt.waverec2 实现可逆小波分解与重建
+#   - 支持任意 pywt 兼容小波(haar, db4, sym2 等)
+#   - 输出格式:(ll_coeff, (lh_coeff, hl_coeff, hh_coeff))
+# --------------------------------------------------------------------------
+class XWaveletTransform2d(nn.Module):
+    def __init__(self, wavelet: str = "haar", level: int = 1) -> None:
+        super().__init__()
+        self.wavelet = wavelet
+        self.level = level
+
+    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+        """
+        分解输入张量。
+        Returns:
+            ll: 低频近似系数 [B, C, H', W']
+            high: 高频细节张量,拼接 LH/HL/HH 为 [B, C*3, H', W']
+        """
+        coeffs = ptwt.wavedec2(x, self.wavelet, level=self.level)
+        ll = coeffs[0]  # 低频近似
+        detail_tuple = coeffs[1]  # (lh, hl, hh) 元组
+        high = torch.cat([detail_tuple[0], detail_tuple[1], detail_tuple[2]], dim=1)
+        return ll, high
+
+    def inverse(
+        self, ll: torch.Tensor, high: torch.Tensor, output_size: tuple[int, int]
+    ) -> torch.Tensor:
+        """
+        从低频和高频系数重建原始张量。
+        Args:
+            ll: 低频近似系数
+            high: 高频细节张量 [B, C*3, H', W']
+            output_size: 目标输出尺寸 (H, W)
+        """
+        lh = high[:, 0 : high.shape[1] // 3]
+        hl = high[:, high.shape[1] // 3 : 2 * high.shape[1] // 3]
+        hh = high[:, 2 * high.shape[1] // 3 :]
+        coeffs = [ll, (lh, hl, hh)]
+        # ptwt.waverec2 自动处理边界对齐,无需手动裁剪
+        return ptwt.waverec2(coeffs, self.wavelet)
+
+
+# --------------------------------------------------------------------------
+# XWaveletBranch2d:小波分支
+# 为什么:对小波分解后的低频和高频分别做特征学习,再重建回空间域
+# 关键行为:
+#   - 当前仅支持 Haar 小波和 level=1(设计约束)
+#   - 高频通道数 = channels * 3,需单独投影
+# --------------------------------------------------------------------------
+class XWaveletBranch2d(nn.Module):
+    def __init__(
+        self, channels: int, wavelet_type: str = "haar", wavelet_level: int = 1
+    ) -> None:
+        super().__init__()
+        self.wavelet = XWaveletTransform2d(wavelet=wavelet_type, level=wavelet_level)
+        # 低频通道投影
+        self.ll_proj = nn.Sequential(
+            Conv2dBN(channels, channels, 3, 1, 1),
+            nn.ReLU(inplace=True),
+        )
+        # 高频通道投影(depthwise 处理多高频分量)
+        self.high_proj = nn.Sequential(
+            Conv2dBN(channels * 3, channels * 3, 3, 1, 1, groups=channels * 3),
+            nn.ReLU(inplace=True),
+            Conv2dBN(channels * 3, channels * 3, 1, 1, 0),
+        )
+        # 重建后输出投影
+        self.out_proj = nn.Sequential(
+            Conv2dBN(channels, channels, 1, 1, 0),
+            nn.ReLU(inplace=True),
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        output_size = x.shape[-2:]
+        ll, high = self.wavelet(x)  # 分解
+        ll = self.ll_proj(ll)
+        high = self.high_proj(high)
+        x = self.wavelet.inverse(ll, high, output_size=output_size)  # 重建
+        return self.out_proj(x)
+
+
+# --------------------------------------------------------------------------
+# XSSMGlobalBranch2d:SSM 全局分支(核心:VMamba SS2D)
+# 为什么:用 State Space Model 捕获长程依赖,弥补卷积局部感受野不足
+# 关键行为:
+#   - 自动选择后端:CUDA→oflex(快速),否则→torch(兼容)
+#   - 通过 monkey-patch forward_core 动态切换 scan 策略
+#   - 用完后恢复原始 forward_core 避免状态污染
+# --------------------------------------------------------------------------
+class XSSMGlobalBranch2d(nn.Module):
+    def __init__(
+        self,
+        channels: int,
+        global_ratio: float = 2.0,
+        d_state: int = 16,
+        forward_type: str = "v3",
+        ssm_backend: str = "auto",
+    ) -> None:
+        super().__init__()
+        hidden_ratio = max(global_ratio, 1.0)  # SSM 隐层缩放比例
+        self.backend = ssm_backend
+        self.pre = nn.Sequential(
+            Conv2dBN(channels, channels, 1, 1, 0),  # 预投影归一化
+            nn.ReLU(inplace=True),
+        )
+        self.ssm = VMambaSS2D(
+            d_model=channels,
+            d_state=d_state,
+            ssm_ratio=hidden_ratio,
+            d_conv=3,
+            dropout=0.0,
+            initialize="v0",
+            forward_type=forward_type,
+            channel_first=True,
+        )
+        self.post = nn.Sequential(
+            Conv2dBN(channels, channels, 1, 1, 0),  # 后投影归一化
+            nn.ReLU(inplace=True),
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.pre(x)
+        prev_backend = None
+        backend = self.backend.lower()
+        if backend == "auto":
+            backend = "oflex" if x.is_cuda else "torch"
+
+        # 动态切换 SSM 后端(避免修改全局配置)
+        if backend == "oflex" and hasattr(self.ssm, "forward_core"):
+            prev_backend = self.ssm.forward_core
+            self.ssm.forward_core = lambda z, _core=prev_backend: _core(
+                z,
+                selective_scan_backend="oflex",
+                scan_force_torch=False,
+            )
+        elif backend == "torch" and hasattr(self.ssm, "forward_core"):
+            prev_backend = self.ssm.forward_core
+            self.ssm.forward_core = lambda z, _core=prev_backend: _core(
+                z,
+                selective_scan_backend="torch",
+                scan_force_torch=True,
+            )
+        try:
+            x = self.ssm(x)  # SSM 全局建模
+        finally:
+            if prev_backend is not None:
+                self.ssm.forward_core = prev_backend  # 恢复原始后端
+        return self.post(x)
+
+
+# --------------------------------------------------------------------------
+# XGlobalBranch2d:全局分支包装器
+# 为什么:提供统一接口,将 SSM 分支暴露为可开关的模块
+# --------------------------------------------------------------------------
+class XGlobalBranch2d(nn.Module):
+    def __init__(
+        self,
+        channels: int,
+        global_ratio: float = 2.0,
+        ssm_d_state: int = 16,
+        ssm_forward_type: str = "v3",
+        ssm_backend: str = "auto",
+    ) -> None:
+        super().__init__()
+        self.ssm_branch = XSSMGlobalBranch2d(
+            channels=channels,
+            global_ratio=global_ratio,
+            d_state=ssm_d_state,
+            forward_type=ssm_forward_type,
+            ssm_backend=ssm_backend,
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return self.ssm_branch(x)
+
+
+# --------------------------------------------------------------------------
+# XBranchFusion2d:多分支特征融合
+# 为什么:将局部/小波/全局三个分支的输出自适应加权融合
+# 关键行为:
+#   - 通道拼接 → 1×1 压缩 → 通道注意力门控(Channel Attention Gate)
+#   - 门控值经 Sigmoid 后与融合特征逐元素相乘
+# --------------------------------------------------------------------------
+class XBranchFusion2d(nn.Module):
+    def __init__(self, channels: int, num_branches: int = 3) -> None:
+        super().__init__()
+        fused_channels = channels * num_branches
+        hidden_channels = max(channels // 4, 8)  # 门控网络隐藏维度
+        self.fuse = nn.Sequential(
+            Conv2dBN(fused_channels, channels, 1, 1, 0),  # 通道降维融合
+            nn.ReLU(inplace=True),
+        )
+        # 通道注意力门控
+        self.gate = nn.Sequential(
+            nn.AdaptiveAvgPool2d(1),  # 全局平均池化 → 空间不变
+            nn.Conv2d(fused_channels, hidden_channels, kernel_size=1, bias=True),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(hidden_channels, channels, kernel_size=1, bias=True),
+            nn.Sigmoid(),  # 门控值 [0, 1]
+        )
+
+    def forward(self, branch_outputs: Sequence[torch.Tensor]) -> torch.Tensor:
+        x_cat = torch.cat(list(branch_outputs), dim=1)  # 拼接所有分支
+        x_fused = self.fuse(x_cat)
+        gate = self.gate(x_cat)  # 计算通道门控
+        return x_fused * gate  # 门控加权融合
+
+
+# --------------------------------------------------------------------------
+# XTEB2d:X-Tri-Enhance-Block (2D) — 核心构建块
+# 为什么:将局部、小波、全局三个分支并行融合,并叠加 FFN 残差
+# 关键行为:
+#   - pre_norm:先做 1×1 投影再输入多分支
+#   - fusion:XBranchFusion2d 自适应融合三分支
+#   - post + FFN:双层残差连接(post-fusion + FFN)
+# --------------------------------------------------------------------------
+class XTEB2d(nn.Module):
+    def __init__(
+        self,
+        channels: int,
+        global_ratio: float = 2.0,
+        wavelet_type: str = "haar",
+        wavelet_level: int = 1,
+        use_wavelet_branch: bool = True,
+        use_global_branch: bool = True,
+        ssm_d_state: int = 16,
+        ssm_forward_type: str = "v3",
+        ssm_backend: str = "auto",
+    ) -> None:
+        super().__init__()
+        self.pre_norm = Conv2dBN(channels, channels, 1, 1, 0)  # 预投影
+        self.local_branch = XLocalBranch2d(channels)  # 局部分支(始终启用)
+        # 小波分支(可开关)
+        self.wavelet_branch = (
+            XWaveletBranch2d(
+                channels, wavelet_type=wavelet_type, wavelet_level=wavelet_level
+            )
+            if use_wavelet_branch
+            else nn.Identity()
+        )
+        # 全局 SSM 分支(可开关)
+        self.global_branch = (
+            XGlobalBranch2d(
+                channels,
+                global_ratio=global_ratio,
+                ssm_d_state=ssm_d_state,
+                ssm_forward_type=ssm_forward_type,
+                ssm_backend=ssm_backend,
+            )
+            if use_global_branch
+            else nn.Identity()
+        )
+        self.fusion = XBranchFusion2d(channels, num_branches=3)  # 三分支融合
+        # 后处理残差块
+        self.post = nn.Sequential(
+            Conv2dBN(channels, channels, 3, 1, 1),
+            nn.ReLU(inplace=True),
+            Conv2dBN(channels, channels, 1, 1, 0, bn_weight_init=0.0),  # 零初始化
+        )
+        # FFN 残差块
+        self.ffn = nn.Sequential(
+            Conv2dBN(channels, channels * 2, 1, 1, 0),  # 通道扩展
+            nn.ReLU(inplace=True),
+            Conv2dBN(channels * 2, channels, 1, 1, 0, bn_weight_init=0.0),  # 零初始化
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x_in = x
+        x = self.pre_norm(x)
+        # 三分支并行 + 融合 + 残差
+        x = x_in + self.post(
+            self.fusion(
+                [self.local_branch(x), self.wavelet_branch(x), self.global_branch(x)]
+            )
+        )
+        # FFN 残差
+        return x + self.ffn(x)
+
+
+# --------------------------------------------------------------------------
+# XNetEncoderStage2d:编码器阶段
+# 为什么:堆叠多个 XTEB2d 块作为单一编码器层级
+# --------------------------------------------------------------------------
+class XNetEncoderStage2d(nn.Module):
+    def __init__(
+        self,
+        channels: int,
+        depth: int,
+        global_ratio: float = 2.0,
+        wavelet_type: str = "haar",
+        wavelet_level: int = 1,
+        use_wavelet_branch: bool = True,
+        use_global_branch: bool = True,
+        ssm_d_state: int = 16,
+        ssm_forward_type: str = "v3",
+        ssm_backend: str = "auto",
+    ) -> None:
+        super().__init__()
+        self.blocks = nn.Sequential(
+            *[
+                XTEB2d(
+                    channels=channels,
+                    global_ratio=global_ratio,
+                    wavelet_type=wavelet_type,
+                    wavelet_level=wavelet_level,
+                    use_wavelet_branch=use_wavelet_branch,
+                    use_global_branch=use_global_branch,
+                    ssm_d_state=ssm_d_state,
+                    ssm_forward_type=ssm_forward_type,
+                    ssm_backend=ssm_backend,
+                )
+                for _ in range(depth)
+            ]
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return self.blocks(x)
+
+
+# --------------------------------------------------------------------------
+# XNetEncoder2d:完整编码器
+# 为什么:Stem + 4 个阶段 + 3 个下采样 → 多尺度特征金字塔 [e1, e2, e3, e4]
+# 关键约束:
+#   - 阶段数固定为 4(由构造函数校验)
+#   - Stage1 默认关闭全局 SSM(浅层特征不适合长程建模)
+#   - stage_channels 属性暴露各阶段输出通道数
+# --------------------------------------------------------------------------
+class XNetEncoder2d(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        stem_channels: int,
+        encoder_channels: Sequence[int],
+        encoder_depths: Sequence[int],
+        global_ratio: float = 2.0,
+        wavelet_type: str = "haar",
+        wavelet_level: int = 1,
+        use_wavelet_branch: bool = True,
+        use_global_branch_stage1: bool = False,
+        ssm_d_state: int = 16,
+        ssm_forward_type: str = "v3",
+        ssm_backend: str = "auto",
+    ) -> None:
+        super().__init__()
+        if len(encoder_channels) != 4 or len(encoder_depths) != 4:
+            raise ValueError("XNetEncoder2d expects 4 encoder stages.")
+        c1, c2, c3, c4 = encoder_channels
+        d1, d2, d3, d4 = encoder_depths
+        self.stem = XNetStem2d(in_channels, stem_channels, c1)
+        # Stage 1:浅层,可选关闭全局分支
+        self.stage1 = XNetEncoderStage2d(
+            c1,
+            d1,
+            global_ratio,
+            wavelet_type,
+            wavelet_level,
+            use_wavelet_branch=use_wavelet_branch,
+            use_global_branch=use_global_branch_stage1,
+            ssm_d_state=ssm_d_state,
+            ssm_forward_type=ssm_forward_type,
+            ssm_backend=ssm_backend,
+        )
+        self.down1 = XNetDownsample2d(c1, c2)
+        # Stage 2-4:始终启用全局分支
+        self.stage2 = XNetEncoderStage2d(
+            c2,
+            d2,
+            global_ratio,
+            wavelet_type,
+            wavelet_level,
+            use_wavelet_branch,
+            True,
+            ssm_d_state=ssm_d_state,
+            ssm_forward_type=ssm_forward_type,
+            ssm_backend=ssm_backend,
+        )
+        self.down2 = XNetDownsample2d(c2, c3)
+        self.stage3 = XNetEncoderStage2d(
+            c3,
+            d3,
+            global_ratio,
+            wavelet_type,
+            wavelet_level,
+            use_wavelet_branch,
+            True,
+            ssm_d_state=ssm_d_state,
+            ssm_forward_type=ssm_forward_type,
+            ssm_backend=ssm_backend,
+        )
+        self.down3 = XNetDownsample2d(c3, c4)
+        self.stage4 = XNetEncoderStage2d(
+            c4,
+            d4,
+            global_ratio,
+            wavelet_type,
+            wavelet_level,
+            use_wavelet_branch,
+            True,
+            ssm_d_state=ssm_d_state,
+            ssm_forward_type=ssm_forward_type,
+            ssm_backend=ssm_backend,
+        )
+        self.stage_channels = list(encoder_channels)  # 暴露各阶段通道数
+
+    def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
+        e1 = self.stage1(self.stem(x))  # 浅层特征
+        e2 = self.stage2(self.down1(e1))  # 中层特征
+        e3 = self.stage3(self.down2(e2))  # 深层特征
+        e4 = self.stage4(self.down3(e3))  # 最深特征
+        return [e1, e2, e3, e4]  # 多尺度特征金字塔
+
+
+# --------------------------------------------------------------------------
+# XGuideProjector2d:引导投影器
+# 为什么:从编码器特征生成引导信号(guide),用于解码器的自适应调制
+# 关键行为:
+#   - affine 模式:输出 (gamma, beta) 用于仿射调制
+#   - feature 模式:直接输出特征
+# --------------------------------------------------------------------------
+class XGuideProjector2d(nn.Module):
+    def __init__(
+        self, in_channels: int, out_channels: int, mode: str = "affine"
+    ) -> None:
+        super().__init__()
+        self.mode = mode
+        if mode == "affine":
+            # 输出双倍通道 → 后续拆分为 gamma 和 beta
+            self.proj = nn.Sequential(
+                Conv2dBN(in_channels, out_channels * 2, 1, 1, 0),
+                nn.ReLU(inplace=True),
+                nn.Conv2d(out_channels * 2, out_channels * 2, kernel_size=1, bias=True),
+            )
+        elif mode == "feature":
+            self.proj = nn.Sequential(
+                Conv2dBN(in_channels, out_channels, 1, 1, 0),
+                nn.ReLU(inplace=True),
+            )
+        else:
+            raise ValueError(f"Unsupported guide mode: {mode}")
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        target_size: tuple[int, int],
+    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
+        # 插值到目标尺寸(guide 需要与解码器特征空间对齐)
+        x = F.interpolate(x, size=target_size, mode="bilinear", align_corners=False)
+        x = self.proj(x)
+        if self.mode == "affine":
+            gamma, beta = torch.chunk(x, 2, dim=1)  # 拆分为仿射参数
+            gamma = torch.sigmoid(gamma) + 0.5  # gamma 偏置到 [0.5, 1.5]
+            return gamma, beta
+        return x
+
+
+# --------------------------------------------------------------------------
+# XSkipFusion2d:跳跃连接融合
+# 为什么:将编码器特征与解码器特征融合后传入
+# 关键行为:
+#   - 分别投影输入和跳跃特征到相同维度
+#   - 拼接 + 3×3 卷积融合
+# --------------------------------------------------------------------------
+class XSkipFusion2d(nn.Module):
+    def __init__(self, in_channels: int, skip_channels: int, out_channels: int) -> None:
+        super().__init__()
+        self.input_proj = nn.Sequential(
+            Conv2dBN(in_channels, out_channels, 1, 1, 0),  # 解码器特征投影
+            nn.ReLU(inplace=True),
+        )
+        self.skip_proj = nn.Sequential(
+            Conv2dBN(skip_channels, out_channels, 1, 1, 0),  # 跳跃特征投影
+            nn.ReLU(inplace=True),
+        )
+        self.fuse = nn.Sequential(
+            Conv2dBN(out_channels * 2, out_channels, 3, 1, 1),  # 拼接后融合
+            nn.ReLU(inplace=True),
+        )
+
+    def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
+        # 双线性插值对齐空间尺寸
+        x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
+        x = self.input_proj(x)
+        skip = self.skip_proj(skip)
+        return self.fuse(torch.cat([x, skip], dim=1))  # 通道拼接融合
+
+
+# --------------------------------------------------------------------------
+# XGuideModulation2d:引导调制器
+# 为什么:对特征应用仿射调制 (gamma * x + beta) 或特征驱动调制
+# --------------------------------------------------------------------------
+class XGuideModulation2d(nn.Module):
+    def __init__(self, channels: int, guide_mode: str = "affine") -> None:
+        super().__init__()
+        self.guide_mode = guide_mode
+        if guide_mode == "feature":
+            # feature 模式下先将 guide 转为仿射参数
+            self.to_affine = nn.Conv2d(channels, channels * 2, kernel_size=1, bias=True)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        guide: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
+    ) -> torch.Tensor:
+        if self.guide_mode == "affine":
+            gamma, beta = guide  # 直接使用仿射参数
+        else:
+            gamma, beta = torch.chunk(self.to_affine(guide), 2, dim=1)
+            gamma = torch.sigmoid(gamma) + 0.5
+        return gamma * x + beta  # 仿射调制
+
+
+# --------------------------------------------------------------------------
+# XFrequencyRefine2d:频率域精炼
+# 为什么:在频域对低频/高频分别应用门控,增强关键频率成分
+# 关键行为:
+#   - FFT → 低频中心保留 + 高频带通 → 逆 FFT
+#   - 门控由自适应平均池化生成
+# --------------------------------------------------------------------------
+class XFrequencyRefine2d(nn.Module):
+    def __init__(self, channels: int) -> None:
+        super().__init__()
+        # 低频门控
+        self.low_gate = nn.Sequential(
+            nn.AdaptiveAvgPool2d(1),
+            nn.Conv2d(channels, channels, kernel_size=1, bias=True),
+            nn.Sigmoid(),
+        )
+        # 高频门控
+        self.high_gate = nn.Sequential(
+            nn.AdaptiveAvgPool2d(1),
+            nn.Conv2d(channels, channels, kernel_size=1, bias=True),
+            nn.Sigmoid(),
+        )
+        # 频域精炼后的空间域细化
+        self.refine = nn.Sequential(
+            Conv2dBN(
+                channels, channels, 3, 1, 1, groups=channels
+            ),  # depthwise 局部细化
+            nn.ReLU(inplace=True),
+            Conv2dBN(channels, channels, 1, 1, 0),
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        input_dtype = x.dtype
+        if x.dtype != torch.float32:
+            x = x.to(torch.float32)  # FFT 需要 float32 精度
+        fft = torch.fft.rfft2(x, norm="ortho")  # 实值 FFT
+        h_freq, w_freq = fft.shape[-2], fft.shape[-1]
+        # 构建圆形低频掩码(中心位于四个角:FFT 未 shift 时低频在四角)
+        # 使用 fftshift 将低频移至中心,应用掩码后再 ifftshift 还原
+        fft_shifted = torch.fft.fftshift(fft, dim=(-2, -1))
+        low = fft_shifted.clone()
+        # 圆形低频掩码:保留中心区域
+        radius_h = h_freq // 4
+        radius_w = w_freq // 4
+        y_grid, x_grid = torch.meshgrid(
+            torch.arange(h_freq, device=fft.device),
+            torch.arange(w_freq, device=fft.device),
+            indexing="ij",
+        )
+        center_y, center_x = h_freq // 2, w_freq // 2
+        mask = (y_grid - center_y) ** 2 + (x_grid - center_x) ** 2 <= max(
+            radius_h, radius_w
+        ) ** 2
+        mask = mask.unsqueeze(0).unsqueeze(0).expand(fft.shape[0], fft.shape[1], -1, -1)
+        low = low * mask  # 低频分量
+        high = fft_shifted - low  # 高频 = 全部 - 低频
+        # 还原到原始 FFT 坐标系
+        low = torch.fft.ifftshift(low, dim=(-2, -1))
+        high = torch.fft.ifftshift(high, dim=(-2, -1))
+        # 应用通道门控(门控值来自空间域)
+        low = low * self.low_gate(x)
+        high = high * self.high_gate(x)
+        out = torch.fft.irfft2(low + high, s=x.shape[-2:], norm="ortho")  # 逆 FFT
+        out = out.to(dtype=input_dtype)
+        return self.refine(out)  # 空间域细化
+
+
+# --------------------------------------------------------------------------
+# XCRB2d:X-ResBlock with Guide (2D) — 解码器核心块
+# 为什么:融合跳跃连接 + 引导调制 + 频率精炼,是解码器重建的基础单元
+# 数据流:
+#   输入特征 → SkipFusion → GuideModulation → FrequencyRefine → OutRefine
+#   每步均有残差连接
+# --------------------------------------------------------------------------
+class XCRB2d(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        skip_channels: int,
+        guide_channels: int,
+        out_channels: int,
+        guide_mode: str = "affine",
+        use_frequency_refine: bool = True,
+    ) -> None:
+        super().__init__()
+        self.skip_fusion = XSkipFusion2d(in_channels, skip_channels, out_channels)
+        self.guide_modulation = XGuideModulation2d(out_channels, guide_mode=guide_mode)
+        self.frequency_refine = (
+            XFrequencyRefine2d(out_channels) if use_frequency_refine else nn.Identity()
+        )
+        # 输出细化(零初始化末尾以渐进学习)
+        self.out_refine = nn.Sequential(
+            Conv2dBN(out_channels, out_channels, 3, 1, 1),
+            nn.ReLU(inplace=True),
+            Conv2dBN(out_channels, out_channels, 3, 1, 1, bn_weight_init=0.0),
+        )
+        self.guide_channels = guide_channels
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        skip: torch.Tensor,
+        guide: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
+    ) -> torch.Tensor:
+        x = self.skip_fusion(x, skip)  # 跳跃融合
+        x = self.guide_modulation(x, guide)  # 引导调制
+        x = x + self.frequency_refine(x)  # 频率精炼残差
+        return x + self.out_refine(x)  # 输出细化残差
+
+
+# --------------------------------------------------------------------------
+# XNetHeadRefine2d:特征精炼头
+# 为什么:在解码器末端做最后的特征增强
+# --------------------------------------------------------------------------
+class XNetHeadRefine2d(nn.Module):
+    def __init__(self, channels: int, out_channels: int | None = None) -> None:
+        super().__init__()
+        if out_channels is None:
+            out_channels = channels
+        self.block = nn.Sequential(
+            Conv2dBN(channels, out_channels, 3, 1, 1),
+            nn.ReLU(inplace=True),
+            Conv2dBN(out_channels, out_channels, 3, 1, 1),
+            nn.ReLU(inplace=True),
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return self.block(x)
+
+
+# --------------------------------------------------------------------------
+# XNetDecoder2d:完整解码器
+# 为什么:从最深特征 e4 逐步上采样,逐层引入引导信号和跳跃连接
+# 关键数据流:
+#   e4 → guide4 → dec4 → guide3 → dec3 → guide2 → dec2 → head_refine
+#   返回:输出特征、所有解码特征、所有引导信号(供损失函数使用)
+# --------------------------------------------------------------------------
+class XNetDecoder2d(nn.Module):
+    def __init__(
+        self,
+        encoder_channels: Sequence[int],
+        decoder_channels: Sequence[int] = (128, 64, 32),
+        guide_mode: str = "affine",
+        use_frequency_refine: bool = True,
+        out_channels: int | None = None,
+    ) -> None:
+        super().__init__()
+        if len(encoder_channels) != 4:
+            raise ValueError("XNetDecoder2d expects 4 encoder stages.")
+        if len(decoder_channels) != 3:
+            raise ValueError("XNetDecoder2d expects 3 decoder channels.")
+        c1, c2, c3, c4 = encoder_channels
+        d4, d3, d2 = decoder_channels
+        # 引导投影器(从编码器特征生成 guide)
+        self.guide4 = XGuideProjector2d(c4, d4, mode=guide_mode)
+        self.guide3 = XGuideProjector2d(c3, d3, mode=guide_mode)
+        self.guide2 = XGuideProjector2d(c2, d2, mode=guide_mode)
+        # 解码块(逐层降通道 + 跳跃融合)
+        self.dec4 = XCRB2d(
+            c4,
+            c3,
+            d4,
+            d4,
+            guide_mode=guide_mode,
+            use_frequency_refine=use_frequency_refine,
+        )
+        self.dec3 = XCRB2d(
+            d4,
+            c2,
+            d3,
+            d3,
+            guide_mode=guide_mode,
+            use_frequency_refine=use_frequency_refine,
+        )
+        self.dec2 = XCRB2d(
+            d3,
+            c1,
+            d2,
+            d2,
+            guide_mode=guide_mode,
+            use_frequency_refine=use_frequency_refine,
+        )
+        self.head_refine = XNetHeadRefine2d(d2, out_channels or d2)
+        self.out_channels = out_channels or d2
+
+    def forward(
+        self,
+        features: Sequence[torch.Tensor],
+    ) -> tuple[
+        torch.Tensor,
+        list[torch.Tensor],
+        list[torch.Tensor | tuple[torch.Tensor, torch.Tensor]],
+    ]:
+        e1, e2, e3, e4 = features
+        # 从深到浅逐层解码
+        g4 = self.guide4(e4, target_size=e3.shape[-2:])  # 从 e4 生成 guide
+        d4 = self.dec4(e4, e3, g4)  # 解码 + 跳跃 e3
+        g3 = self.guide3(e3, target_size=e2.shape[-2:])
+        d3 = self.dec3(d4, e2, g3)  # 解码 + 跳跃 e2
+        g2 = self.guide2(e2, target_size=e1.shape[-2:])
+        d2 = self.dec2(d3, e1, g2)  # 解码 + 跳跃 e1
+        d1 = self.head_refine(d2)  # 最终精炼
+        # 返回解码输出、中间特征(用于辅助损失)、引导信号
+        return d1, [d4, d3, d2, d1], [g4, g3, g2]
+
+
+# --------------------------------------------------------------------------
+# XNetSegHead2d:分割头
+# 为什么:将最终特征映射为 logits 图,并上采样到原始输入尺寸
+# --------------------------------------------------------------------------
+class XNetSegHead2d(nn.Module):
+    def __init__(
+        self, in_channels: int, num_classes: int, upsample_scale: int = 4
+    ) -> None:
+        super().__init__()
+        self.block = nn.Sequential(
+            Conv2dBN(in_channels, in_channels, 3, 1, 1),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(
+                in_channels, num_classes, kernel_size=1, bias=True
+            ),  # 映射到类别数
+        )
+        self.upsample_scale = upsample_scale
+
+    def forward(self, x: torch.Tensor, output_size: tuple[int, int]) -> torch.Tensor:
+        x = self.block(x)
+        # 双线性上采样到目标尺寸(推理时传入原始输入 H, W)
+        return F.interpolate(x, size=output_size, mode="bilinear", align_corners=False)
+
+
+# ==========================================================================
+# XNet2d:完整网络(编码器 + Bottleneck + 解码器 + 分割头)
+# 架构概览:
+#   输入 → Stem → [Stage1 ↓ Stage2 ↓ Stage3 ↓ Stage4] → Bottleneck
+#         → [dec4 ← dec3 ← dec2] → Head → Logits
+# 业务特点:
+#   - 编码器浅层(Stage1)默认关闭 SSM 以降低计算开销
+#   - 解码器逐层注入 guide 信号,实现自适应特征调制
+#   - 每个解码块支持频率精炼,增强医学图像细节保留
+# ==========================================================================
+class XNet2d(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        num_classes: int,
+        encoder_channels: Sequence[int] = (32, 64, 128, 192),
+        encoder_depths: Sequence[int] = (2, 2, 2, 2),
+        decoder_channels: Sequence[int] = (128, 64, 32),
+        stem_channels: int = 24,
+        bottleneck_depth: int = 1,
+        global_ratio: float = 2.0,
+        wavelet_type: str = "haar",
+        wavelet_level: int = 1,
+        use_wavelet_branch: bool = True,
+        use_global_branch_stage1: bool = False,
+        ssm_d_state: int = 16,
+        ssm_forward_type: str = "v3",
+        ssm_backend: str = "auto",
+        use_frequency_refine: bool = True,
+        guide_mode: str = "affine",
+        out_channels: int | None = None,
+    ) -> None:
+        super().__init__()
+        # 编码器:多尺度特征金字塔
+        self.encoder = XNetEncoder2d(
+            in_channels=in_channels,
+            stem_channels=stem_channels,
+            encoder_channels=encoder_channels,
+            encoder_depths=encoder_depths,
+            global_ratio=global_ratio,
+            wavelet_type=wavelet_type,
+            wavelet_level=wavelet_level,
+            use_wavelet_branch=use_wavelet_branch,
+            use_global_branch_stage1=use_global_branch_stage1,
+            ssm_d_state=ssm_d_state,
+            ssm_forward_type=ssm_forward_type,
+            ssm_backend=ssm_backend,
+        )
+        # Bottleneck:最深特征进一步建模
+        bottleneck_channels = encoder_channels[-1]
+        self.bottleneck = nn.Sequential(
+            *[
+                XTEB2d(
+                    channels=bottleneck_channels,
+                    global_ratio=global_ratio,
+                    wavelet_type=wavelet_type,
+                    wavelet_level=wavelet_level,
+                    use_wavelet_branch=use_wavelet_branch,
+                    use_global_branch=True,  # bottleneck 始终启用全局分支
+                    ssm_d_state=ssm_d_state,
+                    ssm_forward_type=ssm_forward_type,
+                    ssm_backend=ssm_backend,
+                )
+                for _ in range(bottleneck_depth)
+            ]
+        )
+        # 解码器
+        self.decoder = XNetDecoder2d(
+            encoder_channels=encoder_channels,
+            decoder_channels=decoder_channels,
+            guide_mode=guide_mode,
+            use_frequency_refine=use_frequency_refine,
+            out_channels=out_channels,
+        )
+        # 分割头
+        head_in_channels = self.decoder.out_channels
+        self.segmentation_head = XNetSegHead2d(head_in_channels, num_classes)
+
+    def forward(
+        self, x: torch.Tensor
+    ) -> dict[
+        str, torch.Tensor | list[torch.Tensor] | list[tuple[torch.Tensor, torch.Tensor]]
+    ]:
+        encoder_features = self.encoder(x)  # 多尺度特征 [e1, e2, e3, e4]
+        encoder_features[-1] = self.bottleneck(encoder_features[-1])  # bottleneck
+        decoder_out, decoder_features, guides = self.decoder(encoder_features)  # 解码
+        output_size = x.shape[-2:]
+        logits = self.segmentation_head(
+            decoder_out, output_size=output_size
+        )  # 分割 logits
+        # 返回字典:包含 logits、中间特征(用于辅助损失)、引导信号
+        outputs: dict[
+            str,
+            torch.Tensor | list[torch.Tensor] | list[tuple[torch.Tensor, torch.Tensor]],
+        ] = {
+            "logits": logits,
+            "seg_logits": logits,
+            "encoder_features": encoder_features,
+            "decoder_features": decoder_features,
+            "guides": guides,
+        }
+        return outputs

+ 5 - 0
lib/trainers/supervised.py

@@ -43,6 +43,11 @@ class SupervisedSegmentationTrainer(BaseTrainer):
             ssm_forward_type=str(model_cfg.get("ssm_forward_type", "v3")),
             ssm_backend=str(model_cfg.get("ssm_backend", "auto")),
             use_frequency_refine=bool(model_cfg.get("use_frequency_refine", True)),
+            low_freq_radius_h=float(model_cfg.get("low_freq_radius_h", 0.25)),
+            low_freq_radius_w=float(model_cfg.get("low_freq_radius_w", 0.25)),
+            learnable_low_freq_radius=bool(
+                model_cfg.get("learnable_low_freq_radius", True)
+            ),
             guide_mode=str(model_cfg.get("guide_mode", "affine")),
             out_channels=model_cfg.get("out_channels"),
         ).to(self.device)

+ 26 - 0
requirements.txt

@@ -0,0 +1,26 @@
+# Core training stack
+torch>=2.2
+torchvision>=0.17
+numpy>=1.24
+Pillow>=10.0
+PyYAML>=6.0
+
+# Medical segmentation losses and metrics
+monai>=1.3
+
+# XNet2d wavelet and VMamba dependencies
+ptwt>=0.1.9
+PyWavelets>=1.5
+timm>=1.0
+fvcore>=0.1.5
+einops>=0.7
+packaging>=23.0
+triton>=2.2; platform_system == "Linux"
+
+# Experiment logging and utilities
+swanlab>=0.6
+tqdm>=4.66
+matplotlib>=3.8
+
+# Tests
+pytest>=8.0

+ 38 - 0
tests/test_xnet_2d.py

@@ -0,0 +1,38 @@
+from __future__ import annotations
+
+import torch
+from torch import nn
+
+from lib.modules.xnet_2d import XNet2d, XTEB2d
+
+
+def test_xnet2d_forward_preserves_segmentation_shape() -> None:
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    model = XNet2d(
+        in_channels=3,
+        num_classes=1,
+        encoder_channels=(8, 16, 24, 32),
+        encoder_depths=(1, 1, 1, 1),
+        decoder_channels=(24, 16, 8),
+        stem_channels=8,
+        bottleneck_depth=1,
+        global_ratio=1.0,
+        use_wavelet_branch=True,
+        use_global_branch_stage1=False,
+        ssm_d_state=1,
+        ssm_backend="torch",
+        use_frequency_refine=True,
+        learnable_low_freq_radius=False,
+    ).to(device)
+    if device.type == "cpu":
+        for module in model.modules():
+            if isinstance(module, XTEB2d):
+                module.global_branch = nn.Identity()
+    model.eval()
+
+    x = torch.randn(2, 3, 64, 64, device=device)
+    with torch.no_grad():
+        outputs = model(x)
+
+    assert outputs["seg_logits"].shape == (2, 1, 64, 64)
+    assert outputs["logits"].shape == outputs["seg_logits"].shape

+ 0 - 1421
tmp/docs/training/当前项目详解与纯文本架构流程图.md

@@ -1,1421 +0,0 @@
-# 当前项目详解与纯文本架构流程图
-
-## 1. 当前项目定位
-
-`X_SSL_Net` 当前 active 主线是一个面向 2D 超声图像分割的全监督训练工程。
-
-当前真实训练链路:
-
-```text
-shell script
--> tools/train.py
--> SupervisedSegmentationTrainer
--> SegmentationRecordDataset / DataLoader
--> XNet2d
--> seg_logits
--> DiceCE loss / BCE fallback
--> Dice / IoU validation
--> best.pth / last.pth
-```
-
-当前真实模型主线:
-
-```text
-XNet2d = X-shaped CNN-Wavelet-VMamba hybrid segmentation network
-```
-
-当前主训练只使用一个分割头:
-
-```text
-outputs["seg_logits"]
-```
-
-当前主线不调用:
-
-1. `lib/sam2`
-2. `lib/SwinTransformer`
-3. SwinV2 segmentation config
-4. boundary auxiliary head
-5. semi-supervised trainer
-
-`lib/sam2` 与 `lib/SwinTransformer` 目前作为外部代码资产保留,不进入当前训练路径。
-
-## 2. 一句话总览
-
-当前项目可以概括为:
-
-```text
-用 XNet2d 在 BUSI / DDTI / TN3K / TG3K 等 2D 超声数据集上做全监督分割训练。
-XNet2d 的 encoder 用 local + wavelet + VMamba-style SS2D 三分支建模,
-decoder 用同尺度 skip + 斜向 guide + 频率细化恢复 mask。
-```
-
-## 3. 启动入口
-
-### 3.1 推荐 shell 入口
-
-最常用入口:
-
-```bash
-DATASET=BUSI bash tools/run_us_experiments.sh
-```
-
-短跑调试入口:
-
-```bash
-DATASET=BUSI \
-EXTRA_SET_ARGS="train.epochs=1 train.batch_size=8 train.val_batch_size=8 logging.use_swanlab=false checkpoint.dir=outputs/validation/xnet_oflex_b8" \
-bash tools/run_us_experiments.sh
-```
-
-### 3.2 shell 脚本职责
-
-文件:
-
-```text
-tools/run_us_experiments.sh
-```
-
-它做四件事:
-
-1. 解析 `DATASET`
-2. 映射数据集根目录
-3. 对需要项目级划分的数据集生成或加载 `train/val`
-4. 调用 `tools/train.py`
-
-支持的数据集名称:
-
-```text
-BUS-UCLM
-BUSI
-BUS-BRA
-BUS_UC
-CCAUI
-DDTI
-OTU_2d
-TN3K
-TG3K
-```
-
-数据集根目录映射:
-
-```text
-BUSI      -> data/BUSI
-DDTI      -> data/DDTI
-TN3K      -> data/TN3K
-TG3K      -> data/TG3K
-BUS_UC    -> data/BUS_UC
-...
-```
-
-项目级 split 数据集:
-
-```text
-BUS-UCLM, BUSI, BUS-BRA, BUS_UC, CCAUI, DDTI
-```
-
-官方 split 数据集:
-
-```text
-OTU_2d, TN3K, TG3K
-```
-
-## 4. 从 shell 到 Python 的总流程图
-
-```text
-User command
-  |
-  |  DATASET=BUSI EXTRA_SET_ARGS="..." bash tools/run_us_experiments.sh
-  v
-+----------------------------------------------------------------------------------+
-| tools/run_us_experiments.sh                                                       |
-+----------------------------------------------------------------------------------+
-| 1. read DATASET / SEED / EXTRA_SET_ARGS                                           |
-| 2. dataset_root(DATASET)                                                          |
-| 3. if DATASET needs project split:                                                |
-|      python scripts/generate_project_split.py --dataset DATASET --root ROOT       |
-| 4. python tools/train.py                                                          |
-|      --config configs/segmentation/train_sup_us_template.yaml                     |
-|      --set dataset.dataset_name=DATASET dataset.root=ROOT ... EXTRA_SET_ARGS      |
-+----------------------------------------------------------------------------------+
-  |
-  v
-+----------------------------------------------------------------------------------+
-| tools/train.py                                                                    |
-+----------------------------------------------------------------------------------+
-| 1. parse --config / --trainer / --set                                             |
-| 2. load yaml config                                                               |
-| 3. apply dotlist overrides                                                        |
-| 4. optional override trainer.name                                                 |
-| 5. build_trainer(cfg)                                                             |
-| 6. trainer.train()                                                                |
-+----------------------------------------------------------------------------------+
-```
-
-## 5. 配置系统
-
-当前主配置:
-
-```text
-configs/segmentation/train_sup_us_template.yaml
-```
-
-当前保留的 segmentation 配置:
-
-```text
-configs/segmentation/train_sup_us_template.yaml
-configs/segmentation/us_exp_sup_busi.yaml
-configs/segmentation/us_exp_sup_busi_ablation.yaml
-```
-
-### 5.1 配置覆盖方式
-
-`tools/train.py` 支持:
-
-```text
---set key=value key=value ...
-```
-
-例如:
-
-```bash
---set train.epochs=1 train.batch_size=8 model.use_frequency_refine=false
-```
-
-覆盖逻辑:
-
-```text
-load_yaml_config(path)
-  |
-  v
-apply_dotlist_overrides(cfg, args.set)
-  |
-  v
-nested dict update
-```
-
-### 5.2 当前关键配置
-
-训练:
-
-```yaml
-train:
-  epochs: 200
-  batch_size: 4
-  val_batch_size: 4
-  amp: true
-  num_workers: 4
-  pin_memory: true
-  persistent_workers: true
-  prefetch_factor: 2
-  device: cuda
-  grad_clip:
-    enabled: true
-    max_norm: 1.0
-```
-
-数据:
-
-```yaml
-dataset:
-  dataset_name: BUSI
-  root: data/BUSI
-  split: train
-  val_split: val
-  image_size: [256, 256]
-  in_channels: 3
-  num_classes: 1
-```
-
-模型:
-
-```yaml
-model:
-  in_channels: 3
-  encoder_channels: [32, 64, 128, 192]
-  encoder_depths: [2, 2, 2, 2]
-  decoder_channels: [128, 64, 32]
-  stem_channels: 24
-  bottleneck_depth: 1
-  global_ratio: 2.0
-  wavelet_type: haar
-  wavelet_level: 1
-  use_wavelet_branch: true
-  use_global_branch_stage1: false
-  ssm_d_state: 16
-  ssm_forward_type: v3
-  ssm_backend: auto
-  use_frequency_refine: true
-  guide_mode: affine
-  out_channels: null
-```
-
-优化:
-
-```yaml
-optimizer:
-  name: adamw
-  lr: 1.0e-4
-  weight_decay: 0.05
-
-scheduler:
-  name: cosine
-  warmup:
-    name: linear
-    params:
-      start_factor: 0.1
-      total_iters: 10
-  params:
-    T_max: 190
-    eta_min: 1.0e-6
-```
-
-loss 与 metric:
-
-```yaml
-loss:
-  name: dicece
-  task_mode: binary
-  params:
-    include_background: true
-    lambda_dice: 0.7
-    lambda_ce: 0.3
-
-validation:
-  threshold: 0.5
-  metrics:
-    task_mode: binary
-    metrics:
-      - name: dice
-      - name: iou
-```
-
-## 6. Trainer 构建流程
-
-入口:
-
-```text
-lib/trainers/builder.py::build_trainer
-```
-
-当前 trainer:
-
-```text
-lib/trainers/supervised.py::SupervisedSegmentationTrainer
-```
-
-构建流程:
-
-```text
-build_trainer(cfg)
-  |
-  v
-read cfg.trainer.name
-  |
-  v
-TRAINER_REGISTRY["supervised_segmentation"]
-  |
-  v
-trainer = SupervisedSegmentationTrainer(cfg, args)
-  |
-  v
-trainer.build()
-  |
-  v
-return trainer
-```
-
-`SupervisedSegmentationTrainer.build()` 做:
-
-```text
-1. dataset_cfg = cfg["dataset"]
-2. model_cfg   = cfg["model"]
-3. train_cfg   = cfg["train"]
-
-4. build XNet2d from model_cfg
-5. move model to device
-6. build optimizer
-7. build scheduler
-8. build loss if cfg.loss is not null
-9. build train dataloader
-10. build validation dataloader
-11. maybe resume checkpoint
-12. maybe init SwanLab
-```
-
-## 7. BaseTrainer 公共职责
-
-文件:
-
-```text
-lib/trainers/base.py
-```
-
-公共职责:
-
-```text
-BaseTrainer
-├─ random seed
-├─ device selection
-├─ output directory
-├─ AMP GradScaler
-├─ batch size resolution
-├─ dataloader construction helper
-├─ validation metric construction
-├─ checkpoint save / resume
-├─ early stopping
-├─ SwanLab logging
-├─ training setup summary
-├─ step performance logging
-└─ epoch finalization
-```
-
-设备选择:
-
-```text
-cfg.train.device == "cuda" and torch.cuda.is_available()
-  -> cuda
-else
-  -> cpu
-```
-
-AMP 开关:
-
-```text
-cfg.train.amp == true and device == cuda
-  -> enabled
-else
-  -> disabled
-```
-
-当前已验证目标环境:
-
-```text
-conda env: xnet_mamba
-torch: 2.10.0+cu126
-GPU: NVIDIA GeForce RTX 4070 Ti SUPER
-selective_scan_cuda_oflex: available
-```
-
-## 8. 数据链路
-
-### 8.1 数据 index 构建
-
-入口:
-
-```text
-lib/data/builder.py::build_dataset_index
-```
-
-核心 registry:
-
-```text
-BUS-UCLM -> paired images/masks
-BUSI     -> Dataset_BUSI_with_GT/{benign,malignant,normal}
-BUS-BRA  -> prefixed image/mask matching
-BUS_UC   -> All / Benign / Malignant folders
-CCAUI    -> US images / Expert mask images
-DDTI     -> XML annotation records
-OTU_2d   -> images / annotations
-TN3K     -> trainval/test image/mask folders
-TG3K     -> thyroid-image / thyroid-mask
-```
-
-### 8.2 split 应用
-
-入口:
-
-```text
-lib/data/loaders.py::apply_official_split
-```
-
-流程:
-
-```text
-build_dataset_index(dataset_name, root)
-  |
-  v
-if split is requested:
-  |
-  +-- OTU_2d: read train.txt / val.txt
-  |
-  +-- TN3K: read tn3k-trainval.json or use test folder
-  |
-  +-- TG3K: read tg3k-trainval.json
-  |
-  +-- project split dataset:
-        read data/<dataset>/splits/project/train.txt or val.txt
-```
-
-项目级 split 生成:
-
-```text
-scripts/generate_project_split.py
-  |
-  v
-generate_project_splits()
-  |
-  v
-write:
-  data/<dataset>/splits/project/train.txt
-  data/<dataset>/splits/project/val.txt
-```
-
-### 8.3 Dataset 读取
-
-文件:
-
-```text
-lib/data/datasets.py::SegmentationRecordDataset
-```
-
-单样本读取:
-
-```text
-record
-  |
-  +-- image_path -> PIL RGB -> float32 [3,H,W] in [0,1]
-  |
-  +-- mask_path  -> PIL L -> binary float32 [1,H,W]
-  |
-  +-- DDTI special:
-        annotation_path XML -> build_ddti_mask() -> binary [1,H,W]
-  |
-  +-- joint augmentation
-  |
-  +-- resize image to dataset.image_size
-  |
-  +-- resize mask to dataset.image_size
-  |
-  v
-{
-  "image": image,
-  "mask": mask,
-  "dataset_name": ...,
-  "sample_id": ...,
-  "split": ...,
-  "class_name": ...,
-  "meta": ...
-}
-```
-
-### 8.4 augmentation
-
-文件:
-
-```text
-lib/data/augment.py::SegmentationAugmentation
-```
-
-当前支持:
-
-```text
-spatial:
-  random horizontal flip
-  random vertical flip
-  random rotate 90
-
-intensity:
-  random brightness / contrast
-  random gaussian noise
-  clamp to [0,1]
-```
-
-### 8.5 collate
-
-文件:
-
-```text
-lib/data/collate.py::record_collate_fn
-```
-
-逻辑:
-
-```text
-if all tensor shapes same:
-  torch.stack(values, dim=0)
-else:
-  keep list
-
-strings / dict / metadata:
-  keep list
-```
-
-最终 batch:
-
-```text
-image: [B,3,256,256]
-mask : [B,1,256,256]
-```
-
-## 9. Dataloader 流程图
-
-```text
-SupervisedSegmentationTrainer.build()
-  |
-  v
-_build_segmentation_loader(split="train")
-  |
-  v
-build_dataloader()
-  |
-  v
-build_record_dataset()
-  |
-  v
-build_dataset_index()
-  |
-  v
-apply_official_split()
-  |
-  v
-SegmentationRecordDataset(records, transforms)
-  |
-  v
-DataLoader(
-  batch_size,
-  shuffle,
-  num_workers,
-  pin_memory,
-  persistent_workers,
-  prefetch_factor,
-  collate_fn=record_collate_fn
-)
-```
-
-注意:`DataLoader` worker 的真实启动通常发生在第一次迭代时,也就是 `======== END TRAINING SETUP ========` 之后。若 `num_workers > 0`,第一批数据可能出现一次性等待。
-
-## 10. XNet2d 总体结构
-
-文件:
-
-```text
-lib/modules/xnet_2d.py
-```
-
-当前默认参数量:
-
-```text
-total parameters:     9,432,129
-trainable parameters: 9,432,129
-```
-
-顶层结构:
-
-```text
-XNet2d
-├─ XNetEncoder2d
-│  ├─ XNetStem2d
-│  ├─ Encoder Stage 1: XTEB2d x 2
-│  ├─ Downsample 1
-│  ├─ Encoder Stage 2: XTEB2d x 2
-│  ├─ Downsample 2
-│  ├─ Encoder Stage 3: XTEB2d x 2
-│  ├─ Downsample 3
-│  └─ Encoder Stage 4: XTEB2d x 2
-│
-├─ Bottleneck: XTEB2d x 1
-│
-├─ XNetDecoder2d
-│  ├─ guide4: E4 -> D4 affine guide
-│  ├─ dec4: XCRB2d(E4, E3, guide4)
-│  ├─ guide3: E3 -> D3 affine guide
-│  ├─ dec3: XCRB2d(D4, E2, guide3)
-│  ├─ guide2: E2 -> D2 affine guide
-│  ├─ dec2: XCRB2d(D3, E1, guide2)
-│  └─ head_refine
-│
-└─ XNetSegHead2d
-```
-
-## 11. XNet2d 纯文本架构图
-
-以输入 `[B,3,256,256]` 为例,默认通道为 `[32,64,128,192]`:
-
-```text
-Input
-[B, 3, 256, 256]
-  |
-  v
-XNetStem2d
-  Conv3x3 s2:       [B, 24, 128, 128]
-  DWConv3x3:        [B, 24, 128, 128]
-  PWConv1x1:        [B, 32, 128, 128]
-  Conv3x3 s2:       [B, 32,  64,  64]
-  |
-  v
-E1 = Encoder Stage 1, XTEB x2
-[B, 32, 64, 64]
-  |
-  v
-Down1
-[B, 64, 32, 32]
-  |
-  v
-E2 = Encoder Stage 2, XTEB x2
-[B, 64, 32, 32]
-  |
-  v
-Down2
-[B, 128, 16, 16]
-  |
-  v
-E3 = Encoder Stage 3, XTEB x2
-[B, 128, 16, 16]
-  |
-  v
-Down3
-[B, 192, 8, 8]
-  |
-  v
-E4 = Encoder Stage 4, XTEB x2
-[B, 192, 8, 8]
-  |
-  v
-Bottleneck XTEB x1
-[B, 192, 8, 8]
-```
-
-Decoder:
-
-```text
-E4 [B,192,8,8]
-  |
-  +-- guide4 = Phi(E4) -> resize to E3 size -> affine gamma/beta for d4
-  |
-  v
-dec4 input:
-  decoder input: E4 [B,192,8,8]
-  same-scale skip: E3 [B,128,16,16]
-  guide: g4
-  output D4 [B,128,16,16]
-
-D4 [B,128,16,16]
-  |
-  +-- guide3 = Phi(E3) -> resize to E2 size -> affine gamma/beta for d3
-  |
-  v
-dec3 input:
-  decoder input: D4 [B,128,16,16]
-  same-scale skip: E2 [B,64,32,32]
-  guide: g3
-  output D3 [B,64,32,32]
-
-D3 [B,64,32,32]
-  |
-  +-- guide2 = Phi(E2) -> resize to E1 size -> affine gamma/beta for d2
-  |
-  v
-dec2 input:
-  decoder input: D3 [B,64,32,32]
-  same-scale skip: E1 [B,32,64,64]
-  guide: g2
-  output D2 [B,32,64,64]
-
-D2 [B,32,64,64]
-  |
-  v
-HeadRefine
-[B,32,64,64]
-  |
-  v
-SegHead + upsample to input size
-[B,1,256,256]
-```
-
-## 12. XTEB2d 详解
-
-`XTEB2d` 是 encoder 的基本 block。
-
-名字含义:
-
-```text
-XTEB = XNet Tri-branch Encoding Block
-```
-
-输入输出:
-
-```text
-input : X [B,C,H,W]
-output: Y [B,C,H,W]
-```
-
-内部结构:
-
-```text
-X
-│
-├─ pre_norm: 1x1 Conv2dBN
-│
-├─ Local branch
-│   ├─ DWConv3x3 + PWConv1x1
-│   └─ DWConv5x5 + PWConv1x1
-│
-├─ Wavelet branch
-│   ├─ Haar DWT
-│   │   ├─ LL
-│   │   └─ LH/HL/HH high bands
-│   ├─ LL projection
-│   ├─ high-band projection
-│   └─ inverse Haar transform
-│
-├─ Global branch
-│   ├─ 1x1 pre projection
-│   ├─ VMamba-style SS2D
-│   └─ 1x1 post projection
-│
-├─ concat(local, wavelet, global)
-├─ 1x1 fusion
-├─ channel gate from GAP + MLP + sigmoid
-├─ residual add
-└─ lightweight FFN + residual add
-```
-
-公式化:
-
-```text
-X0 = PreNorm(X)
-
-L = Local(X0)
-W = Wavelet(X0)
-G = GlobalSS2D(X0)
-
-F = Fuse([L,W,G])
-Y = X + Post(F)
-Z = Y + FFN(Y)
-```
-
-### 12.1 Local branch
-
-职责:
-
-```text
-局部纹理、边界、短程结构
-```
-
-结构:
-
-```text
-DWConv3x3 -> ReLU -> PWConv1x1
-DWConv5x5 -> ReLU -> PWConv1x1
-sum
-```
-
-### 12.2 Wavelet branch
-
-职责:
-
-```text
-低频轮廓 + 高频边界/纹理
-```
-
-结构:
-
-```text
-Haar DWT:
-  LL      -> low-frequency structure
-  LH/HL/HH -> high-frequency directional details
-
-LL -> Conv projection
-High bands -> depthwise conv + pointwise conv
-IDWT -> output projection
-```
-
-当前限制:
-
-```text
-wavelet_type = haar
-wavelet_level = 1
-```
-
-### 12.3 Global SS2D branch
-
-职责:
-
-```text
-高效长程依赖建模、全局结构一致性
-```
-
-当前实现:
-
-```text
-lib/modules/lib_mamba/vmamba.py::SS2D
-```
-
-来源:
-
-```text
-VMamba-style SS2D operator
-```
-
-后端选择:
-
-```text
-ssm_backend = auto
-  |
-  +-- if x.is_cuda:
-        selective_scan_backend = oflex
-        scan_force_torch = false
-  |
-  +-- else:
-        selective_scan_backend = torch
-        scan_force_torch = true
-
-ssm_backend = oflex
-  -> force oflex
-
-ssm_backend = torch
-  -> force torch fallback
-```
-
-当前默认:
-
-```text
-ssm_forward_type = v3
-ssm_backend = auto
-```
-
-在 `xnet_mamba` + RTX 4070 Ti SUPER 环境中已验证:
-
-```text
-selective_scan_cuda_oflex import OK
-WITH_SELECTIVESCAN_OFLEX = True
-```
-
-## 13. XCRB2d 详解
-
-`XCRB2d` 是 decoder 的基本 block。
-
-名字含义:
-
-```text
-XCRB = XNet Cross-guided Reconstruction Block
-```
-
-输入:
-
-```text
-decoder input: deeper decoder or bottleneck feature
-same-scale skip: encoder feature at target scale
-diagonal guide: deeper encoder semantic guide
-```
-
-内部结构:
-
-```text
-decoder input
-  |
-  v
-bilinear upsample to skip size
-  |
-  v
-1x1 projection
-  |
-  +-----------------------------+
-                                |
-same-scale skip                 |
-  |                             |
-  v                             |
-1x1 projection                  |
-  |                             |
-  +----------- concat ----------+
-                  |
-                  v
-             3x3 fusion
-                  |
-                  v
-      guide affine modulation
-                  |
-                  v
-        optional frequency refine
-                  |
-                  v
-        residual spatial refine
-```
-
-### 13.1 X-shaped 信息流
-
-当前 decoder 不只是普通 U-Net 横向 skip。
-
-它同时使用:
-
-```text
-same-scale path:
-  E3 -> D4
-  E2 -> D3
-  E1 -> D2
-
-diagonal guide path:
-  E4 -> D4
-  E3 -> D3
-  E2 -> D2
-```
-
-纯文本示意:
-
-```text
-Encoder: E1 ---------------------------> D2
-            \                          /
-             \                        /
-Encoder:      E2 -------------------> D3
-               \      guide to D2   /
-                \                  /
-Encoder:         E3 -------------> D4
-                  \ guide to D3  /
-                   \            /
-Encoder:            E4 --------/
-                     guide to D4
-```
-
-### 13.2 Guide modulation
-
-默认 `guide_mode=affine`。
-
-流程:
-
-```text
-guide feature
-  |
-  v
-resize to target decoder scale
-  |
-  v
-projection -> [gamma, beta]
-  |
-  v
-gamma = sigmoid(gamma) + 0.5
-  |
-  v
-F' = gamma * F + beta
-```
-
-### 13.3 Frequency refine
-
-默认 `use_frequency_refine=true`。
-
-流程:
-
-```text
-feature F
-  |
-  v
-cast to float32 if needed
-  |
-  v
-rfft2
-  |
-  +-- low frequency mask
-  |
-  +-- high frequency residual
-  |
-  v
-low/high learnable gates
-  |
-  v
-irfft2
-  |
-  v
-cast back to input dtype
-  |
-  v
-depthwise conv refine
-```
-
-这里显式将 FFT 计算放在 `float32` 中,避免 AMP 下触发 `ComplexHalf support is experimental` warning。
-
-## 14. XNet2d forward 输出
-
-`XNet2d.forward(x)` 返回:
-
-```python
-{
-    "logits": logits,
-    "seg_logits": logits,
-    "encoder_features": encoder_features,
-    "decoder_features": decoder_features,
-    "guides": guides,
-}
-```
-
-训练只使用:
-
-```text
-outputs["seg_logits"]
-```
-
-其余输出用于:
-
-```text
-debug
-visualization
-future auxiliary analysis
-```
-
-当前没有边界辅助输出。
-
-## 15. 训练循环详解
-
-入口:
-
-```text
-SupervisedSegmentationTrainer.train()
-```
-
-流程:
-
-```text
-train()
-  |
-  v
-print training setup
-  |
-  v
-for epoch in range(start_epoch, epochs):
-  |
-  +-- model.train()
-  +-- optimizer.zero_grad()
-  +-- for step, batch in train_loader:
-        |
-        +-- measure data_time
-        |
-        +-- image = batch["image"].to(device)
-        +-- mask  = batch["mask"].to(device)
-        |
-        +-- with autocast(enabled=amp):
-              outputs = model(image)
-              seg_logits = outputs["seg_logits"]
-              seg_loss = loss(seg_logits, mask)
-              total_loss = seg_loss
-        |
-        +-- scaled_total_loss = total_loss / accum_steps
-        +-- grad_scaler.scale(scaled_total_loss).backward()
-        |
-        +-- if should optimizer step:
-              unscale gradients if grad clipping enabled
-              clip grad norm
-              grad_scaler.step(optimizer)
-              grad_scaler.update()
-              optimizer.zero_grad()
-        |
-        +-- log step every logging.log_interval
-  |
-  +-- scheduler.step()
-  |
-  +-- validate if enabled and interval matches
-  |
-  +-- finalize epoch
-        |
-        +-- merge train / val metrics
-        +-- update best metric
-        +-- save best.pth if improved
-        +-- save last.pth if enabled
-        +-- early stopping check
-```
-
-## 16. Loss 路径
-
-当前配置使用:
-
-```text
-MONAI DiceCELoss
-```
-
-构建路径:
-
-```text
-cfg.loss
-  |
-  v
-lib/tools/loss.py::build_loss
-  |
-  v
-DiceCELoss(sigmoid=True, include_background=True, lambda_dice=0.7, lambda_ce=0.3)
-```
-
-如果 `loss: null`:
-
-```text
-torch.nn.functional.binary_cross_entropy_with_logits(seg_logits, mask)
-```
-
-该 fallback 适合环境临时缺 MONAI 时做 smoke test,不建议作为正式论文训练默认。
-
-## 17. Validation 路径
-
-验证函数:
-
-```text
-SupervisedSegmentationTrainer._validate()
-```
-
-流程:
-
-```text
-model.eval()
-build validation metrics
-for batch in val_loader:
-  image -> device
-  mask -> device
-  outputs, losses = _compute_losses(image, mask)
-  update loss sums
-  update metrics with outputs["seg_logits"]
-
-average val loss
-compute Dice / IoU
-reset metric states
-return val_metrics
-```
-
-metric 输入处理:
-
-```text
-binary mode:
-  pred = sigmoid(logits) >= threshold
-  target = target > 0
-
-multiclass mode:
-  pred = argmax(logits)
-  target = one-hot or class index
-```
-
-当前默认:
-
-```text
-threshold = 0.5
-metrics = Dice, IoU
-```
-
-## 18. Checkpoint 路径
-
-checkpoint 目录:
-
-```text
-cfg.checkpoint.dir
-```
-
-默认脚本会覆盖为:
-
-```text
-outputs/experiments/supervised/<DATASET>
-```
-
-保存文件:
-
-```text
-best.pth
-last.pth
-```
-
-checkpoint 内容:
-
-```text
-epoch
-cfg
-metrics
-model state_dict
-optimizer state_dict
-scheduler state_dict
-grad_scaler state_dict
-best_metric
-no_improve_epochs
-```
-
-best 判断:
-
-```text
-monitor = dice
-monitor_mode = max
-```
-
-即:
-
-```text
-val_dice 越大越好
-```
-
-## 19. 日志与性能字段
-
-每隔 `logging.log_interval` step 打印:
-
-```text
-epoch
-step
-num_steps
-data_time
-iter_time
-gpu_memory_mb
-lr
-train_total
-train_seg
-train_grad_norm
-```
-
-含义:
-
-```text
-data_time:
-  从上一步结束到当前 batch 可用的时间。
-  num_workers > 0 时,第一批 worker 启动开销发生在 END TRAINING SETUP 之后。
-
-iter_time:
-  当前 step 的训练计算时间,包括 forward、loss、backward、optimizer step。
-
-gpu_memory_mb:
-  torch.cuda.max_memory_allocated。
-```
-
-当前实测参考:
-
-```text
-batch_size = 8
-image_size = 256
-ssm_backend = auto -> oflex
-iter_time ≈ 0.09 - 0.11 s / step
-GPU memory ≈ 850 MB
-```
-
-## 20. 从输入到 loss 的端到端流程图
-
-```text
-Batch from DataLoader
-  |
-  +-- image [B,3,256,256]
-  +-- mask  [B,1,256,256]
-  |
-  v
-image.to(cuda), mask.to(cuda)
-  |
-  v
-autocast(enabled=True)
-  |
-  v
-XNet2d(image)
-  |
-  +-- encoder_features = [E1,E2,E3,E4]
-  |
-  +-- bottleneck(E4)
-  |
-  +-- decoder_out, decoder_features, guides
-  |
-  +-- segmentation_head(decoder_out)
-  |
-  v
-seg_logits [B,1,256,256]
-  |
-  v
-DiceCELoss(seg_logits, mask)
-  |
-  v
-total_loss
-  |
-  v
-GradScaler.scale(total_loss).backward()
-  |
-  v
-clip gradients
-  |
-  v
-optimizer.step()
-```
-
-## 21. 关键运行命令
-
-GPU 环境检查:
-
-```bash
-python -c "import sys, torch; print(sys.executable); print(torch.__version__); print(torch.cuda.is_available()); print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'no cuda')"
-```
-
-oflex 检查:
-
-```bash
-python -c "import torch; import selective_scan_cuda_oflex; print('oflex import OK')"
-python -c "import torch; from lib.modules.lib_mamba import csms6s; print(csms6s.WITH_SELECTIVESCAN_OFLEX)"
-```
-
-前向检查:
-
-```bash
-python - <<'PY'
-import torch
-from lib.modules import XNet2d
-
-model = XNet2d(in_channels=3, num_classes=1, ssm_backend="auto", ssm_forward_type="v3").cuda().eval()
-x = torch.randn(1, 3, 128, 128, device="cuda")
-with torch.no_grad():
-    y = model(x)
-print(sorted(y.keys()))
-print(tuple(y["seg_logits"].shape))
-PY
-```
-
-短训:
-
-```bash
-DATASET=BUSI \
-EXTRA_SET_ARGS="train.epochs=1 train.batch_size=8 train.val_batch_size=8 logging.use_swanlab=false checkpoint.dir=outputs/validation/xnet_oflex_b8" \
-bash tools/run_us_experiments.sh
-```
-
-关闭 frequency refine 消融:
-
-```bash
-DATASET=BUSI \
-EXTRA_SET_ARGS="train.epochs=1 train.batch_size=8 train.val_batch_size=8 model.use_frequency_refine=false logging.use_swanlab=false checkpoint.dir=outputs/validation/xnet_oflex_b8_no_freq" \
-bash tools/run_us_experiments.sh
-```
-
-汇总结果:
-
-```bash
-bash tools/summarize_results.sh
-sed -n '1,40p' results/experiment_summary.md
-```
-
-## 22. 推荐实验主线
-
-第一阶段:训练链路稳定性
-
-```text
-BUSI smoke
-BUSI batch size 8
-BUSI no frequency refine
-```
-
-第二阶段:甲状腺主线
-
-```text
-DDTI
-TN3K
-TG3K
-DDTI -> TN3K / TN3K -> DDTI 跨数据集泛化
-```
-
-第三阶段:乳腺扩展
-
-```text
-BUSI
-BUS_UC
-BUS-BRA
-BUS-UCLM
-```
-
-第四阶段:核心消融
-
-```text
-use_wavelet_branch=false
-use_frequency_refine=false
-ssm_backend=torch
-use_global_branch_stage1=true
-encoder_depths=[2,2,3,2]
-```
-
-## 23. 当前边界与注意事项
-
-1. 当前文档描述的是 active XNet2d 全监督主链。
-2. 当前训练主链只优化 `seg_logits`。
-3. `lib/sam2` 保留但不参与训练。
-4. `lib/SwinTransformer` 保留但不参与训练。
-5. `ssm_backend=auto` 在 CUDA 上应走 `oflex`,这是当前速度优化后的默认路径。
-6. `XFrequencyRefine2d` 的 FFT 计算使用 float32,避免 AMP 下 ComplexHalf warning。
-7. `num_workers > 0` 时,第一次进入 dataloader 迭代可能在 `END TRAINING SETUP` 后产生一次性等待。