Parcourir la source

feat(Initial commit):
XNet - Polyp Segmentation with Wavelet-FFT Enhanced SwinUNETR

Features:
- Implementation of Wavelet_FFT_SwinUNETR architecture for medical image segmentation
- Adaptive Wavelet Augmented Enhancer (AWAE) module for multi-scale feature enhancement
- FFT-based frequency domain enhancement via FFTRCAB modules
- Combined Dice + CrossEntropy + IoU loss function for robust training
- Comprehensive data augmentation pipeline with geometric and photometric transforms
- Early stopping mechanism with automatic resume from best checkpoint
- SwanLab experiment tracking integration
- Full evaluation pipeline with metrics: mDice, mIoU, mHD, mHD95

Project Structure:
- train.py: Main training script with configurable hyperparameters
- eval.py: Model evaluation and visualization tool
- lib/model/: Core model architecture (SwinUNETR backbone)
- lib/modules/: Enhancement modules (AWAE, FFTRCAB)
- datasets/: Polyp detection dataset loader with MONAI integration

Technical Highlights:
- Multi-level wavelet decomposition with adaptive attention
- Frequency-domain feature enhancement using FFT

kekeZack il y a 1 mois
commit
d5359a8c2e

+ 60 - 0
.gitignore

@@ -0,0 +1,60 @@
+# ---> Python
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+env/
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+parts/
+sdist/
+var/
+*.egg-info/
+.installed.cfg
+*.egg
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*,cover
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# IntelliJ IDEA project files
+.idea

+ 72 - 0
LICENSE

@@ -0,0 +1,72 @@
+Apache License 
+Version 2.0, January 2004 
+http://www.apache.org/licenses/
+TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+1. Definitions.
+
+"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
+
+"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
+
+"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
+
+"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
+
+"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
+
+"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
+
+"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
+
+"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
+
+"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
+
+"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
+
+2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
+
+3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
+
+4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
+
+(a) You must give any other recipients of the Work or Derivative Works a copy of this License; and
+
+(b) You must cause any modified files to carry prominent notices stating that You changed the files; and
+
+(c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
+
+(d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
+
+You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
+
+5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
+
+6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
+
+7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
+
+8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
+
+9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
+
+END OF TERMS AND CONDITIONS
+
+APPENDIX: How to apply the Apache License to your work.
+
+To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives.
+
+Copyright [yyyy] [name of copyright owner]
+
+Licensed under the Apache License, Version 2.0 (the "License"); 
+you may not use this file except in compliance with the License. 
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software 
+distributed under the License is distributed on an "AS IS" BASIS, 
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
+See the License for the specific language governing permissions and 
+limitations under the License.

+ 3 - 0
README.md

@@ -0,0 +1,3 @@
+# XNet
+
+XNet: A Staged Dual-Frequency Synergistic Framework via Wavelet-FFT for Medical Image Segmentation of Small Objects and Weak Boundaries

+ 110 - 0
datasets/PolypDetectionDataset/PolypDetectionDataset.py

@@ -0,0 +1,110 @@
+from pathlib import Path
+from typing import Callable, Optional, List, Dict, Any
+from monai.data import Dataset
+
+
+class PolypDetectionDataset(Dataset):
+    """Polyp detection dataset for training and validation."""
+
+    def __init__(
+        self, root_dir: str, flag: str = "train", transform: Optional[Callable] = None
+    ):
+        """Initialize dataset.
+        
+        Args:
+            root_dir: Root directory containing images, masks, and split files
+            flag: Dataset split ('train' or 'val')
+            transform: Optional transformations to apply
+        """
+        # Set up directory paths
+        self.root_dir = Path(root_dir)
+        self.images_dir = self.root_dir / "images"
+        self.labels_dir = self.root_dir / "masks"
+        self.flag = flag.lower()
+        self.transform = transform
+
+        # Validate split flag
+        if self.flag not in ["train", "val"]:
+            raise ValueError(f"flag must be 'train' or 'val', got '{self.flag}'")
+
+        # Check if directories exist
+        if not self.images_dir.exists():
+            raise FileNotFoundError(f"Image directory does not exist: {self.images_dir}")
+        if not self.labels_dir.exists():
+            raise FileNotFoundError(f"Label directory does not exist: {self.labels_dir}")
+
+        # Load image filenames from split file
+        txt_file = self.root_dir / f"{self.flag}.txt"
+        if not txt_file.exists():
+            raise FileNotFoundError(
+                f"{self.flag}.txt file not found in {self.root_dir}\n"
+                f"Please ensure train.txt and val.txt files exist in this directory"
+            )
+
+        with open(txt_file, "r", encoding="utf-8") as f:
+            self.images: List[str] = [
+                line.strip() for line in f.readlines() if line.strip()
+            ]
+
+        # Labels have same filenames as images
+        self.labels: List[str] = self.images.copy()
+
+        # Create data list with image-label pairs
+        data = [
+            {"image": str(self.images_dir / img), "label": str(self.labels_dir / lbl)}
+            for img, lbl in zip(self.images, self.labels)
+        ]
+
+        super().__init__(data=data, transform=transform)
+
+    def __len__(self) -> int:
+        """Return number of samples in dataset."""
+        return len(self.images)
+
+    def __getitem__(self, idx: int) -> Dict[str, Any]:
+        """Get a single sample by index.
+        
+        Args:
+            idx: Sample index
+            
+        Returns:
+            Dictionary with 'image' and 'label' tensors
+        """
+        # Build file paths
+        image_path = str(self.images_dir / self.images[idx])
+        label_path = str(self.labels_dir / self.labels[idx])
+
+        data = {"image": image_path, "label": label_path}
+
+        # Apply transformations if specified
+        if self.transform is not None:
+            data = self.transform(data)
+
+        return data
+
+    def get_image_filename(self, idx: int) -> str:
+        """Get image filename by index."""
+        return self.images[idx]
+
+    def get_label_filename(self, idx: int) -> str:
+        """Get label filename by index."""
+        return self.labels[idx]
+
+    def get_dataset_info(self, dataset_name="CVC_300") -> Dict[str, Any]:
+        """Get dataset information.
+        
+        Args:
+            dataset_name: Name of the dataset
+            
+        Returns:
+            Dictionary containing dataset metadata
+        """
+        return {
+            "dataset_name": dataset_name,
+            "split": self.flag,
+            "num_samples": len(self),
+            "root_dir": str(self.root_dir),
+            "images_dir": str(self.images_dir),
+            "labels_dir": str(self.labels_dir),
+            "has_transform": self.transform is not None,
+        }

+ 0 - 0
datasets/__init__.py


+ 564 - 0
eval.py

@@ -0,0 +1,564 @@
+import argparse
+import os
+from datetime import datetime
+from pathlib import Path
+
+import cv2
+import numpy as np
+import torch
+from monai.metrics import DiceMetric, MeanIoU, HausdorffDistanceMetric
+from monai.transforms import (
+    Compose, LoadImaged, ScaleIntensityd, EnsureChannelFirstd,
+    ToTensord, Resized, Lambdad
+)
+from torch.utils.data import DataLoader
+
+from datasets.PolypDetectionDataset.PolypDetectionDataset import PolypDetectionDataset
+from lib.model.model import Wavelet_FFT_SwinUNETR
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description="Polyp Segmentation Model Evaluation Script")
+
+    # ==================== Dataset-related Parameters ====================
+    parser.add_argument(
+        "--dataset_name",
+        type=str,
+        required=True,
+        help="Dataset name"
+    )
+    parser.add_argument(
+        "--data_root",
+        type=str,
+        default=r"./data/Polyp-Detection-Dataset",
+        help="Root directory path of the dataset"
+    )
+    parser.add_argument(
+        "--num_workers",
+        type=int,
+        default=0,
+        help="Number of worker processes for data loader"
+    )
+    parser.add_argument(
+        "--pin_memory",
+        type=bool,
+        default=True,
+        help="Whether to enable pinned memory"
+    )
+    parser.add_argument(
+        "--target_spatial_size",
+        type=tuple,
+        default=(512, 512),
+        help="Target spatial size, format like: '512,512' or '(512,512)'"
+    )
+    parser.add_argument(
+        "--batch_size",
+        type=int,
+        default=1,
+        help="Batch size (smaller batch size recommended for evaluation)"
+    )
+
+    # ==================== Model-related Parameters ====================
+    parser.add_argument(
+        "--in_channels",
+        type=int,
+        default=3,
+        help="Number of input image channels"
+    )
+    parser.add_argument(
+        "--out_channels",
+        type=int,
+        default=1,
+        help="Number of output foreground channels"
+    )
+    parser.add_argument(
+        "--feature_size",
+        type=int,
+        default=48,
+        help="Network feature dimension"
+    )
+    parser.add_argument(
+        "--spatial_dims",
+        type=int,
+        default=2,
+        choices=[2, 3],
+        help="Spatial dimension (2D or 3D)"
+    )
+    parser.add_argument(
+        "--use_wavelet",
+        type=bool,
+        default=True,
+        help="Whether to enable wavelet enhancement module"
+    )
+    parser.add_argument(
+        "--wavelet_J",
+        type=int,
+        default=2,
+        help="Wavelet decomposition levels"
+    )
+    parser.add_argument(
+        "--wavelet_wave",
+        type=str,
+        default="db4",
+        help="Wavelet basis type"
+    )
+    parser.add_argument(
+        "--wavelet_reduction",
+        type=int,
+        default=16,
+        help="Wavelet attention compression ratio"
+    )
+    parser.add_argument(
+        "--use_fft",
+        type=bool,
+        default=True,
+        help="Whether to enable FFT enhancement module"
+    )
+    parser.add_argument(
+        "--use_v2",
+        type=bool,
+        default=True,
+        help="Whether to enable Swin-UNETR v2 module"
+    )
+
+    # ==================== Model Loading Parameters ====================
+    parser.add_argument(
+        "--device",
+        type=str,
+        default="cuda" if torch.cuda.is_available() else "cpu",
+        help="Training device (cuda or cpu)"
+    )
+
+    # ==================== Other Parameters ====================
+    parser.add_argument(
+        "--save_results",
+        type=bool,
+        default=False,
+        help="是否保存预测结果"
+    )
+    parser.add_argument(
+        "--dir_flag",
+        type=str,
+        default="_v4_minute",
+        help="Prediction result save file suffix"
+    )
+    parser.add_argument(
+        "--results_dir",
+        type=str,
+        default="./evaluation_results",
+        help="Directory for saving prediction results"
+    )
+    parser.add_argument(
+        "--outputs_dir",
+        type=str,
+        default="./outputs",
+        help="是否保存预测结果"
+    )
+    parser.add_argument(
+        "--save_visualization",
+        type=bool,
+        default=True,
+        help="Whether to save visualization results"
+    )
+    parser.add_argument(
+        "--vis_num_samples",
+        type=int,
+        default=1000,
+        help="Number of samples to save for visualization"
+    )
+
+    parser.add_argument(
+        "--best_metric",
+        type=str,
+        default=False,
+        help="Load best overall model, False means load best Dice model by default"
+    )
+
+    return parser.parse_args()
+
+
+def create_val_transform(target_spatial_size=(512, 512)):
+    """
+    Create validation set transformations
+    
+    Args:
+        target_spatial_size: Target spatial size
+        
+    Returns:
+        Compose: Validation transformation composition
+    """
+
+    def convert_label_to_single_channel(label_tensor):
+        """Convert RGB labels to single-channel binary mask"""
+        single_channel = label_tensor[0:1, :, :]
+        binary_label = (single_channel > 127).float()
+        return binary_label
+
+    val_transforms = Compose([
+        LoadImaged(keys=["image", "label"]),
+        EnsureChannelFirstd(keys=["image", "label"]),
+        Lambdad(keys=["label"], func=convert_label_to_single_channel),
+        Resized(keys=["image", "label"], spatial_size=target_spatial_size,
+                mode=("bilinear", "nearest")),
+        ScaleIntensityd(keys=["image"]),
+        ToTensord(keys=["image", "label"]),
+    ])
+
+    return val_transforms
+
+
+def create_dataloader(args, dataset):
+    """
+    Create data loader
+    
+    Args:
+        args: Command line arguments
+        dataset: Dataset
+        
+    Returns:
+        DataLoader: Data loader
+    """
+    loader = DataLoader(
+        dataset,
+        batch_size=args.batch_size,
+        shuffle=False,
+        num_workers=args.num_workers,
+        pin_memory=args.pin_memory,
+        drop_last=False
+    )
+
+    return loader
+
+
+def load_model(args, checkpoint_path):
+    """
+    Load model
+    
+    Args:
+        args: Command line arguments
+        checkpoint_path: Checkpoint path
+        
+    Returns:
+        model: Model with loaded weights
+    """
+    print(f"\nLoading model: {checkpoint_path}")
+
+    # Create model
+    model = Wavelet_FFT_SwinUNETR(
+        in_channels=args.in_channels,
+        out_channels=args.out_channels,
+        feature_size=args.feature_size,
+        spatial_dims=args.spatial_dims,
+        wavelet_enhancement=args.use_wavelet,
+        wavelet_J=args.wavelet_J,
+        wavelet_wave=args.wavelet_wave,
+        wavelet_mode='symmetric',
+        wavelet_reduction=args.wavelet_reduction,
+        fft_enhancement=args.use_fft,
+        use_v2=args.use_v2
+    )
+
+    # Load weights
+    if not os.path.exists(checkpoint_path):
+        raise FileNotFoundError(f"Model file does not exist: {checkpoint_path}")
+
+    checkpoint = torch.load(checkpoint_path, map_location=args.device)
+
+    # Check checkpoint format
+    if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
+        model.load_state_dict(checkpoint["model_state_dict"])
+        print("✓ Model weights loaded from training checkpoint")
+    else:
+        model.load_state_dict(checkpoint)
+        print("✓ Model weights loaded directly")
+
+    model = model.to(args.device)
+    model.eval()
+
+    print(f"✓ Model loaded and set to evaluation mode")
+    print(f"✓ 使用设备:{args.device}")
+
+    return model
+
+
+def evaluate_model(model, dataloader, args):
+    """
+    Evaluate model performance
+
+    Args:
+        model: Model
+        dataloader: Data loader
+        args: Command line arguments
+
+    Returns:
+        dict: Dictionary containing various metrics
+    """
+    print("\n" + "=" * 60)
+    print("Starting evaluation...")
+    print("=" * 60)
+
+    # Initialize metrics
+    dice_metric = DiceMetric(reduction="mean")
+    iou_metric = MeanIoU(reduction="mean")
+    hd_metric = HausdorffDistanceMetric(reduction="mean")
+    hd95_metric = HausdorffDistanceMetric(reduction="mean", percentile=95)
+    total_samples = 0
+    flag = 0
+    saved_vis_count = 0
+    vis_dir = None
+    # Create visualization save directory
+    if args.save_visualization:
+        vis_dir = os.path.join(args.results_dir, f"visualization_{args.dataset_name}")
+        os.makedirs(vis_dir, exist_ok=True)
+        print(f"✓ Visualization results will be saved to: {vis_dir}")
+
+    with torch.no_grad():
+        dice_metric.reset()
+        iou_metric.reset()
+        hd_metric.reset()
+        hd95_metric.reset()
+        for batch_idx, batch_data in enumerate(dataloader):
+            images = batch_data["image"].to(args.device)
+            labels = batch_data["label"].to(args.device)
+            # Forward propagation
+            outputs = model(images)  # [B, 1, H, W]
+
+            # Post-processing
+            outputs = torch.sigmoid(outputs)
+            outputs = (outputs > 0.5).int()
+            if flag == 0:
+                flag = 1
+                print(f"\n{'=' * 60}")
+                print(f"[Detailed Debug Info - Batch 0]")
+                print(f"{'=' * 60}")
+                print(f"Input image size: {images.shape}")
+                print(f"Output image size after post-processing: {outputs.shape}")
+                print(f"Unique values: {torch.unique(outputs)}")
+
+                print(f"\n--- Labels ---")
+                print(f"Labels shape: {labels.shape}")
+                print(f"Labels unique values: {torch.unique(labels)}")
+                print(f"{'=' * 60}\n")
+            # Calculate metrics for current batch - use directly without additional processing
+            dice_metric(y_pred=outputs, y=labels)
+            iou_metric(y_pred=outputs, y=labels)
+            hd_metric(y_pred=outputs, y=labels)
+            hd95_metric(y_pred=outputs, y=labels)
+            # Save visualization results
+            if args.save_visualization and saved_vis_count < args.vis_num_samples:
+                save_visualization(
+                    images=images,
+                    labels=labels,
+                    predictions=outputs,
+                    save_dir=vis_dir,
+                    batch_idx=batch_idx,
+                    max_samples=args.vis_num_samples - saved_vis_count
+                )
+                saved_vis_count += images.shape[0]
+
+            # Print progress
+            if (batch_idx + 1) % 10 == 0 or (batch_idx + 1) == len(dataloader):
+                print(f"进度:{batch_idx + 1}/{len(dataloader)} batches")
+            total_samples += images.shape[0]
+
+    # Aggregate all metrics
+    mean_dice = dice_metric.aggregate().item()
+    mean_iou = iou_metric.aggregate().item()
+    mean_hd = hd_metric.aggregate().item()
+    mean_hd95 = hd95_metric.aggregate().item()
+
+    results = {
+        "mDice": mean_dice,
+        "mIoU": mean_iou,
+        "mHD": mean_hd,
+        "mHD95": mean_hd95,
+        "total_samples": total_samples,
+    }
+
+    return results
+
+
+def save_visualization(images, labels, predictions, save_dir, batch_idx, max_samples):
+    """
+    Save visualization results: original image, ground truth label, prediction label combined
+    
+    Args:
+        images: Input image batch [B, C, H, W]
+        labels: Ground truth labels [B, 1, H, W]
+        predictions: Prediction labels [B, 1, H, W]
+        save_dir: Save directory
+        batch_idx: Batch index
+        max_samples: Maximum samples to save
+    """
+    for i in range(min(images.shape[0], max_samples)):
+        try:
+            # Extract single sample
+            image = images[i].cpu()
+            label = labels[i].cpu()
+            prediction = predictions[i].cpu()
+
+            # Image processing: de-normalize and convert to RGB
+            if image.shape[0] == 1:  # 灰度图
+                image_np = image[0].numpy() * 255
+                image_rgb = np.stack([image_np] * 3, axis=-1).astype(np.uint8)
+            else:  # RGB image
+                image_np = image.numpy().transpose(1, 2, 0)
+                # De-normalize (assuming z-score normalization, approximate handling)
+                image_np = np.clip((image_np - image_np.min()) / (image_np.max() - image_np.min() + 1e-8) * 255, 0, 255)
+                image_rgb = image_np.astype(np.uint8)
+
+            # Label processing: convert to binary mask
+            label_np = label[0].numpy() if label.shape[0] == 1 else label.numpy()
+            label_binary = (label_np > 0.5).astype(np.float32)
+
+            # Prediction processing: convert to binary mask
+            pred_np = prediction[0].numpy() if prediction.shape[0] == 1 else prediction.numpy()
+            pred_binary = (pred_np > 0.5).astype(np.float32)
+
+            # Create pure black and white label image
+            label_bw = (label_binary * 255).astype(np.uint8)
+            label_bw_3ch = np.stack([label_bw] * 3, axis=-1)  # Convert to 3 channels for concatenation
+
+            # Create pure black and white prediction image
+            pred_bw = (pred_binary * 255).astype(np.uint8)
+            pred_bw_3ch = np.stack([pred_bw] * 3, axis=-1)  # Convert to 3 channels for concatenation
+
+            # Horizontally concatenate three images: original, ground truth B&W, prediction B&W
+            combined = np.hstack([image_rgb, label_bw_3ch, pred_bw_3ch])
+
+            # Add text annotations (at the top of the image)
+            h, w = image_rgb.shape[:2]
+            # Calculate text width for centering
+            orig_text_size = cv2.getTextSize('Original', cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
+            gt_text_size = cv2.getTextSize('Ground Truth', cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
+            pred_text_size = cv2.getTextSize('Prediction', cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
+
+            combined = cv2.putText(combined, 'Original', ((w - orig_text_size[0]) // 2, 30),
+                                   cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
+            combined = cv2.putText(combined, 'Ground Truth', (w + (w - gt_text_size[0]) // 2, 30),
+                                   cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
+            combined = cv2.putText(combined, 'Prediction', (2 * w + (w - pred_text_size[0]) // 2, 30),
+                                   cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
+
+            # Save
+            sample_idx = batch_idx * images.shape[0] + i
+            save_path = os.path.join(save_dir, f'sample_{sample_idx:04d}.png')
+            cv2.imwrite(save_path, cv2.cvtColor(combined, cv2.COLOR_RGB2BGR))
+
+        except Exception as e:
+            print(f"⚠️ Error saving sample {i}: {e}")
+            continue
+
+
+def print_results(results, dataset_name, checkpoint_path, model):
+    """
+    Print evaluation results
+    
+    Args:
+        results: Evaluation results dictionary
+        dataset_name: Dataset name
+        checkpoint_path: Model checkpoint path
+        model: Model
+    """
+    print("\n" + "=" * 60)
+    print("Evaluation Results")
+    print("=" * 60)
+    print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")
+    print(f"Dataset: {dataset_name}")
+    print(f"Model: {checkpoint_path}")
+    print(f"Number of samples: {results['total_samples']}")
+    print("-" * 60)
+    print(f"mDice (Mean Dice Coefficient):     {results['mDice']:.3f}")
+    print(f"mIoU (Mean Intersection over Union): {results['mIoU']:.3f}")
+    print(f"mHD (Mean Hausdorff Distance):       {results['mHD']:.3f}")
+    print(f"mHD95 (95% Hausdorff Distance):      {results['mHD95']:.3f}")
+    print("=" * 60)
+
+
+def save_results(results, dataset_name, checkpoint_path, results_dir, model):
+    """
+    Save evaluation results to file
+    
+    Args:
+        results: Evaluation results dictionary
+        dataset_name: Dataset name
+        checkpoint_path: Model checkpoint path
+        results_dir: Results save directory
+        model: Model
+    """
+    os.makedirs(results_dir, exist_ok=True)
+
+    # Generate filename
+    result_file = os.path.join(results_dir, f"eval_{dataset_name}.txt")
+
+    with open(result_file, 'w', encoding='utf-8') as f:
+        f.write("=" * 60 + "\n")
+        f.write("Polyp Segmentation Model Evaluation Report\n")
+        f.write("=" * 60 + "\n\n")
+        f.write(f"Model parameters: {sum(p.numel() for p in model.parameters())}\n")
+        f.write(f"Dataset: {dataset_name}\n")
+        f.write(f"Model checkpoint: {checkpoint_path}\n")
+        f.write(f"Evaluation time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
+        f.write(f"Number of samples: {results['total_samples']}\n\n")
+        f.write("-" * 60 + "\n")
+        f.write("Evaluation Metrics:\n")
+        f.write("-" * 60 + "\n")
+        f.write(f"mDice (Mean Dice Coefficient):     {results['mDice']:.3f}\n")
+        f.write(f"mIoU (Mean Intersection over Union): {results['mIoU']:.3f}\n")
+        f.write(f"mHD (Mean Hausdorff Distance):       {results['mHD']:.3f}\n")
+        f.write(f"mHD95 (95% Hausdorff Distance):      {results['mHD95']:.3f}\n")
+        f.write("-" * 60 + "\n")
+
+    print(f"\n✓ Evaluation results saved to: {result_file}")
+
+
+def main():
+    """
+    Main evaluation function
+    """
+    # ==================== Step 1: Parse arguments ====================
+    args = parse_args()
+    checkpoint_path = Path(args.outputs_dir + args.dir_flag) / f"best_dice_model_{args.dataset_name}.pt"
+    if args.best_metric:
+        checkpoint_path = Path(args.outputs_dir + args.dir_flag) / f"best_metric_model_{args.dataset_name}.pt"
+    print("\n" + "=" * 60)
+    print("Polyp Segmentation Model Evaluation")
+    print("=" * 60)
+    print(f"Dataset: {args.dataset_name}")
+    print(f"Model checkpoint: {checkpoint_path}")
+
+    # ==================== Step 2: Create validation set and data loader ====================
+    print("\nLoading validation set...")
+
+    val_transform = create_val_transform(args.target_spatial_size)
+
+    val_dataset = PolypDetectionDataset(
+        root_dir=Path(args.data_root) / args.dataset_name,
+        flag='val',
+        transform=val_transform
+    )
+
+    val_loader = create_dataloader(args, val_dataset)
+
+    print(f"✓ Validation set size: {len(val_dataset)} samples")
+    print(f"✓ Data loader: {len(val_loader)} batches")
+
+    # ==================== Step 3: Load model ====================
+    model = load_model(args, checkpoint_path)
+
+    # ==================== Step 4: Evaluate model ====================
+    results = evaluate_model(model, val_loader, args)
+
+    # ==================== Step 5: Print and save results ====================
+    print_results(results, args.dataset_name, checkpoint_path, model)
+
+    if args.save_results:
+        save_results(results, args.dataset_name, checkpoint_path, args.results_dir + args.dir_flag, model)
+
+    print("\n" + "=" * 60)
+    print("Evaluation completed!")
+    print("=" * 60)
+
+
+if __name__ == "__main__":
+    main()

+ 9 - 0
lib/__init__.py

@@ -0,0 +1,9 @@
+from lib.model.model import Wavelet_FFT_SwinUNETR
+from lib.modules.awae import AdaptiveWaveletAugmentedEnhancer
+from lib.modules.frcab import FFTRCAB
+
+__all__ = [
+    'Wavelet_FFT_SwinUNETR',
+    'AdaptiveWaveletAugmentedEnhancer',
+    'FFTRCAB'
+]

+ 232 - 0
lib/model/model.py

@@ -0,0 +1,232 @@
+from typing import List, Union
+
+import torch
+from monai.networks.nets import SwinUNETR
+from torch import Tensor
+from torch import nn
+
+from lib.modules.awae import AdaptiveWaveletAugmentedEnhancer
+from lib.modules.frcab import FFTRCAB
+
+
+class Wavelet_FFT_SwinUNETR(SwinUNETR):
+    """SwinUNETR with Wavelet and FFT enhancement modules."""
+    
+    def __init__(
+            self,
+            in_channels=3,
+            out_channels=2,
+            patch_size=2,
+            depths=(2, 2, 2, 2),
+            num_heads=(3, 6, 12, 24),
+            window_size=7,
+            qkv_bias=True,
+            mlp_ratio=4.0,
+            feature_size=48,
+            norm_name="instance",
+            drop_rate=0.0,
+            attn_drop_rate=0.0,
+            dropout_path_rate=0.0,
+            normalize=True,
+            norm_layer=nn.LayerNorm,
+            patch_norm=False,
+            use_checkpoint=False,
+            spatial_dims=2,
+            downsample="merging",
+            use_v2=True,
+            wavelet_enhancement=True,
+            wavelet_J=2,
+            wavelet_wave="db4",
+            wavelet_mode="symmetric",
+            wavelet_reduction=16,
+            fft_enhancement=True,
+    ):
+        """Initialize model.
+        
+        Args:
+            in_channels: Number of input channels
+            out_channels: Number of output channels
+            feature_size: Base feature dimension
+            spatial_dims: Spatial dimensions (2D or 3D)
+            wavelet_enhancement: Enable wavelet enhancement module
+            wavelet_J: Wavelet decomposition levels
+            wavelet_wave: Wavelet basis type
+            wavelet_mode: Wavelet mode
+            wavelet_reduction: Wavelet attention compression ratio
+            fft_enhancement: Enable FFT enhancement module
+        """
+
+        # Initialize parent SwinUNETR class
+        super().__init__(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            patch_size=patch_size,
+            feature_size=feature_size,
+            depths=depths,
+            num_heads=num_heads,
+            window_size=window_size,
+            qkv_bias=qkv_bias,
+            mlp_ratio=mlp_ratio,
+            norm_name=norm_name,
+            drop_rate=drop_rate,
+            attn_drop_rate=attn_drop_rate,
+            dropout_path_rate=dropout_path_rate,
+            normalize=normalize,
+            norm_layer=norm_layer,
+            patch_norm=patch_norm,
+            use_checkpoint=use_checkpoint,
+            spatial_dims=spatial_dims,
+            downsample=downsample,
+            use_v2=use_v2,
+        )
+
+        # Initialize wavelet enhancement modules for different feature levels
+        # Initialize wavelet enhancement modules for different feature levels
+        self.wavelet_enhancement = wavelet_enhancement
+        if self.wavelet_enhancement:
+            self.wavelet_enhancer_f0 = AdaptiveWaveletAugmentedEnhancer(
+                in_channels=feature_size,
+                J=wavelet_J,
+                wave=wavelet_wave,
+                mode=wavelet_mode,
+                reduction_ratio=wavelet_reduction,
+            )
+
+            self.wavelet_enhancer_f1 = AdaptiveWaveletAugmentedEnhancer(
+                in_channels=feature_size,
+                J=wavelet_J,
+                wave=wavelet_wave,
+                mode=wavelet_mode,
+                reduction_ratio=wavelet_reduction,
+            )
+
+            self.wavelet_enhancer_f2 = AdaptiveWaveletAugmentedEnhancer(
+                in_channels=2 * feature_size,
+                J=wavelet_J,
+                wave=wavelet_wave,
+                mode=wavelet_mode,
+                reduction_ratio=wavelet_reduction,
+            )
+
+            self.wavelet_enhancer_f3 = AdaptiveWaveletAugmentedEnhancer(
+                in_channels=4 * feature_size,
+                J=wavelet_J,
+                wave=wavelet_wave,
+                mode=wavelet_mode,
+                reduction_ratio=wavelet_reduction,
+            )
+
+            self.wavelet_enhancer_bottleneck = AdaptiveWaveletAugmentedEnhancer(
+                in_channels=16 * feature_size,
+                J=1,
+                wave=wavelet_wave,
+                mode=wavelet_mode,
+                reduction_ratio=wavelet_reduction,
+            )
+        # Initialize FFT enhancement modules for different decoder levels
+        # Initialize FFT enhancement modules for different decoder levels
+        self.fft_enhancement = fft_enhancement
+        if self.fft_enhancement:
+            self.fft_enhancer_f1 = FFTRCAB(feature_size)
+
+            self.fft_enhancer_f2 = FFTRCAB(2 * feature_size)
+
+            self.fft_enhancer_f3 = FFTRCAB(4 * feature_size)
+
+            self.fft_enhancer_f4 = FFTRCAB(8 * feature_size)
+
+    def forward(self, x: Tensor) -> Union[Tensor, List[Tensor]]:
+        """Forward pass.
+        
+        Args:
+            x: Input tensor [B, C, H, W]
+            
+        Returns:
+            Output logits
+        """
+
+        # Check input size compatibility
+        if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+            self._check_input_size(x.shape[2:])
+
+        # Extract multiscale features from Swin ViT backbone
+        hidden_states_out = self.swinViT(x, self.normalize)
+
+        if self.wavelet_enhancement:
+
+            enc0 = self.wavelet_enhancer_f0(self.encoder1(x))
+            enc1 = self.wavelet_enhancer_f1(self.encoder2(hidden_states_out[0]))
+            enc2 = self.wavelet_enhancer_f2(self.encoder3(hidden_states_out[1]))
+            enc3 = self.wavelet_enhancer_f3(self.encoder4(hidden_states_out[2]))
+        else:
+            # Use standard encoder features without enhancement
+            enc0 = self.encoder1(x)
+            enc1 = self.encoder2(hidden_states_out[0])
+            enc2 = self.encoder3(hidden_states_out[1])
+            enc3 = self.encoder4(hidden_states_out[2])
+
+        # Process bottleneck features
+        dec4 = self.encoder10(hidden_states_out[4])
+        if self.wavelet_enhancement:
+            dec4 = self.wavelet_enhancer_bottleneck(dec4)
+
+        if self.fft_enhancement:
+
+            dec3 = self.decoder5(dec4, hidden_states_out[3])
+            dec3 = self.fft_enhancer_f4(dec3)
+            dec2 = self.decoder4(dec3, enc3)
+            dec2 = self.fft_enhancer_f3(dec2)
+            dec1 = self.decoder3(dec2, enc2)
+            dec1 = self.fft_enhancer_f2(dec1)
+            dec0 = self.decoder2(dec1, enc1)
+            dec0 = self.fft_enhancer_f1(dec0)
+            out = self.decoder1(dec0, enc0)
+        else:
+            # Standard decoding without FFT enhancement
+
+            dec3 = self.decoder5(dec4, hidden_states_out[3])
+            dec2 = self.decoder4(dec3, enc3)
+            dec1 = self.decoder3(dec2, enc2)
+            dec0 = self.decoder2(dec1, enc1)
+            out = self.decoder1(dec0, enc0)
+
+        # Generate final output logits
+        logits = self.out(out)
+
+        return logits
+
+
+if __name__ == "__main__":
+    image = torch.randn(1, 3, 512, 512)
+    model = Wavelet_FFT_SwinUNETR(
+        in_channels=3,
+        out_channels=1,
+        patch_size=2,
+        depths=(2, 2, 2, 2),
+        num_heads=(3, 6, 12, 24),
+        window_size=7,
+        qkv_bias=True,
+        mlp_ratio=4.0,
+        feature_size=48,
+        norm_name="instance",
+        drop_rate=0.0,
+        attn_drop_rate=0.0,
+        dropout_path_rate=0.0,
+        normalize=True,
+        norm_layer=nn.LayerNorm,
+        patch_norm=False,
+        use_checkpoint=False,
+        spatial_dims=2,
+        downsample="merging",
+        use_v2=True,
+        wavelet_enhancement=True,
+        wavelet_J=2,
+        wavelet_wave="db4",
+        wavelet_mode="symmetric",
+        wavelet_reduction=16,
+        fft_enhancement=True,
+    )
+    hidden_states_out = model.swinViT(image, normalize=True)
+    print("hidden_states_out:", [i.shape for i in hidden_states_out])
+    print(model(image).shape)
+    print("total parameters: ", sum(p.numel() for p in model.parameters()))

+ 150 - 0
lib/modules/awae.py

@@ -0,0 +1,150 @@
+import torch
+import torch.nn as nn
+from pytorch_wavelets import DWT, IDWT
+
+
+class AdaptiveWaveletAttention(nn.Module):
+    """Adaptive wavelet attention module for feature enhancement."""
+
+    def __init__(
+            self, in_channels: int, reduction_ratio: int = 4, init_bias: float = 0.2
+    ):
+        """Initialize attention module.
+        
+        Args:
+            in_channels: Number of input channels
+            reduction_ratio: Channel compression ratio
+            init_bias: Initial bias value
+        """
+        super().__init__()
+        # Ensure safe reduction ratio
+        safe_reduction = max(1, min(reduction_ratio, in_channels))
+        # Channel attention branch with squeeze-and-excitation
+        self.channel_attention = nn.Sequential(
+            nn.AdaptiveAvgPool2d(1),
+            nn.Conv2d(in_channels, in_channels // safe_reduction, 1, bias=False),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(in_channels // safe_reduction, in_channels, 1, bias=False),
+            nn.Sigmoid(),
+        )
+
+        # Learnable bias scale parameter
+        self.bias_scale = nn.Parameter(torch.tensor(init_bias))
+        # Fusion gate for adaptive enhancement
+        self.fusion_gate = nn.Sequential(
+            nn.AdaptiveAvgPool2d(1),
+            nn.Conv2d(in_channels, in_channels // safe_reduction, 1, bias=False),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(in_channels // safe_reduction, in_channels, 1, bias=False),
+            nn.Sigmoid(),
+        )
+        self._init_weight()
+
+    def _init_weight(self):
+        """Initialize weights using Kaiming initialization."""
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+                if m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Forward pass with adaptive wavelet enhancement.
+        
+        Args:
+            x: Input tensor [B, C, H, W]
+            
+        Returns:
+            Enhanced tensor with same shape
+        """
+        # Compute base channel attention weights
+        # Compute base channel attention weights
+        base_weight = self.channel_attention(x)
+        # Compute adaptive gate factor
+        gate_factor = self.fusion_gate(x)
+        enhanced_factor = 1.0 + (self.bias_scale * gate_factor)
+        final_weight = base_weight * enhanced_factor
+        return x * torch.clamp(final_weight, min=0.1)
+
+
+class AdaptiveWaveletAugmentedEnhancer(nn.Module):
+    """Adaptive wavelet augmented enhancer with multi-level attention."""
+
+    def __init__(
+            self,
+            in_channels: int,
+            J: int = 1,
+            wave: str = "db4",
+            mode: str = "symmetric",
+            reduction_ratio: int = 4,
+    ):
+        """Initialize wavelet enhancer.
+        
+        Args:
+            in_channels: Number of input channels
+            J: Wavelet decomposition levels
+            wave: Wavelet basis type
+            mode: Wavelet padding mode
+            reduction_ratio: Attention compression ratio
+        """
+        super().__init__()
+        # Validate decomposition levels
+        assert 1 <= J <= 3, "J must be in [1, 3]"
+        self.J = J
+        # Initialize discrete wavelet transform
+        self.dwt = DWT(J=J, wave=wave, mode=mode)
+        self.idwt = IDWT(wave=wave, mode=mode)
+
+        # Low-frequency attention for approximation coefficients
+        self.ll_att = AdaptiveWaveletAttention(
+            in_channels=in_channels, reduction_ratio=reduction_ratio, init_bias=0.2
+        )
+
+        # High-frequency attention for detail coefficients at each level
+        self.yh_att = nn.ModuleList(
+            [
+                AdaptiveWaveletAttention(
+                    in_channels=in_channels,
+                    reduction_ratio=reduction_ratio,
+                    init_bias=0.4,
+                )
+                for _ in range(J)
+            ]
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Forward pass with multi-level wavelet enhancement.
+        
+        Args:
+            x: Input tensor [B, C, H, W]
+            
+        Returns:
+            Enhanced tensor after inverse wavelet transform
+        """
+        # Perform wavelet decomposition
+        # Perform wavelet decomposition
+        yl, yh = self.dwt(x)
+        # Enhance low-frequency approximation
+        yl_enhanced = self.ll_att(yl)
+        # Enhance high-frequency details at each level
+        yh_enhanced = []
+        for i in range(self.J):
+            level_features = []
+            for j in range(3):  # LH, HL, HH subbands
+                subband = yh[i][:, :, j, :, :]
+                enhanced_subband = self.yh_att[i](subband)
+                level_features.append(enhanced_subband)
+            yh_enhanced.append(torch.stack(level_features, dim=2))
+        # Reconstruct enhanced signal via inverse wavelet transform
+        return self.idwt((yl_enhanced, yh_enhanced))
+
+
+if __name__ == "__main__":
+    input_tensor = torch.randn(1, 64, 256, 256)
+    wavelet_enhancer = AdaptiveWaveletAugmentedEnhancer(in_channels=64, J=2)
+    enhanced_tensor = wavelet_enhancer(input_tensor)
+    print("input shape:", input_tensor.shape)
+    print("output shape:", enhanced_tensor.shape)

+ 64 - 0
lib/modules/frcab.py

@@ -0,0 +1,64 @@
+import torch
+import torch.nn as nn
+
+
+class FFTRCAB(nn.Module):
+
+    def __init__(self, dim):
+        super(FFTRCAB, self).__init__()
+
+        self.CBG3x3 = nn.Sequential(
+            nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=False),
+            nn.LeakyReLU(0.1, inplace=True),
+            nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=False),
+        )
+
+        self.avg_pool = nn.AdaptiveAvgPool2d(1)
+
+        self.xc_aEnhance = nn.Sequential(
+            nn.Conv2d(dim, dim // 2 + 1, 1, 1, 0),
+            nn.LeakyReLU(0.1, inplace=True),
+            nn.Conv2d(dim // 2 + 1, dim // 2 + 1, 1, 1, 0),
+        )
+
+        self.xc_pEnhance = nn.Sequential(
+            nn.Conv2d(dim, dim // 2 + 1, 1, 1, 0),
+            nn.LeakyReLU(0.1, inplace=True),
+            nn.Conv2d(dim // 2 + 1, dim // 2 + 1, 1, 1, 0),
+        )
+
+    def forward(self, x):
+        x_conv = self.CBG3x3(x)
+        x_conv = x_conv.to(torch.float32)
+
+        x_pool = self.avg_pool(x_conv)
+
+        xc_a = self.xc_aEnhance(x_pool)
+        xc_p = self.xc_pEnhance(x_pool)
+
+        x_fft = torch.fft.rfft2(x_pool, dim=1, norm="ortho")
+        x_a = torch.abs(x_fft)
+        x_p = torch.angle(x_fft)
+
+        xa_enh = x_a * xc_a
+        xp_enh = x_p * xc_p
+
+        xa = xa_enh * torch.cos(xp_enh)
+        xp = xa_enh * torch.sin(xp_enh)
+        x_comp = torch.complex(xa, xp)
+
+        xc = torch.fft.irfft2(x_comp, dim=1, norm="ortho")
+
+        x_out = x_conv * xc
+
+        return x_out + x
+
+
+if __name__ == "__main__":
+    input_tensor = torch.randn(3, 64, 128, 128)
+
+    fft_rcab = FFTRCAB(64)
+
+    output_tensor = fft_rcab(input_tensor)
+
+    print(output_tensor.shape)

+ 0 - 0
lib/tools/__init__.py


+ 256 - 0
lib/tools/combined_loss.py

@@ -0,0 +1,256 @@
+from __future__ import annotations
+
+import warnings
+from collections.abc import Callable, Sequence
+from typing import Any
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+# noinspection PyProtectedMember
+from torch.nn.modules.loss import _Loss
+
+from monai.losses.utils import compute_tp_fp_fn
+from monai.networks import one_hot
+from monai.utils import LossReduction, Weight, look_up_option
+from monai.losses import DiceLoss, DiceCELoss, HausdorffDTLoss
+
+
+class IoULoss(_Loss):
+    """
+    Compute average IoU (Intersection over Union) loss between two tensors.
+    
+    IoU Loss 直接优化交并比指标,相比 Dice Loss 对边界误差更敏感,
+    能够提供更强的梯度信号用于训练。
+    """
+
+    def __init__(
+            self,
+            include_background: bool = True,
+            to_onehot_y: bool = False,
+            sigmoid: bool = False,
+            softmax: bool = False,
+            other_act: Callable | None = None,
+            squared_pred: bool = False,
+            reduction: LossReduction | str = LossReduction.MEAN,
+            smooth_nr: float = 1e-5,
+            smooth_dr: float = 1e-5,
+            batch: bool = False,
+            weight: Sequence[float] | float | int | torch.Tensor | None = None,
+            soft_label: bool = False,
+    ) -> None:
+        """
+        Args:
+            include_background: if False, channel index 0 (background category) is excluded from the calculation.
+                如果非背景区域相对于整个图像较小,排除背景可以帮助收敛。
+            to_onehot_y: whether to convert the ``target`` into the one-hot format,
+                using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
+            sigmoid: if True, apply a sigmoid function to the prediction.
+            softmax: if True, apply a softmax function to the prediction.
+            other_act: callable function to execute other activation layers, Defaults to ``None``. 
+                for example: ``other_act = torch.tanh``.
+            squared_pred: use squared versions of targets and predictions in the denominator or not.
+                使用平方可以加强大误差区域的惩罚。
+            reduction: {``"none"``, ``"mean"``, ``"sum"``}
+                Specifies the reduction to apply to the output. Defaults to ``"mean"``.
+                - ``"none"``: no reduction will be applied.
+                - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
+                - ``"sum"``: the output will be summed.
+            smooth_nr: a small constant added to the numerator to avoid zero.
+            smooth_dr: a small constant added to the denominator to avoid nan.
+            batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
+                Defaults to False, 每个 batch item 独立计算损失后再 reduction。
+            weight: weights to apply to the voxels of each class. If None no weights are applied.
+                The input can be a single value (same weight for all classes), a sequence of values (the length
+                of the sequence should be the same as the number of classes. If not ``include_background``,
+                the number of classes should not include the background category class 0).
+                The value/values should be no less than 0. Defaults to None.
+            soft_label: whether the target contains non-binary values (soft labels) or not.
+                If True a soft label formulation of the loss will be used.
+
+        Raises:
+            TypeError: When ``other_act`` is not an ``Optional[Callable]``.
+            ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
+                Incompatible values.
+
+        Example:
+            >>> import torch
+            >>> from lib.modules.iou_loss import IoULoss
+            >>> B, C, H, W = 7, 5, 3, 2
+            >>> y_input = torch.rand(B, C, H, W)
+            >>> target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long()
+            >>> target = one_hot(target_idx[:, None, ...], num_classes=C)
+            >>> loss_fn = IoULoss(reduction='none')
+            >>> loss = loss_fn(y_input, target)
+        """
+        super().__init__(reduction=LossReduction(reduction).value)
+        if other_act is not None and not callable(other_act):
+            raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.")
+        if int(sigmoid) + int(softmax) + int(other_act is not None) > 1:
+            raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].")
+
+        self.include_background = include_background
+        self.to_onehot_y = to_onehot_y
+        self.sigmoid = sigmoid
+        self.softmax = softmax
+        self.other_act = other_act
+        self.squared_pred = squared_pred
+        self.smooth_nr = float(smooth_nr)
+        self.smooth_dr = float(smooth_dr)
+        self.batch = batch
+        weight = torch.as_tensor(weight) if weight is not None else None
+        self.register_buffer("class_weight", weight)
+        self.class_weight: None | torch.Tensor
+        self.soft_label = soft_label
+
+    def forward(self, y_input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+        """
+        Args:
+            y_input: the shape should be BNH[WD], where N is the number of classes.
+            target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.
+
+        Raises:
+            AssertionError: When input and target (after one hot transform if set) have different shapes.
+            ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
+        """
+        if self.sigmoid:
+            y_input = torch.sigmoid(y_input)
+
+        n_pred_ch = y_input.shape[1]
+        if self.softmax:
+            if n_pred_ch == 1:
+                warnings.warn("single channel prediction, `softmax=True` ignored.")
+            else:
+                y_input = torch.softmax(y_input, 1)
+
+        if self.other_act is not None:
+            y_input = self.other_act(y_input)
+
+        if self.to_onehot_y:
+            if n_pred_ch == 1:
+                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
+            else:
+                target = one_hot(target, num_classes=n_pred_ch)
+
+        if not self.include_background:
+            if n_pred_ch == 1:
+                warnings.warn("single channel prediction, `include_background=False` ignored.")
+            else:
+                # if skipping background, removing first channel
+                target = target[:, 1:]
+                y_input = y_input[:, 1:]
+
+        if target.shape != y_input.shape:
+            raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({y_input.shape})")
+
+        # reducing only spatial dimensions (not batch nor channels)
+        reduce_axis: list[int] = torch.arange(2, len(y_input.shape)).tolist()
+        if self.batch:
+            # reducing spatial dimensions and batch
+            reduce_axis = [0] + reduce_axis
+
+        y_ord = 2 if self.squared_pred else 1
+        tp, fp, fn = compute_tp_fp_fn(y_input, target, reduce_axis, y_ord, self.soft_label)
+
+        # IoU 的核心公式:IoU = TP / (TP + FP + FN)
+        # 注意:与 Dice 不同,IoU 的分子没有系数 2,分母也少了 TP
+        numerator = tp + self.smooth_nr
+        denominator = tp + fp + fn + self.smooth_dr
+
+        iou: torch.Tensor = numerator / denominator
+        loss: torch.Tensor = 1 - iou  # IoU Loss = 1 - IoU
+
+        num_of_classes = target.shape[1]
+        if self.class_weight is not None and num_of_classes != 1:
+            # make sure the lengths of weights are equal to the number of classes
+            if self.class_weight.ndim == 0:
+                # noinspection PyAttributeOutsideInit
+                self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes)
+            else:
+                if self.class_weight.shape[0] != num_of_classes:
+                    raise ValueError(
+                        """the length of the `weight` sequence should be the same as the number of classes.
+                        If `include_background=False`, the weight should not include
+                        the background category class 0."""
+                    )
+            if self.class_weight.min() < 0:
+                raise ValueError("the value/values of the `weight` should be no less than 0.")
+            # apply class_weight to loss
+            loss = loss * self.class_weight.to(loss)
+
+        if self.reduction == LossReduction.MEAN.value:
+            loss = torch.mean(loss)  # the batch and channel average
+        elif self.reduction == LossReduction.SUM.value:
+            loss = torch.sum(loss)  # sum over the batch and channel dims
+        elif self.reduction == LossReduction.NONE.value:
+            # If we are not computing voxelwise loss components at least
+            # make sure a none reduction maintains a broadcastable shape
+            broadcast_shape = list(loss.shape[0:2]) + [1] * (len(y_input.shape) - 2)
+            loss = loss.view(broadcast_shape)
+        else:
+            raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
+
+        return loss
+
+
+class CombinedDiceCEIoULoss(_Loss):
+    """
+    四合一组合损失:DiceCE + IoU
+    - DiceCE: Dice + CrossEntropy (全局 + 局部)
+    - IoU: 交并比优化
+    """
+
+    def __init__(
+            self,
+            dice_weight: float = 1.0,
+            ce_weight: float = 1.0,
+            iou_weight: float = 1.0,
+            include_background: bool = True,
+            to_onehot_y: bool = False,
+            softmax: bool = False,
+            sigmoid: bool = False,
+    ):
+        super().__init__()
+
+        self.include_background = include_background
+        self.to_onehot_y = to_onehot_y
+        self.softmax = softmax
+
+        # DiceBCE Loss (Dice + BCE)
+        self.dice_ce_loss = DiceCELoss(
+            include_background=include_background,
+            to_onehot_y=to_onehot_y,
+            softmax=softmax,
+            sigmoid=sigmoid,
+            lambda_dice=dice_weight / (dice_weight + ce_weight),
+            lambda_ce=ce_weight / (dice_weight + ce_weight),
+            reduction="mean",
+        )
+        # IoU Loss
+        self.iou_loss = IoULoss(
+            include_background=include_background,
+            to_onehot_y=to_onehot_y,
+            softmax=softmax,
+            sigmoid=sigmoid,
+            reduction="mean",
+        )
+
+        # 外部权重
+        self.dice_ce_weight = dice_weight + ce_weight
+        self.iou_weight = iou_weight
+
+    def forward(self, y_input: torch.Tensor, target: torch.Tensor) -> tuple[float | Any, Any, Any]:
+        # DiceCE Loss
+        dice_ce_loss = self.dice_ce_loss(y_input, target)
+
+        # IoU Loss
+        iou_loss = self.iou_loss(y_input, target)
+
+        # 加权组合
+        total_loss = (
+                self.dice_ce_weight * dice_ce_loss +
+                self.iou_weight * iou_loss
+        )
+
+        return total_loss, dice_ce_loss, iou_loss

+ 397 - 0
lib/tools/loss.py

@@ -0,0 +1,397 @@
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.fft as fft
+from monai.losses import DiceLoss, DiceCELoss, HausdorffDTLoss
+
+
+class FocalFrequencyLoss(nn.Module):
+    """
+    焦点频域损失函数 (Focal Frequency Loss)
+    
+    核心思想:
+        传统的空间域损失(如 MSE、Dice)主要关注像素级别的差异,而频域损失通过傅里叶变换
+        将图像转换到频率域,从频率角度衡量预测图像与真实图像的差异。
+        
+        该损失的创新点在于"焦点"机制:
+        1. 自动计算频谱权重矩阵,对不同频率成分赋予不同的重要性
+        2. 对难以重建的频率成分给予更高权重(类似 Focal Loss 的思想)
+        3. 可以捕捉图像的全局结构和纹理细节,弥补空间域损失的不足
+    
+    适用场景:
+        - 医学图像分割:增强边缘和纹理的恢复
+        - 图像超分辨率:重建高频细节
+        - 图像去噪/去模糊:平衡低频和高频信息
+        
+    参数说明:
+        loss_weight: 损失权重系数,用于平衡该损失与其他损失的重要性
+        alpha: 频谱权重的幂次参数,控制权重分布的陡峭程度
+               alpha 越大,困难频率成分的权重越突出
+        patch_factor: 图像分块因子,将图像分成多个小块分别进行 FFT
+                     值为 1 表示不分块,对整个图像做 FFT
+                     值大于 1 时,将图像分成 patch_factor×patch_factor 个小块
+        ave_spectrum: 是否对 batch 内的频谱进行平均,用于减少 batch 内差异
+        log_matrix: 是否对频谱差异取对数,用于压缩动态范围
+        batch_matrix: 权重归一化方式
+                     True: 在整个 batch 范围内归一化到 [0,1]
+                     False: 对每张图像单独归一化
+    """
+
+    def __init__(self, loss_weight=1.0, alpha=1.0, patch_factor=1, ave_spectrum=False, log_matrix=False,
+                 batch_matrix=False):
+        """
+        初始化焦点频域损失函数的所有超参数
+        
+        Args:
+            loss_weight (float): 损失权重,默认 1.0
+            alpha (float): 频谱权重指数,默认 1.0
+            patch_factor (int): 图像分块因子,默认 1(不分块)
+            ave_spectrum (bool): 是否对 batch 频谱平均,默认 False
+            log_matrix (bool): 是否使用对数矩阵,默认 False
+            batch_matrix (bool): 是否使用 batch 级矩阵,默认 False
+        """
+        super(FocalFrequencyLoss, self).__init__()
+        self.loss_weight = loss_weight
+        self.alpha = alpha
+        self.patch_factor = patch_factor
+        self.ave_spectrum = ave_spectrum
+        self.log_matrix = log_matrix
+        self.batch_matrix = batch_matrix
+
+    def tensor2freq(self, x):
+        """
+        将空间域图像张量转换为频域表示
+        
+        工作原理:
+            1. 如果 patch_factor > 1,先将图像分割成多个小块
+            2. 对每个小块执行 2D 快速傅里叶变换 (FFT)
+            3. 将复数形式的 FFT 结果分解为实部和虚部
+        
+        傅里叶变换的物理意义:
+            - 低频成分:对应图像的平滑区域和整体轮廓
+            - 高频成分:对应图像的边缘、纹理和噪声
+            - 通过分析频谱,可以分离和处理不同频率的特征
+        
+        Args:
+            x (torch.Tensor): 输入图像张量,形状为 (N, C, H, W)
+                             N=batch_size, C=channels, H=height, W=width
+        
+        Returns:
+            freq (torch.Tensor): 频域表示,形状为 (N, patch_factor², C, H/pf, W/pf, 2)
+                                最后一维的 2 个通道分别是 [实部,虚部]
+                                patch_factor² 表示分成了多少个小块
+        
+        Example:
+            输入:x.shape = (4, 1, 256, 256), patch_factor=4
+            输出:freq.shape = (4, 16, 1, 64, 64, 2)
+                - 16 = 4×4 个小块
+                - 64×64 = 每个小块的尺寸
+                - 2 = [实部,虚部]
+        """
+        # 获取分块因子
+        patch_factor = self.patch_factor
+
+        # 获取输入图像的尺寸信息
+        _, _, h, w = x.shape
+
+        # 断言检查:确保图像尺寸可以被 patch_factor 整除
+        # 这是为了保证分块时每个小块大小一致,避免边界问题
+        assert h % patch_factor == 0 and w % patch_factor == 0, (
+            'Patch factor should be divisible by image height and width')
+
+        # 初始化列表用于存储所有小块的频域表示
+        patch_list = []
+
+        # 计算每个小块的高度和宽度
+        # 例如:原图 256×256, patch_factor=4 → 每个小块 64×64
+        patch_h = h // patch_factor
+        patch_w = w // patch_factor
+
+        # 双重循环遍历所有小块
+        # i 控制垂直方向的索引,j 控制水平方向的索引
+        for i in range(patch_factor):
+            for j in range(patch_factor):
+                # 切片操作:提取第 (i,j) 个小块
+                # 垂直方向:从 i*patch_h 到 (i+1)*patch_h
+                # 水平方向:从 j*patch_w 到 (j+1)*patch_w
+                # [:, :, ...] 保持 batch 和 channel 维度不变
+                patch_list.append(x[:, :, i * patch_h:(i + 1) * patch_h, j * patch_w:(j + 1) * patch_w])
+
+        # 将所有小块堆叠成一个新的维度
+        # 原始形状:(N, C, patch_h, patch_w) 的列表,长度为 patch_factor²
+        # 堆叠后形状:(N, patch_factor², C, patch_h, patch_w)
+        # dim=1 表示在第 1 个维度(channel 之后)插入新的分块维度
+        y = torch.stack(patch_list, 1)
+
+        # 对每个小块执行 2D 快速傅里叶变换
+        # torch.fft.fft2: 计算二维离散傅里叶变换
+        # norm='ortho': 使用正交归一化,保证变换前后能量守恒
+        # 变换结果是一个复数张量,包含每个频率成分的振幅和相位信息
+        freq = torch.fft.fft2(y, norm='ortho')
+
+        # 将复数形式转换为实数表示,方便后续神经网络处理
+        # torch.stack([freq.real, freq.imag], -1): 将实部和虚部堆叠到最后一个维度
+        # freq.real: 复数的实部,代表余弦分量的幅度
+        # freq.imag: 复数的虚部,代表正弦分量的幅度
+        # 最终形状:(N, patch_factor², C, patch_h, patch_w, 2)
+        freq = torch.stack([freq.real, freq.imag], -1)
+
+        return freq
+
+    def loss_formulation(self, recon_freq, real_freq, matrix=None):
+        """
+        构建并计算焦点频域损失的核心公式
+        
+        核心思想:
+            传统频域损失直接计算预测频谱与真实频谱的距离(如 MSE)。
+            本方法引入"焦点"机制:根据每个频率成分的重建难度动态调整权重。
+            
+            权重计算逻辑:
+                1. 如果未提供预定义权重矩阵,则在线计算动态权重
+                2. 重建误差大的频率 → 高权重(重点关注)
+                3. 重建误差小的频率 → 低权重(减少关注)
+            这类似于 Focal Loss 中"关注难例"的思想
+        
+        Args:
+            recon_freq (torch.Tensor): 重建(预测)图像的频域表示
+                                      形状:(N, P, C, H, W, 2),P=patch_factor²
+            real_freq (torch.Tensor): 真实(目标)图像的频域表示
+                                    形状:(N, P, C, H, W, 2)
+            matrix (torch.Tensor, optional): 预定义的频谱权重矩阵
+                                           如果为 None,则动态计算权重
+        
+        Returns:
+            loss (torch.Tensor): 标量损失值
+        
+        工作流程:
+            Step 1: 确定权重矩阵(预定义或动态计算)
+            Step 2: 计算频谱距离(复数空间中的欧氏距离)
+            Step 3: 加权求和得到最终损失
+        """
+
+        # ==================== Step 1: 确定权重矩阵 ====================
+        if matrix is not None:
+            # 情况 A: 使用预定义的权重矩阵
+            # 这种模式允许外部指定固定的频率权重,适用于某些先验知识已知的场景
+            # 例如:人工指定某些频率更重要,或者使用其他算法计算的权重
+            weight_matrix = matrix.detach()
+            # .detach() 确保权重矩阵不参与梯度反向传播
+            # 这样权重是固定的,不会影响梯度的流动
+        else:
+            # 情况 B: 动态计算自适应权重矩阵(推荐模式)
+
+            # --- 子步骤 1: 计算初步的频谱差异 ---
+            # 逐元素计算预测频谱与真实频谱的差值的平方
+            # 这是一个复数差的平方,需要分别处理实部和虚部
+            # 形状:(N, P, C, H, W, 2)
+            matrix_tmp = (recon_freq - real_freq) ** 2
+
+            # --- 子步骤 2: 计算频谱幅度差异 ---
+            # 复数的模长公式:|a + bi| = sqrt(a² + b²)
+            # 这里计算的是每个频率成分的预测误差的幅度
+            # [..., 0]: 实部的平方, [..., 1]: 虚部的平方
+            # torch.sqrt(...): 开平方得到欧几里得距离
+            # ** self.alpha: 应用幂次变换,alpha 控制权重的分布特性
+            #   - alpha > 1: 放大差异,使大误差更突出
+            #   - alpha < 1: 缩小差异,使权重分布更均匀
+            #   - alpha = 1: 线性关系(默认情况)
+            matrix_tmp = torch.sqrt(matrix_tmp[..., 0] + matrix_tmp[..., 1]) ** self.alpha
+
+            # --- 子步骤 3: 可选的对数变换 ---
+            if self.log_matrix:
+                # 对数变换的作用:压缩动态范围
+                # 当频谱差异的范围很大时(几个数量级),对数变换可以防止某些频率主导损失
+                # log(x + 1.0): 加 1 是为了避免 log(0) 的数值不稳定
+                matrix_tmp = torch.log(matrix_tmp + 1.0)
+
+            # --- 子步骤 4: 归一化权重到 [0, 1] 范围 ---
+            # 归一化的目的是确保权重有统一的尺度,便于控制和解释
+            if self.batch_matrix:
+                # 模式 A: Batch 级归一化
+                # 在整个 batch 的所有像素、所有频率上找最大值,然后统一归一化
+                # 优点:batch 内所有样本的权重在同一尺度
+                # 缺点:可能掩盖样本间的差异
+                matrix_tmp = matrix_tmp / matrix_tmp.max()
+            else:
+                # 模式 B: 样本级归一化(推荐)
+                # 对每个样本单独归一化,保持样本间的相对差异
+                # .max(-1).values: 沿最后一个维度(实部/虚部维度)取最大值
+                # .max(-1).values[:, :, :, None, None]: 再沿空间维度取最大值
+                # [:, :, :, None, None]: 添加维度以保持广播兼容性
+                # 最终每张图片的权重矩阵独立归一化到 [0,1]
+                matrix_tmp = matrix_tmp / matrix_tmp.max(-1).values.max(-1).values[:, :, :, None, None]
+
+            # --- 子步骤 5: 数值稳定性处理 ---
+            # 处理可能出现的 NaN 值(例如 0/0 的情况)
+            # 将 NaN 替换为 0,表示这些位置不参与加权
+            matrix_tmp[torch.isnan(matrix_tmp)] = 0.0
+
+            # --- 子步骤 6: 截断到合法范围 ---
+            # 确保所有权值都在 [0, 1] 区间内
+            # torch.clamp: 将小于 0 的值设为 0,大于 1 的值设为 1
+            # 这是防御性编程,防止数值溢出导致训练不稳定
+            matrix_tmp = torch.clamp(matrix_tmp, min=0.0, max=1.0)
+
+            # --- 子步骤 7: 创建最终的权重矩阵 ---
+            # .clone().detach() 创建副本并断开梯度连接
+            # 这样权重矩阵在本次前向传播中是固定的,不会自我影响
+            # 这是关键设计:权重基于当前误差计算,但不参与本次的梯度回传
+            weight_matrix = matrix_tmp.clone().detach()
+
+        # ==================== 权重矩阵有效性验证 ====================
+        # 断言检查:确保权重矩阵的所有值都在 [0, 1] 范围内
+        # 这是一个安全检查,帮助调试时发现问题
+        # .item() 将标量张量转换为 Python 浮点数,方便打印
+        assert weight_matrix.min().item() >= 0 and weight_matrix.max().item() <= 1, (
+                'The values of spectrum weight matrix should be in the range [0, 1], '
+                'but got Min: %.10f Max: %.10f' % (weight_matrix.min().item(), weight_matrix.max().item()))
+
+        # ==================== Step 2: 计算频谱距离 ====================
+        # 计算预测频谱与真实频谱的逐元素差异的平方
+        # 与前面计算 matrix_tmp 的第一步相同
+        # 形状:(N, P, C, H, W, 2)
+        tmp = (recon_freq - real_freq) ** 2
+
+        # 将实部和虚部的平方和相加,得到复数空间中的欧几里得距离平方
+        # 这就是每个频率成分的重建误差
+        # 形状:(N, P, C, H, W)
+        freq_distance = tmp[..., 0] + tmp[..., 1]
+
+        # ==================== Step 3: 加权求和得到损失 ====================
+        # 逐元素相乘:权重矩阵 × 频谱距离
+        # 效果:
+        #   - 权重高的频率(重建困难)→ 损失贡献大 → 梯度大 → 模型重点关注
+        #   - 权重低的频率(重建简单)→ 损失贡献小 → 梯度小 → 模型次要关注
+        # 形状:(N, P, C, H, W)
+        loss = weight_matrix * freq_distance
+
+        # 对所有维度取平均,得到标量损失值
+        # torch.mean(): 将整个张量压缩成一个标量
+        # 这样得到的损失可以直接用于反向传播
+        return torch.mean(loss)
+
+    def forward(self, pred, target, matrix=None):
+        """
+        焦点频域损失的前向传播计算
+        
+        这是损失函数的主要入口,当调用 loss_function(pred, target) 时执行此方法
+        
+        Args:
+            pred (torch.Tensor): 预测的图像张量
+                                形状:(N, C, H, W)
+                                N: batch size(批次大小)
+                                C: channels(通道数,对于灰度图 C=1,RGB 图 C=3)
+                                H: height(图像高度)
+                                W: width(图像宽度)
+            
+            target (torch.Tensor): 目标的图像张量(真实标签)
+                                  形状:(N, C, H, W),必须与 pred 完全相同
+            
+            matrix (torch.Tensor, optional): 预定义的频谱权重矩阵
+                                           如果提供,则使用该固定权重而非动态计算
+                                           默认:None(动态计算自适应权重)
+        
+        Returns:
+            torch.Tensor: 标量损失值(0 维张量),可以直接用于反向传播
+        
+        完整计算流程:
+            ┌─────────────────┐
+            │ 输入:pred, target │
+            └────────┬──────────┘
+                     │
+                     ▼
+            ┌─────────────────┐
+            │ Step 1: 傅里叶变换 │  tensor2freq()
+            │ pred → pred_freq  │  空间域 → 频域
+            │ target → target_freq│
+            └────────┬──────────┘
+                     │
+                     ▼
+            ┌─────────────────┐
+            │ Step 2: 可选的平均 │  if ave_spectrum
+            │ 对 batch 维度平均   │  减少样本间差异
+            └────────┬──────────┘
+                     │
+                     ▼
+            ┌─────────────────┐
+            │ Step 3: 计算损失  │  loss_formulation()
+            │ 动态权重 × 频谱距离 │  焦点机制核心
+            └────────┬──────────┘
+                     │
+                     ▼
+            ┌─────────────────┐
+            │ Step 4: 应用权重  │  × loss_weight
+            │ 返回最终损失值    │
+            └─────────────────┘
+        
+        物理意义解释:
+            1. 傅里叶变换:将图像从"像素空间"转换到"频率空间"
+               - 像素空间:关注每个点的亮度值
+               - 频率空间:关注图像的周期性模式(纹理、边缘、轮廓)
+            
+            2. 频谱比较:在频率空间中衡量预测与真实的差异
+               - 低频误差:反映整体结构的偏差
+               - 高频误差:反映细节纹理的偏差
+            
+            3. 焦点权重:自动识别并强调难以重建的频率成分
+               - 这是与传统频域损失(如简单的频谱 MSE)的关键区别
+               - 类似于注意力机制,让模型"聚焦"于困难频率
+        
+        """
+
+        # ==================== Step 1: 将预测图像转换为频域 ====================
+        # 调用 tensor2freq 方法对预测图像进行傅里叶变换
+        # 将空间域的像素表示转换为频域的频谱表示
+        # pred_freq 包含了预测图像在各个频率上的振幅和相位信息
+        pred_freq = self.tensor2freq(pred)
+
+        # ==================== Step 2: 将目标图像转换为频域 ====================
+        # 同样对真实标签图像进行傅里叶变换
+        # 这样我们就可以在频率空间中比较预测与真实的差异
+        target_freq = self.tensor2freq(target)
+
+        # ==================== Step 3: 可选的 Batch 频谱平均 ====================
+        if self.ave_spectrum:
+            # 如果启用了 ave_spectrum 选项,对 batch 维度(第 0 维)取平均
+            # keepdim=True 保持维度数量不变,只是将第 0 维的大小设为 1
+            # 
+            # 这个操作的效果:
+            #   - 原始形状:(N, P, C, H, W, 2)
+            #   - 平均后:(1, P, C, H, W, 2)
+            #   
+            # 为什么要这样做?
+            #   1. 减少 batch 内样本间的随机波动
+            #   2. 计算一个"平均频谱"作为代表
+            #   3. 在某些任务中可以提高训练稳定性
+            #   
+            # 注意:这会改变损失的语义,从"逐个样本的损失"变成"batch 级别的损失"
+            pred_freq = torch.mean(pred_freq, 0, keepdim=True)
+            target_freq = torch.mean(target_freq, 0, keepdim=True)
+
+        # ==================== Step 4: 计算最终的焦点频域损失 ====================
+        # 调用 loss_formulation 方法计算加权后的频域损失
+        # 该方法会:
+        #   1. 动态计算频谱权重矩阵(如果 matrix=None)
+        #   2. 计算预测频谱与真实频谱的距离
+        #   3. 用权重矩阵对距离加权,得到最终损失
+        #
+        # 返回值是一个标量张量,表示整个 batch 的平均损失
+        loss_value = self.loss_formulation(pred_freq, target_freq, matrix)
+
+        # ==================== Step 5: 应用损失权重系数 ====================
+        # 将计算得到的损失乘以预设的权重系数 loss_weight
+        # 
+        # 这个参数的作用:
+        #   - 在多损失联合训练时,平衡不同损失的重要性
+        #   - 例如:总损失 = 1.0 * DiceLoss + 0.1 * FocalFrequencyLoss
+        #   - 这样可以让频域损失作为辅助损失,不会主导训练过程
+        #
+        # 为什么需要这样做?
+        #   不同的损失函数量级可能差异很大
+        #   DiceLoss 可能在 0-1 之间,而频域损失可能在 0-100 之间
+        #   通过调整 loss_weight,可以确保各个损失在同一数量级
+
+        final_loss = loss_value * self.loss_weight
+
+        return final_loss

+ 1053 - 0
train.py

@@ -0,0 +1,1053 @@
+import argparse
+import os
+import time
+from datetime import datetime
+from pathlib import Path
+
+import monai
+import monai.utils
+import swanlab
+import torch
+from lib.model.model_v4_minute import Wavelet_FFT_SwinUNETR
+from monai.metrics import DiceMetric, MeanIoU, HausdorffDistanceMetric
+from monai.transforms import (
+    Compose, LoadImaged, ScaleIntensityd, RandFlipd, RandRotated, RandRotate90d,
+    EnsureChannelFirstd, ToTensord, Resized, Lambdad, RandZoomd, RandShiftIntensityd, RandGaussianNoised,
+    RandGaussianSmoothd, RandAdjustContrastd, RandHistogramShiftd,
+    RandAxisFlipd, RandCoarseDropoutd,
+)
+from torch.optim import AdamW
+from torch.optim.lr_scheduler import ReduceLROnPlateau
+
+from datasets.PolypDetectionDataset.PolypDetectionDataset import PolypDetectionDataset
+from lib.tools.combined_loss import CombinedDiceCEIoULoss
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description="Polyp Segmentation Model Training Script")
+
+    # ==================== Dataset-related Parameters ====================
+    parser.add_argument(
+        "--dataset_name",
+        type=str,
+        required=True,
+        help="Dataset name"
+    )
+    parser.add_argument(
+        "--data_root",
+        type=str,
+        default=r"./data/Polyp-Detection-Dataset",
+        help="Root directory path of the dataset"
+    )
+    parser.add_argument(
+        "--num_workers",
+        type=int,
+        default=0,
+        help="Number of worker processes for data loader"
+    )
+    parser.add_argument(
+        "--pin_memory",
+        type=bool,
+        default=True,
+        help="Whether to enable pinned memory"
+    )
+    parser.add_argument(
+        "--target_spatial_size",
+        type=tuple,
+        default=(512, 512),
+        help="Target spatial size"
+    )
+    parser.add_argument(
+        "--dataset_enhanced",
+        type=bool,
+        default=True,
+        help="Whether to use enhanced data augmentation strategy"
+    )
+
+    # ==================== Model-related Parameters ====================
+    parser.add_argument(
+        "--in_channels",
+        type=int,
+        default=3,
+        help="Number of input image channels"
+    )
+    parser.add_argument(
+        "--out_channels",
+        type=int,
+        default=1,
+        help="Number of output foreground channels"
+    )
+    parser.add_argument(
+        "--feature_size",
+        type=int,
+        default=48,
+        help="Network feature dimension"
+    )
+    parser.add_argument(
+        "--spatial_dims",
+        type=int,
+        default=2,
+        choices=[2, 3],
+        help="Spatial dimension (2D or 3D)"
+    )
+    parser.add_argument(
+        "--no_wavelet",
+        action="store_false",
+        dest="use_wavelet",
+        help="Whether to enable wavelet enhancement module"
+    )
+    parser.add_argument(
+        "--wavelet_J",
+        type=int,
+        default=2,
+        help="Wavelet decomposition levels"
+    )
+    parser.add_argument(
+        "--wavelet_wave",
+        type=str,
+        default="db4",
+        help="Wavelet basis type"
+    )
+    parser.add_argument(
+        "--wavelet_reduction",
+        type=int,
+        default=16,
+        help="Wavelet attention compression ratio"
+    )
+    parser.add_argument(
+        "--no_fft",
+        action="store_false",
+        dest="use_fft",
+        help="Whether to enable FFT enhancement module"
+    )
+    parser.add_argument(
+        "--use_v2",
+        type=bool,
+        default=True,
+        help="Whether to enable Swin-UNETR v2 module"
+    )
+
+    # ==================== Training-related Parameters ====================
+    parser.add_argument(
+        "--max_epochs",
+        type=int,
+        default=1000,
+        help="Maximum number of training epochs"
+    )
+    parser.add_argument(
+        "--batch_size",
+        type=int,
+        default=4,
+        help="Batch size"
+    )
+    parser.add_argument(
+        "--learning_rate",
+        type=float,
+        default=1e-4,
+        help="Learning rate"
+    )
+    parser.add_argument(
+        "--weight_decay",
+        type=float,
+        default=1e-4,
+        help="Weight decay coefficient"
+    )
+
+    # ==================== Loss Function Parameters ====================
+
+    parser.add_argument(
+        "--dice_weight",
+        type=float,
+        default=1.0,
+        help="Dice loss weight"
+    )
+    parser.add_argument(
+        "--ce_weight",
+        type=float,
+        default=1.0,
+        help="Cross Entropy loss weight"
+    )
+    parser.add_argument(
+        "--iou_weight",
+        type=float,
+        default=1.0,
+        help="IoU loss weight"
+    )
+
+    # ==================== SwanLab Parameters ====================
+    parser.add_argument(
+        "--swanlab_project",
+        type=str,
+        default="polyp-segmentation-v4_minute",
+        help="SwanLab project name"
+    )
+    parser.add_argument(
+        "--swanlab_experiment",
+        type=str,
+        default=None,
+        help="SwanLab experiment name (default uses timestamp)"
+    )
+    parser.add_argument(
+        "--swanlab_log_dir",
+        type=str,
+        default="./swanlab_log",
+        help="SwanLab log directory"
+    )
+
+    # ==================== Saving and Loading Parameters ====================
+    parser.add_argument(
+        "--output_dir",
+        type=str,
+        default="./outputs_v4_minute",
+        help="Directory for saving model checkpoints"
+    )
+    parser.add_argument(
+        "--save_every",
+        type=int,
+        default=50,
+        help="Save model every N epochs"
+    )
+    # ==================== Early Stopping Parameters ====================
+    parser.add_argument(
+        "--early_stopping",
+        type=bool,
+        default=True,
+        help="Whether to enable early stopping"
+    )
+    parser.add_argument(
+        "--early_stopping_patience",
+        type=int,
+        default=100,
+        help="Early stopping patience (stop if validation metric doesn't improve for N rounds)"
+    )
+    parser.add_argument(
+        "--early_stopping_min_delta",
+        type=float,
+        default=1e-4,
+        help="Minimum improvement threshold (improvement below this value is considered no improvement)"
+    )
+    parser.add_argument(
+        "--early_stopping_monitor",
+        type=str,
+        default="dice",
+        choices=["dice", "iou", "metric", "loss"],
+        help="Metric to monitor for early stopping"
+    )
+    parser.add_argument(
+        "--resume",
+        type=str,
+        default=None,
+        help="Checkpoint path to resume training. If not specified, will automatically load the best Dice model (if exists)"
+    )
+    parser.add_argument(
+        "--no_auto_resume",
+        action="store_false",
+        dest="auto_resume",
+        help="Whether to enable auto-resume functionality (default loads best Dice model)"
+    )
+
+    # ==================== Other Parameters ====================
+    parser.add_argument(
+        "--device",
+        type=str,
+        default="cuda" if torch.cuda.is_available() else "cpu",
+        help="Training device (cuda or cpu)"
+    )
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="Random seed"
+    )
+
+    return parser.parse_args()
+
+
+def find_best_checkpoint(args):
+    """
+    Find the best checkpoint file
+
+    Args:
+        args: Command line arguments
+
+    Returns:
+        str or None: Best checkpoint path, returns None if not exists
+    """
+    # Find the best Dice model
+    best_dice_path = os.path.join(args.output_dir, f"best_dice_model_{args.dataset_name}.pt")
+    if os.path.exists(best_dice_path):
+        print(f"Found best Dice model: {best_dice_path}")
+        return best_dice_path
+
+    # Find latest checkpoint
+    checkpoint_dir = os.path.join(args.output_dir, f"checkpoints_{args.dataset_name}")
+    if os.path.exists(checkpoint_dir):
+        checkpoints = sorted(
+            [f for f in os.listdir(checkpoint_dir) if f.endswith('.pt')],
+            key=lambda x: int(x.split('epoch=')[1].split('.')[0]) if 'epoch=' in x else -1,
+            reverse=True
+        )
+        if checkpoints:
+            latest_checkpoint = os.path.join(checkpoint_dir, checkpoints[0])
+            print(f"Found latest checkpoint: {latest_checkpoint}")
+            return latest_checkpoint
+
+    # Find best overall model
+    best_metric_path = os.path.join(args.output_dir, f"best_metric_model_{args.dataset_name}.pt")
+    if os.path.exists(best_metric_path):
+        print(f"Found best overall model: {best_metric_path}")
+        return best_metric_path
+
+    return None
+
+
+def create_enhanced_transforms(target_spatial_size=(512, 512)):
+    """
+    Enhanced data augmentation strategy
+
+    Includes:
+    1. Geometric transformations: flip, rotation, scaling, cropping
+    2. Photometric transformations: brightness, contrast, gamma correction
+    3. Noise injection: Gaussian noise, low-resolution simulation
+    4. Regularization: Coarse Dropout
+    """
+
+    def convert_label_to_single_channel(label_tensor):
+        """Convert RGB labels to single-channel binary mask"""
+        single_channel = label_tensor[0:1, :, :]
+        binary_label = (single_channel > 127).float()
+        return binary_label
+
+    train_transforms = Compose([
+        # ========== Loading and Preprocessing ==========
+        LoadImaged(keys=["image", "label"]),
+        EnsureChannelFirstd(keys=["image", "label"]),
+        Lambdad(keys=["label"], func=convert_label_to_single_channel),
+
+        # ========== Spatial Transformations ==========
+        Resized(keys=["image", "label"], spatial_size=target_spatial_size,
+                mode=("bilinear", "nearest")),
+        ScaleIntensityd(keys=["image"]),
+
+        # --- Geometric Augmentation ---
+        # Random axis flip
+        RandAxisFlipd(keys=["image", "label"], prob=0.5),
+
+        # Random rotation (-15° to +15°)
+        RandRotated(keys=["image", "label"], range_x=0.15, prob=0.5,
+                    keep_size=True, mode=("bilinear", "nearest")),
+
+        # Random 90-degree rotation
+        RandRotate90d(keys=["image", "label"], prob=0.5, max_k=2),
+
+        # Random zoom (0.8-1.2x) + cropping
+        RandZoomd(keys=["image", "label"], min_zoom=0.8, max_zoom=1.2,
+                  prob=0.5, mode=("bilinear", "nearest"), keep_size=True),
+
+        # ========== Photometric Transformations ==========
+        # Random brightness adjustment (±20%)
+        RandShiftIntensityd(keys=["image"], offsets=(-0.2, 0.2), prob=0.5),
+
+        # Random contrast adjustment (gamma 0.7-1.3)
+        RandAdjustContrastd(keys=["image"], gamma=(0.7, 1.3), prob=0.5),
+
+        # Random histogram shift (simulate different staining/lighting conditions)
+        RandHistogramShiftd(keys=["image"], num_control_points=(5, 10),
+                            prob=0.3),
+
+        # ========== Noise and Quality Degradation ==========
+        # Random Gaussian smoothing (simulate blur)
+        RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.0),
+                            sigma_y=(0.5, 1.0), prob=0.3),
+
+        # Random Gaussian noise
+        RandGaussianNoised(keys=["image"], mean=0.0, std=0.05, prob=0.3),
+
+        # Coarse Dropout (occlusion augmentation, improve robustness)
+        RandCoarseDropoutd(
+            keys=["image"],
+            holes=1,
+            max_holes=3,
+            spatial_size=(32, 32),
+            max_spatial_size=(64, 64),
+            prob=0.3
+        ),
+
+        # ========== Post-processing ==========
+        ToTensord(keys=["image", "label"]),
+    ])
+
+    return train_transforms
+
+
+def create_datasets(args):
+    """
+    Create training and validation datasets
+
+    Args:
+        args: Command line arguments
+
+    Returns:
+        tuple: (train_dataset, val_dataset)
+    """
+    print("=" * 60)
+    print("正在加载数据集...")
+    print("=" * 60)
+
+    def convert_label_to_single_channel(label_tensor):
+        """
+        Global function: Convert 3-channel RGB labels to 1-channel binary labels (0 or 1)
+        Input: label_tensor (shape: [3, H, W], value range 0-255)
+        Output: new_tensor (shape: [1, H, W], value range 0 or 1)
+        """
+        # 1. Extract first channel (R channel)
+        single_channel = label_tensor[0:1, :, :]
+
+        # 2. Binarization: pixels greater than 0 are set to 1 (assuming background is pure black 0, polyp is white or other color)
+        # This ensures all pixel values can only be 0 or 1, meeting the requirements for out_channels=2
+        binary_label = (single_channel > 127).float()
+
+        return binary_label
+
+    # Define training set transformations
+    train_transforms = Compose([
+        LoadImaged(keys=["image", "label"]),
+        EnsureChannelFirstd(keys=["image", "label"]),
+        # Convert labels to single-channel (take first channel or convert to grayscale)
+        Lambdad(keys=["label"], func=convert_label_to_single_channel),
+        Resized(keys=["image", "label"], spatial_size=args.target_spatial_size, mode=("bilinear", "nearest")),
+        ScaleIntensityd(keys=["image"]),
+        RandFlipd(keys=["image", "label"], prob=0.5),
+        RandRotated(keys=["image", "label"], range_x=0.15, prob=0.5),
+    ])
+    if args.dataset_enhanced:
+        train_transforms = create_enhanced_transforms(args.target_spatial_size)
+        print("✓ 使用增强数据增强策略")
+
+    # Define validation set transformations
+    val_transforms = Compose([
+        LoadImaged(keys=["image", "label"]),
+        EnsureChannelFirstd(keys=["image", "label"]),
+        # Convert labels to single-channel (take first channel or convert to grayscale)
+        Lambdad(keys=["label"], func=convert_label_to_single_channel),
+        Resized(keys=["image", "label"], spatial_size=args.target_spatial_size, mode=("bilinear", "nearest")),
+        ScaleIntensityd(keys=["image"]),
+    ])
+
+    # Create training dataset
+    train_dataset = PolypDetectionDataset(
+        root_dir=Path(args.data_root) / args.dataset_name,
+        flag='train',
+        transform=train_transforms
+    )
+
+    # Create validation dataset
+    val_dataset = PolypDetectionDataset(
+        root_dir=Path(args.data_root) / args.dataset_name,
+        flag='val',
+        transform=val_transforms
+    )
+
+    print(f"✓ Training set size: {len(train_dataset)} samples")
+    print(f"✓ Validation set size: {len(val_dataset)} samples")
+    print(f"✓ Total samples: {len(train_dataset) + len(val_dataset)} samples")
+    print("=" * 60)
+
+    return train_dataset, val_dataset
+
+
+def create_dataloaders(args, train_dataset, val_dataset):
+    """
+    Create data loaders
+
+    Args:
+        args: Command line arguments
+        train_dataset: Training dataset
+        val_dataset: Validation dataset
+
+    Returns:
+        tuple: (train_loader, val_loader)
+    """
+    train_loader = monai.data.DataLoader(
+        train_dataset,
+        batch_size=args.batch_size,
+        shuffle=True,
+        num_workers=args.num_workers,
+        pin_memory=args.pin_memory,
+        drop_last=True
+    )
+
+    val_loader = monai.data.DataLoader(
+        val_dataset,
+        batch_size=args.batch_size,
+        shuffle=False,
+        num_workers=args.num_workers,
+        pin_memory=args.pin_memory,
+        drop_last=False
+    )
+
+    print(f"✓ Training data loader: {len(train_loader)} batches")
+    print(f"✓ Validation data loader: {len(val_loader)} batches")
+
+    return train_loader, val_loader
+
+
+def create_model(args):
+    """
+    Create the model
+
+    Args:
+        args: Command line arguments
+
+    Returns:
+        torch.nn.Module: Initialized model
+    """
+    print("\n" + "=" * 60)
+    print("Creating model...")
+
+    model = Wavelet_FFT_SwinUNETR(
+        in_channels=args.in_channels,
+        out_channels=args.out_channels,
+        feature_size=args.feature_size,
+        spatial_dims=args.spatial_dims,
+        wavelet_enhancement=args.use_wavelet,
+        wavelet_J=args.wavelet_J,
+        wavelet_wave=args.wavelet_wave,
+        wavelet_mode='symmetric',
+        wavelet_reduction=args.wavelet_reduction,
+        fft_enhancement=args.use_fft,
+        use_v2=args.use_v2
+    )
+
+    # Print model information
+    total_params = sum(p.numel() for p in model.parameters())
+    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+    print(f"\n✓ Total model parameters: {total_params:,}")
+    print(f"✓ Trainable parameters: {trainable_params:,}")
+    print(f"✓ Using device: {args.device}")
+    print("=" * 60)
+
+    return model
+
+
+def create_loss_function(args):
+    """
+    Create loss function
+
+    Args:
+        args: Command line arguments
+
+    Returns:
+        Callable: Loss function
+    """
+    loss_fn = CombinedDiceCEIoULoss(
+        dice_weight=args.dice_weight,
+        ce_weight=args.ce_weight,
+        iou_weight=args.iou_weight,
+        include_background=True,
+        to_onehot_y=False,
+        softmax=False,
+        sigmoid=True,
+    )
+    return loss_fn
+
+
+def create_optimizer(args, model):
+    """
+    Create optimizer
+
+    Args:
+        args: Command line arguments
+        model: Model
+
+    Returns:
+        Optimizer: Optimizer
+    """
+    optimizer = AdamW(
+        model.parameters(),
+        lr=args.learning_rate,
+        weight_decay=args.weight_decay
+    )
+    scheduler = ReduceLROnPlateau(
+        optimizer,
+        mode='min',  # 验证损失越小越好
+        factor=0.5,  # 每次乘以 0.5
+        patience=20,  # 20 个 epoch 不下降则降低 LR
+        threshold=1e-4,  # 最小变化阈值
+        cooldown=5,  # 降低 LR 后的冷却期
+        min_lr=1e-6  # 学习率下限
+    )
+    print(f"✓ Optimizer: AdamW")
+    print(f"  - Learning rate: {args.learning_rate}")
+    print(f"  - Weight decay: {args.weight_decay}")
+    print(f"✓ Scheduler: ReduceLROnPlateau")
+    print(f"  - Mode: {scheduler.mode}")
+    print(f"  - Decay factor: {scheduler.factor}")
+    print(f"  - Patience: {scheduler.patience}")
+    print(f"  - Minimum change threshold: {scheduler.threshold}")
+    print(f"  - Cooldown period: {scheduler.cooldown}")
+    print(f"  - Minimum learning rate: {scheduler.min_lrs}")
+
+    return optimizer, scheduler
+
+
+def setup_swanlab(args):
+    """
+    Configure SwanLab experiment tracking
+
+    Args:
+        args: Command line arguments
+
+    Returns:
+        swanlab.Run: SwanLab run object
+    """
+    # If experiment name is not specified, use timestamp
+    if args.swanlab_experiment is None:
+        args.swanlab_experiment = "v2_" + args.dataset_name + "_" + datetime.now().strftime("%Y%m%d_%H%M%S")
+
+    # Create log directory
+    os.makedirs(args.swanlab_log_dir, exist_ok=True)
+    os.makedirs(args.output_dir, exist_ok=True)
+
+    # Initialize SwanLab
+    run = swanlab.init(
+        project=args.swanlab_project,
+        experiment_name=args.swanlab_experiment,
+        logdir=args.swanlab_log_dir,
+        config=vars(args)
+    )
+
+    print(f"\n✓ SwanLab experiment initialized: {args.swanlab_experiment}")
+    print(f"  - Project: {args.swanlab_project}")
+    print(f"  - Log directory: {args.swanlab_log_dir}")
+
+    return run
+
+
+def main():
+    """
+    Main training function
+    """
+    # ==================== Step 1: Parse arguments ====================
+    args = parse_args()
+
+    # Set random seed for reproducibility
+    torch.manual_seed(args.seed)
+    if torch.cuda.is_available():
+        torch.cuda.manual_seed_all(args.seed)
+
+    print("\n" + "=" * 60)
+    print("Polyp Segmentation Model Training Started")
+    print("=" * 60)
+    print(f"Using device: {args.device}")
+    print(f"Batch size: {args.batch_size}")
+    print(f"Maximum epochs: {args.max_epochs}")
+    if args.early_stopping:
+        print(
+            f"Early stopping: enabled (patience={args.early_stopping_patience}, monitor={args.early_stopping_monitor})")
+
+    # ==================== Step 2: Initialize SwanLab ====================
+    run = setup_swanlab(args)
+
+    # ==================== Step 3: Create datasets and data loaders ====================
+    train_dataset, val_dataset = create_datasets(args)
+    train_loader, val_loader = create_dataloaders(args, train_dataset, val_dataset)
+
+    # ==================== Step 4: Create model, loss function, optimizer ====================
+    model = create_model(args)
+    model = model.to(args.device)
+
+    loss_function = create_loss_function(args)
+    optimizer, scheduler = create_optimizer(args, model)
+
+    # ==================== Step 5: Create evaluation metrics ====================
+    dice_metric = DiceMetric(reduction="mean")
+    iou_metric = MeanIoU(reduction="mean")
+    hd_metric = HausdorffDistanceMetric(reduction="mean")
+    # ==================== Step 6: Setup training loop ====================
+    best_dice = -1
+    best_dice_epoch = -1
+    best_metric = -1
+    best_metric_epoch = -1
+    best_iou = -1
+    best_iou_epoch = -1
+    epoch_loss_values = []
+
+    dice_metric_values = []
+    iou_metric_values = []
+    hd_metric_values = []
+    start_epoch = 0
+    # ==================== Early stopping related variables ====================
+    early_stopping_counter = 0
+    should_stop = False
+    has_restarted = False  # Flag indicating whether it has been restarted once
+
+    # ==================== Step 7: Resume training (if checkpoint exists) ====================
+    checkpoint_loaded = False
+    checkpoint = None
+    if args.resume:
+        # User specified checkpoint path
+        if not os.path.exists(args.resume):
+            raise FileNotFoundError(f"Checkpoint file does not exist: {args.resume}")
+
+        checkpoint_path = args.resume
+        print(f"\nResuming training from user-specified checkpoint: {checkpoint_path}")
+        checkpoint = torch.load(checkpoint_path, map_location=args.device)
+        checkpoint_loaded = True
+
+    elif args.auto_resume:
+        # Automatically find best checkpoint
+        checkpoint_path = find_best_checkpoint(args)
+        if checkpoint_path:
+            print(f"\nAuto-resume mode: loading {checkpoint_path}")
+            checkpoint = torch.load(checkpoint_path, map_location=args.device)
+            checkpoint_loaded = True
+        else:
+            print("\nNo checkpoints found, starting training from scratch")
+
+    if checkpoint_loaded:
+        # Load model weights - supports migration from v1 to v2
+        model_dict = model.state_dict()
+
+        # Try to load from checkpoint
+        try:
+            pretrained_dict = checkpoint["model_state_dict"]
+            print("✓ Model weights loaded from training checkpoint")
+        except KeyError:
+            pretrained_dict = checkpoint
+            print("Loading model weights from best Dice or best overall model")
+
+        # Filter and match parameters (handle structural changes from v1->v2)
+        matched_params = {}
+        unmatched_params = []
+        missing_params = []
+
+        for name, param in model_dict.items():
+            if name in pretrained_dict:
+                pretrained_param = pretrained_dict[name]
+                # Check if shape matches
+                if param.shape == pretrained_param.shape:
+                    matched_params[name] = pretrained_param
+                else:
+                    unmatched_params.append(f"{name} (shape mismatch: {param.shape} vs {pretrained_param.shape})")
+            else:
+                missing_params.append(name)
+
+        # Output loading statistics
+        print(f"\nWeight loading statistics:")
+        print(f"  ✓ Successfully matched parameters: {len(matched_params)}/{len(model_dict)}")
+        print(f"  ⚠ Shape mismatched parameters: {len(unmatched_params)}")
+        print(f"  ✗ Newly added parameters (randomly initialized): {len(missing_params)}")
+
+        if unmatched_params:
+            print(f"\nShape mismatched layers:")
+            for info in unmatched_params[:5]:  # Only show first 5
+                print(f"  - {info}")
+            if len(unmatched_params) > 5:
+                print(f"  ... {len(unmatched_params) - 5} more")
+
+        if missing_params:
+            print(f"\nNewly added layers (will be randomly initialized):")
+            for name in missing_params[:5]:  # Only show first 5
+                print(f"  - {name}")
+            if len(missing_params) > 5:
+                print(f"  ... {len(missing_params) - 5} more")
+
+        # Update pre-trained dictionary
+        model_dict.update(matched_params)
+
+        # Load matched parameters
+        model.load_state_dict(model_dict, strict=False)
+        print(f"\n✓ Model weights loaded (strict mode: False)")
+        print("=" * 60)
+
+        if "optimizer_state_dict" in checkpoint:
+            # Load optimizer state
+            optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
+            print("✓ Optimizer state loaded")
+
+        # Load epoch number
+        # if "epoch" in checkpoint:
+        #     start_epoch = checkpoint["epoch"] + 1  # Start from next epoch
+        #     print(f"✓ Training epoch restored to epoch {start_epoch}")
+
+        # Load best metrics
+        if "best_dice" in checkpoint:
+            best_dice = checkpoint["best_dice"]
+            best_dice_epoch = checkpoint["best_dice_epoch"]
+            print(f"✓ Best metrics restored: Dice={best_dice:.4f} (Epoch {best_dice_epoch})")
+
+        # Load historical loss and metric values (optional)
+        if "epoch_loss_values" in checkpoint:
+            epoch_loss_values = checkpoint["epoch_loss_values"]
+        if "dice_metric_values" in checkpoint:
+            dice_metric_values = checkpoint["dice_metric_values"]
+
+        # Load early stopping state
+        if args.early_stopping:
+            if "early_stopping_counter" in checkpoint:
+                early_stopping_counter = checkpoint["early_stopping_counter"]
+                print(f"✓ Early stopping counter restored: {early_stopping_counter}")
+            if "should_stop" in checkpoint and checkpoint["should_stop"]:
+                should_stop = False  # Even if marked as stopped, allow continued training
+                print("✓ Early stopping state reset, can continue training")
+
+        print(f"✓ Training will continue from epoch {start_epoch}")
+        print("=" * 60)
+
+    print("\n" + "=" * 60)
+    print("Starting training...")
+    print("=" * 60)
+
+    start_time = time.time()
+
+    try:
+        for epoch in range(start_epoch, run.config.max_epochs):
+            # ========== Check early stopping condition ==========
+            if should_stop and args.early_stopping:
+                print(f"\n{'=' * 60}")
+                print(f"Early stopping triggered! Training will terminate early at epoch {epoch + 1}")
+                print(f"{'=' * 60}")
+
+                if not has_restarted:
+                    # First early stopping: load best weights, restart training
+                    print("Early stopping detected, preparing to restart training from best model...")
+
+                    # 1. Find the best Dice model
+                    best_checkpoint_path = os.path.join(
+                        args.output_dir,
+                        f"best_dice_model_{args.dataset_name}.pt"
+                    )
+
+                    if os.path.exists(best_checkpoint_path):
+                        print(f"Loading best Dice model: {best_checkpoint_path}")
+                        checkpoint = torch.load(best_checkpoint_path, map_location=args.device)
+
+                        # 2. Load best weights
+                        model.load_state_dict(checkpoint)
+                        print("✓ Model weights restored to best state")
+
+                        # 3. Reset optimizer
+                        optimizer, scheduler = create_optimizer(args, model)
+                        print("✓ Optimizer has been reset")
+
+                        # 4. Reset early stopping counter
+                        early_stopping_counter = 0
+                        should_stop = False
+                        has_restarted = True
+
+                        print("✓ Training restarted from best model")
+                        print(f"{'=' * 60}\n")
+                        continue  # Skip break, continue to next epoch
+                    else:
+                        print(f"Warning: Best model file not found {best_checkpoint_path}")
+                        print("Will stop training directly")
+
+                # Second early stopping or best model not found: truly stop
+                print("Training has been restarted once after early stopping, now stopping training")
+                break
+
+            # ========== Training phase ==========
+            model.train()
+            step = 0
+            epoch_loss = 0
+            epoch_loss_dice_ce = 0
+            epoch_loss_iou = 0
+
+            for batch_data in train_loader:
+                step += 1
+                inputs = batch_data["image"].to(args.device)
+                targets = batch_data["label"].to(args.device)
+
+                optimizer.zero_grad()
+                outputs = model(inputs)
+                loss, loss_dice_ce, loss_iou = loss_function(outputs, targets)
+                loss.backward()
+                optimizer.step()
+
+                epoch_loss += loss.item()
+                epoch_loss_dice_ce += loss_dice_ce.item()
+                epoch_loss_iou += loss_iou.item()
+
+            # If this is the first epoch resumed from checkpoint, print notification
+            if epoch == start_epoch and start_epoch > 0:
+                print(f"\n✓ Training resumed from epoch {start_epoch}")
+            epoch_loss /= step
+            epoch_loss_dice_ce /= step
+            epoch_loss_iou /= step
+
+            epoch_loss_values.append(epoch_loss)
+            print(f"\nEpoch {epoch + 1}/{args.max_epochs} - Training loss: {epoch_loss:.4f}")
+            # Log to SwanLab
+            swanlab.log({
+                "train/loss": epoch_loss,
+                "train/loss_dice_ce": epoch_loss_dice_ce,
+                "train/loss_iou": epoch_loss_iou,
+                "train/lr": optimizer.param_groups[0]['lr'],
+            }, step=(epoch + 1))
+
+            # ========== Validation phase ==========
+            model.eval()
+            val_loss_total = 0
+            with torch.no_grad():
+                dice_metric.reset()
+                iou_metric.reset()
+                hd_metric.reset()
+
+                for val_data in val_loader:
+                    val_images = val_data["image"].to(args.device)
+                    val_labels = val_data["label"].to(args.device)
+
+                    val_outputs = model(val_images)
+                    # Calculate validation loss
+                    val_loss_batch, _, _ = loss_function(val_outputs, val_labels)
+                    val_loss_total += val_loss_batch.item()
+
+                    # Post-processing
+                    val_outputs = torch.sigmoid(val_outputs)
+                    val_outputs = (val_outputs > 0.5).int()
+
+                    # Calculate Dice score
+                    dice_metric(y_pred=val_outputs, y=val_labels)
+                    iou_metric(y_pred=val_outputs, y=val_labels)
+                    hd_metric(y_pred=val_outputs, y=val_labels)
+
+                # Calculate average validation loss
+                val_loss_avg = val_loss_total / len(val_loader)
+
+                # Update learning rate scheduler
+                scheduler.step(val_loss_avg)
+                current_lr = optimizer.param_groups[0]['lr']
+
+                # Aggregate results
+                mean_dice = dice_metric.aggregate().item()
+                dice_metric_values.append(mean_dice)
+                mean_iou = iou_metric.aggregate().item()
+                iou_metric_values.append(mean_iou)
+                mean_hd = hd_metric.aggregate().item()
+                hd_metric_values.append(mean_hd)
+
+                print(
+                    f"Epoch {epoch + 1} - Validation Dice: {mean_dice:.4f}, Validation loss: {val_loss_avg:.4f}, Current LR: {current_lr:.2e}")
+                swanlab.log({
+                    "val/loss": val_loss_avg,
+                    "val/mean_dice": mean_dice,
+                    "val/mean_iou": mean_iou,
+                    "val/mean_hd": mean_hd,
+                    "val/lr": current_lr,
+                }, step=(epoch + 1))
+
+                # ========== Early stopping check ==========
+                if args.early_stopping:
+                    # Get current monitored metric
+                    if args.early_stopping_monitor == "dice":
+                        current_score = mean_dice
+                        best_score = best_dice
+                        is_better = current_score > best_score + args.early_stopping_min_delta
+                    elif args.early_stopping_monitor == "iou":
+                        current_score = mean_iou
+                        best_score = best_iou
+                        is_better = current_score > best_score + args.early_stopping_min_delta
+                    elif args.early_stopping_monitor == "metric":
+                        normalized_hd = 1.0 / (1.0 + mean_hd)
+                        current_score = 1 * mean_dice + 1 * mean_iou + 1 * normalized_hd
+                        best_score = best_metric
+                        is_better = current_score > best_score + args.early_stopping_min_delta
+                    else:  # loss
+                        current_score = -val_loss_avg  # Lower loss is better, so take negative
+                        best_score = -min(epoch_loss_values) if epoch_loss_values else float('-inf')
+                        is_better = current_score > best_score + args.early_stopping_min_delta
+
+                    # Check if there is improvement
+                    if is_better:
+                        early_stopping_counter = 0
+                        print(
+                            f"  ✓ {args.early_stopping_monitor.upper()} metric improved: {current_score:.4f} > {best_score:.4f}")
+                    else:
+                        early_stopping_counter += 1
+                        print(
+                            f"  ⚠ {args.early_stopping_monitor.upper()} metric did not improve, counter: {early_stopping_counter}/{args.early_stopping_patience}")
+
+                        # Check if early stopping should be triggered
+                        if early_stopping_counter >= args.early_stopping_patience:
+                            should_stop = True
+
+                # Save best Dice model
+                if mean_dice > best_dice:
+                    best_dice = mean_dice
+                    best_dice_epoch = epoch + 1
+                    checkpoint_path = os.path.join(args.output_dir, f"best_dice_model_{args.dataset_name}.pt")
+                    Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
+                    torch.save(model.state_dict(), checkpoint_path)
+                    print(
+                        f"✓ Found better Dice model! Dice: {mean_dice:.4f}, IoU: {mean_iou:.4f}, HD: {mean_hd:.4f}, saved to {checkpoint_path}")
+                # Save best IoU model
+                if mean_iou > best_iou:
+                    best_iou = mean_iou
+                    best_iou_epoch = epoch + 1
+                    checkpoint_path = os.path.join(args.output_dir, f"best_iou_model_{args.dataset_name}.pt")
+                    Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
+                    torch.save(model.state_dict(), checkpoint_path)
+                    print(
+                        f"✓ Found better IoU model! IoU: {mean_iou:.4f}, Dice: {mean_dice:.4f}, HD: {mean_hd:.4f}, saved to {checkpoint_path}"
+                    )
+                # Save best overall model
+                normalized_hd = 1.0 / (1.0 + mean_hd)
+                mean_metric = (
+                        1 * mean_dice +
+                        1 * mean_iou +
+                        1 * normalized_hd
+                )
+                if mean_metric > best_metric:
+                    best_metric = mean_metric
+                    best_metric_epoch = epoch + 1
+                    checkpoint_path = os.path.join(args.output_dir, f"best_metric_model_{args.dataset_name}.pt")
+                    Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
+                    torch.save(model.state_dict(), checkpoint_path)
+                    print(
+                        f"✓ Found better overall model! Overall score: {mean_metric:.4f}, Dice: {mean_dice:.4f}, IoU: {mean_iou:.4f}, HD: {mean_hd:.4f}, saved to {checkpoint_path}"
+                    )
+            # Periodically save checkpoint
+            if (epoch + 1) % args.save_every == 0:
+                checkpoint_path = os.path.join(args.output_dir, f"checkpoints_{args.dataset_name}",
+                                               f"checkpoint_epoch={epoch}.pt")
+                Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
+                torch.save({
+                    "epoch": epoch,
+                    "model_state_dict": model.state_dict(),
+                    "optimizer_state_dict": optimizer.state_dict(),
+                    "best_dice": best_dice,
+                    "best_dice_epoch": best_dice_epoch,
+                    "epoch_loss_values": epoch_loss_values,
+                    "dice_metric_values": dice_metric_values,
+                    "iou_metric_values": iou_metric_values,
+                    "hd_metric_values": hd_metric_values,
+                    "best_metric": best_metric,
+                    "best_metric_epoch": best_metric_epoch,
+                    "best_iou": best_iou,
+                    "best_iou_epoch": best_iou_epoch,
+                    "early_stopping_counter": early_stopping_counter,
+                    "should_stop": should_stop
+                }, checkpoint_path)
+                print(f"✓ Checkpoint saved: {checkpoint_path}")
+
+    except KeyboardInterrupt:
+        print("\nTraining interrupted by user")
+    finally:
+        end_time = time.time()
+        training_time = end_time - start_time
+
+        print("\n" + "=" * 60)
+        print("Training completed!")
+        print(f"Total training time: {training_time / 3600:.2f} hours")
+        print(f"Best validation Dice: {best_dice:.4f} (Epoch {best_dice_epoch})")
+        print("=" * 60)
+
+        # Close SwanLab
+        swanlab.finish()
+        print("✓ SwanLab experiment saved")
+
+
+if __name__ == "__main__":
+    main()