|
|
@@ -0,0 +1,1053 @@
|
|
|
+import argparse
|
|
|
+import os
|
|
|
+import time
|
|
|
+from datetime import datetime
|
|
|
+from pathlib import Path
|
|
|
+
|
|
|
+import monai
|
|
|
+import monai.utils
|
|
|
+import swanlab
|
|
|
+import torch
|
|
|
+from lib.model.model_v4_minute import Wavelet_FFT_SwinUNETR
|
|
|
+from monai.metrics import DiceMetric, MeanIoU, HausdorffDistanceMetric
|
|
|
+from monai.transforms import (
|
|
|
+ Compose, LoadImaged, ScaleIntensityd, RandFlipd, RandRotated, RandRotate90d,
|
|
|
+ EnsureChannelFirstd, ToTensord, Resized, Lambdad, RandZoomd, RandShiftIntensityd, RandGaussianNoised,
|
|
|
+ RandGaussianSmoothd, RandAdjustContrastd, RandHistogramShiftd,
|
|
|
+ RandAxisFlipd, RandCoarseDropoutd,
|
|
|
+)
|
|
|
+from torch.optim import AdamW
|
|
|
+from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
|
+
|
|
|
+from datasets.PolypDetectionDataset.PolypDetectionDataset import PolypDetectionDataset
|
|
|
+from lib.tools.combined_loss import CombinedDiceCEIoULoss
|
|
|
+
|
|
|
+
|
|
|
+def parse_args():
|
|
|
+ parser = argparse.ArgumentParser(description="Polyp Segmentation Model Training Script")
|
|
|
+
|
|
|
+ # ==================== Dataset-related Parameters ====================
|
|
|
+ parser.add_argument(
|
|
|
+ "--dataset_name",
|
|
|
+ type=str,
|
|
|
+ required=True,
|
|
|
+ help="Dataset name"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--data_root",
|
|
|
+ type=str,
|
|
|
+ default=r"./data/Polyp-Detection-Dataset",
|
|
|
+ help="Root directory path of the dataset"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--num_workers",
|
|
|
+ type=int,
|
|
|
+ default=0,
|
|
|
+ help="Number of worker processes for data loader"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--pin_memory",
|
|
|
+ type=bool,
|
|
|
+ default=True,
|
|
|
+ help="Whether to enable pinned memory"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--target_spatial_size",
|
|
|
+ type=tuple,
|
|
|
+ default=(512, 512),
|
|
|
+ help="Target spatial size"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--dataset_enhanced",
|
|
|
+ type=bool,
|
|
|
+ default=True,
|
|
|
+ help="Whether to use enhanced data augmentation strategy"
|
|
|
+ )
|
|
|
+
|
|
|
+ # ==================== Model-related Parameters ====================
|
|
|
+ parser.add_argument(
|
|
|
+ "--in_channels",
|
|
|
+ type=int,
|
|
|
+ default=3,
|
|
|
+ help="Number of input image channels"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--out_channels",
|
|
|
+ type=int,
|
|
|
+ default=1,
|
|
|
+ help="Number of output foreground channels"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--feature_size",
|
|
|
+ type=int,
|
|
|
+ default=48,
|
|
|
+ help="Network feature dimension"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--spatial_dims",
|
|
|
+ type=int,
|
|
|
+ default=2,
|
|
|
+ choices=[2, 3],
|
|
|
+ help="Spatial dimension (2D or 3D)"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--no_wavelet",
|
|
|
+ action="store_false",
|
|
|
+ dest="use_wavelet",
|
|
|
+ help="Whether to enable wavelet enhancement module"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--wavelet_J",
|
|
|
+ type=int,
|
|
|
+ default=2,
|
|
|
+ help="Wavelet decomposition levels"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--wavelet_wave",
|
|
|
+ type=str,
|
|
|
+ default="db4",
|
|
|
+ help="Wavelet basis type"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--wavelet_reduction",
|
|
|
+ type=int,
|
|
|
+ default=16,
|
|
|
+ help="Wavelet attention compression ratio"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--no_fft",
|
|
|
+ action="store_false",
|
|
|
+ dest="use_fft",
|
|
|
+ help="Whether to enable FFT enhancement module"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--use_v2",
|
|
|
+ type=bool,
|
|
|
+ default=True,
|
|
|
+ help="Whether to enable Swin-UNETR v2 module"
|
|
|
+ )
|
|
|
+
|
|
|
+ # ==================== Training-related Parameters ====================
|
|
|
+ parser.add_argument(
|
|
|
+ "--max_epochs",
|
|
|
+ type=int,
|
|
|
+ default=1000,
|
|
|
+ help="Maximum number of training epochs"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--batch_size",
|
|
|
+ type=int,
|
|
|
+ default=4,
|
|
|
+ help="Batch size"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--learning_rate",
|
|
|
+ type=float,
|
|
|
+ default=1e-4,
|
|
|
+ help="Learning rate"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--weight_decay",
|
|
|
+ type=float,
|
|
|
+ default=1e-4,
|
|
|
+ help="Weight decay coefficient"
|
|
|
+ )
|
|
|
+
|
|
|
+ # ==================== Loss Function Parameters ====================
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--dice_weight",
|
|
|
+ type=float,
|
|
|
+ default=1.0,
|
|
|
+ help="Dice loss weight"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--ce_weight",
|
|
|
+ type=float,
|
|
|
+ default=1.0,
|
|
|
+ help="Cross Entropy loss weight"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--iou_weight",
|
|
|
+ type=float,
|
|
|
+ default=1.0,
|
|
|
+ help="IoU loss weight"
|
|
|
+ )
|
|
|
+
|
|
|
+ # ==================== SwanLab Parameters ====================
|
|
|
+ parser.add_argument(
|
|
|
+ "--swanlab_project",
|
|
|
+ type=str,
|
|
|
+ default="polyp-segmentation-v4_minute",
|
|
|
+ help="SwanLab project name"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--swanlab_experiment",
|
|
|
+ type=str,
|
|
|
+ default=None,
|
|
|
+ help="SwanLab experiment name (default uses timestamp)"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--swanlab_log_dir",
|
|
|
+ type=str,
|
|
|
+ default="./swanlab_log",
|
|
|
+ help="SwanLab log directory"
|
|
|
+ )
|
|
|
+
|
|
|
+ # ==================== Saving and Loading Parameters ====================
|
|
|
+ parser.add_argument(
|
|
|
+ "--output_dir",
|
|
|
+ type=str,
|
|
|
+ default="./outputs_v4_minute",
|
|
|
+ help="Directory for saving model checkpoints"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--save_every",
|
|
|
+ type=int,
|
|
|
+ default=50,
|
|
|
+ help="Save model every N epochs"
|
|
|
+ )
|
|
|
+ # ==================== Early Stopping Parameters ====================
|
|
|
+ parser.add_argument(
|
|
|
+ "--early_stopping",
|
|
|
+ type=bool,
|
|
|
+ default=True,
|
|
|
+ help="Whether to enable early stopping"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--early_stopping_patience",
|
|
|
+ type=int,
|
|
|
+ default=100,
|
|
|
+ help="Early stopping patience (stop if validation metric doesn't improve for N rounds)"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--early_stopping_min_delta",
|
|
|
+ type=float,
|
|
|
+ default=1e-4,
|
|
|
+ help="Minimum improvement threshold (improvement below this value is considered no improvement)"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--early_stopping_monitor",
|
|
|
+ type=str,
|
|
|
+ default="dice",
|
|
|
+ choices=["dice", "iou", "metric", "loss"],
|
|
|
+ help="Metric to monitor for early stopping"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--resume",
|
|
|
+ type=str,
|
|
|
+ default=None,
|
|
|
+ help="Checkpoint path to resume training. If not specified, will automatically load the best Dice model (if exists)"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--no_auto_resume",
|
|
|
+ action="store_false",
|
|
|
+ dest="auto_resume",
|
|
|
+ help="Whether to enable auto-resume functionality (default loads best Dice model)"
|
|
|
+ )
|
|
|
+
|
|
|
+ # ==================== Other Parameters ====================
|
|
|
+ parser.add_argument(
|
|
|
+ "--device",
|
|
|
+ type=str,
|
|
|
+ default="cuda" if torch.cuda.is_available() else "cpu",
|
|
|
+ help="Training device (cuda or cpu)"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--seed",
|
|
|
+ type=int,
|
|
|
+ default=42,
|
|
|
+ help="Random seed"
|
|
|
+ )
|
|
|
+
|
|
|
+ return parser.parse_args()
|
|
|
+
|
|
|
+
|
|
|
+def find_best_checkpoint(args):
|
|
|
+ """
|
|
|
+ Find the best checkpoint file
|
|
|
+
|
|
|
+ Args:
|
|
|
+ args: Command line arguments
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ str or None: Best checkpoint path, returns None if not exists
|
|
|
+ """
|
|
|
+ # Find the best Dice model
|
|
|
+ best_dice_path = os.path.join(args.output_dir, f"best_dice_model_{args.dataset_name}.pt")
|
|
|
+ if os.path.exists(best_dice_path):
|
|
|
+ print(f"Found best Dice model: {best_dice_path}")
|
|
|
+ return best_dice_path
|
|
|
+
|
|
|
+ # Find latest checkpoint
|
|
|
+ checkpoint_dir = os.path.join(args.output_dir, f"checkpoints_{args.dataset_name}")
|
|
|
+ if os.path.exists(checkpoint_dir):
|
|
|
+ checkpoints = sorted(
|
|
|
+ [f for f in os.listdir(checkpoint_dir) if f.endswith('.pt')],
|
|
|
+ key=lambda x: int(x.split('epoch=')[1].split('.')[0]) if 'epoch=' in x else -1,
|
|
|
+ reverse=True
|
|
|
+ )
|
|
|
+ if checkpoints:
|
|
|
+ latest_checkpoint = os.path.join(checkpoint_dir, checkpoints[0])
|
|
|
+ print(f"Found latest checkpoint: {latest_checkpoint}")
|
|
|
+ return latest_checkpoint
|
|
|
+
|
|
|
+ # Find best overall model
|
|
|
+ best_metric_path = os.path.join(args.output_dir, f"best_metric_model_{args.dataset_name}.pt")
|
|
|
+ if os.path.exists(best_metric_path):
|
|
|
+ print(f"Found best overall model: {best_metric_path}")
|
|
|
+ return best_metric_path
|
|
|
+
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+def create_enhanced_transforms(target_spatial_size=(512, 512)):
|
|
|
+ """
|
|
|
+ Enhanced data augmentation strategy
|
|
|
+
|
|
|
+ Includes:
|
|
|
+ 1. Geometric transformations: flip, rotation, scaling, cropping
|
|
|
+ 2. Photometric transformations: brightness, contrast, gamma correction
|
|
|
+ 3. Noise injection: Gaussian noise, low-resolution simulation
|
|
|
+ 4. Regularization: Coarse Dropout
|
|
|
+ """
|
|
|
+
|
|
|
+ def convert_label_to_single_channel(label_tensor):
|
|
|
+ """Convert RGB labels to single-channel binary mask"""
|
|
|
+ single_channel = label_tensor[0:1, :, :]
|
|
|
+ binary_label = (single_channel > 127).float()
|
|
|
+ return binary_label
|
|
|
+
|
|
|
+ train_transforms = Compose([
|
|
|
+ # ========== Loading and Preprocessing ==========
|
|
|
+ LoadImaged(keys=["image", "label"]),
|
|
|
+ EnsureChannelFirstd(keys=["image", "label"]),
|
|
|
+ Lambdad(keys=["label"], func=convert_label_to_single_channel),
|
|
|
+
|
|
|
+ # ========== Spatial Transformations ==========
|
|
|
+ Resized(keys=["image", "label"], spatial_size=target_spatial_size,
|
|
|
+ mode=("bilinear", "nearest")),
|
|
|
+ ScaleIntensityd(keys=["image"]),
|
|
|
+
|
|
|
+ # --- Geometric Augmentation ---
|
|
|
+ # Random axis flip
|
|
|
+ RandAxisFlipd(keys=["image", "label"], prob=0.5),
|
|
|
+
|
|
|
+ # Random rotation (-15° to +15°)
|
|
|
+ RandRotated(keys=["image", "label"], range_x=0.15, prob=0.5,
|
|
|
+ keep_size=True, mode=("bilinear", "nearest")),
|
|
|
+
|
|
|
+ # Random 90-degree rotation
|
|
|
+ RandRotate90d(keys=["image", "label"], prob=0.5, max_k=2),
|
|
|
+
|
|
|
+ # Random zoom (0.8-1.2x) + cropping
|
|
|
+ RandZoomd(keys=["image", "label"], min_zoom=0.8, max_zoom=1.2,
|
|
|
+ prob=0.5, mode=("bilinear", "nearest"), keep_size=True),
|
|
|
+
|
|
|
+ # ========== Photometric Transformations ==========
|
|
|
+ # Random brightness adjustment (±20%)
|
|
|
+ RandShiftIntensityd(keys=["image"], offsets=(-0.2, 0.2), prob=0.5),
|
|
|
+
|
|
|
+ # Random contrast adjustment (gamma 0.7-1.3)
|
|
|
+ RandAdjustContrastd(keys=["image"], gamma=(0.7, 1.3), prob=0.5),
|
|
|
+
|
|
|
+ # Random histogram shift (simulate different staining/lighting conditions)
|
|
|
+ RandHistogramShiftd(keys=["image"], num_control_points=(5, 10),
|
|
|
+ prob=0.3),
|
|
|
+
|
|
|
+ # ========== Noise and Quality Degradation ==========
|
|
|
+ # Random Gaussian smoothing (simulate blur)
|
|
|
+ RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.0),
|
|
|
+ sigma_y=(0.5, 1.0), prob=0.3),
|
|
|
+
|
|
|
+ # Random Gaussian noise
|
|
|
+ RandGaussianNoised(keys=["image"], mean=0.0, std=0.05, prob=0.3),
|
|
|
+
|
|
|
+ # Coarse Dropout (occlusion augmentation, improve robustness)
|
|
|
+ RandCoarseDropoutd(
|
|
|
+ keys=["image"],
|
|
|
+ holes=1,
|
|
|
+ max_holes=3,
|
|
|
+ spatial_size=(32, 32),
|
|
|
+ max_spatial_size=(64, 64),
|
|
|
+ prob=0.3
|
|
|
+ ),
|
|
|
+
|
|
|
+ # ========== Post-processing ==========
|
|
|
+ ToTensord(keys=["image", "label"]),
|
|
|
+ ])
|
|
|
+
|
|
|
+ return train_transforms
|
|
|
+
|
|
|
+
|
|
|
+def create_datasets(args):
|
|
|
+ """
|
|
|
+ Create training and validation datasets
|
|
|
+
|
|
|
+ Args:
|
|
|
+ args: Command line arguments
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ tuple: (train_dataset, val_dataset)
|
|
|
+ """
|
|
|
+ print("=" * 60)
|
|
|
+ print("正在加载数据集...")
|
|
|
+ print("=" * 60)
|
|
|
+
|
|
|
+ def convert_label_to_single_channel(label_tensor):
|
|
|
+ """
|
|
|
+ Global function: Convert 3-channel RGB labels to 1-channel binary labels (0 or 1)
|
|
|
+ Input: label_tensor (shape: [3, H, W], value range 0-255)
|
|
|
+ Output: new_tensor (shape: [1, H, W], value range 0 or 1)
|
|
|
+ """
|
|
|
+ # 1. Extract first channel (R channel)
|
|
|
+ single_channel = label_tensor[0:1, :, :]
|
|
|
+
|
|
|
+ # 2. Binarization: pixels greater than 0 are set to 1 (assuming background is pure black 0, polyp is white or other color)
|
|
|
+ # This ensures all pixel values can only be 0 or 1, meeting the requirements for out_channels=2
|
|
|
+ binary_label = (single_channel > 127).float()
|
|
|
+
|
|
|
+ return binary_label
|
|
|
+
|
|
|
+ # Define training set transformations
|
|
|
+ train_transforms = Compose([
|
|
|
+ LoadImaged(keys=["image", "label"]),
|
|
|
+ EnsureChannelFirstd(keys=["image", "label"]),
|
|
|
+ # Convert labels to single-channel (take first channel or convert to grayscale)
|
|
|
+ Lambdad(keys=["label"], func=convert_label_to_single_channel),
|
|
|
+ Resized(keys=["image", "label"], spatial_size=args.target_spatial_size, mode=("bilinear", "nearest")),
|
|
|
+ ScaleIntensityd(keys=["image"]),
|
|
|
+ RandFlipd(keys=["image", "label"], prob=0.5),
|
|
|
+ RandRotated(keys=["image", "label"], range_x=0.15, prob=0.5),
|
|
|
+ ])
|
|
|
+ if args.dataset_enhanced:
|
|
|
+ train_transforms = create_enhanced_transforms(args.target_spatial_size)
|
|
|
+ print("✓ 使用增强数据增强策略")
|
|
|
+
|
|
|
+ # Define validation set transformations
|
|
|
+ val_transforms = Compose([
|
|
|
+ LoadImaged(keys=["image", "label"]),
|
|
|
+ EnsureChannelFirstd(keys=["image", "label"]),
|
|
|
+ # Convert labels to single-channel (take first channel or convert to grayscale)
|
|
|
+ Lambdad(keys=["label"], func=convert_label_to_single_channel),
|
|
|
+ Resized(keys=["image", "label"], spatial_size=args.target_spatial_size, mode=("bilinear", "nearest")),
|
|
|
+ ScaleIntensityd(keys=["image"]),
|
|
|
+ ])
|
|
|
+
|
|
|
+ # Create training dataset
|
|
|
+ train_dataset = PolypDetectionDataset(
|
|
|
+ root_dir=Path(args.data_root) / args.dataset_name,
|
|
|
+ flag='train',
|
|
|
+ transform=train_transforms
|
|
|
+ )
|
|
|
+
|
|
|
+ # Create validation dataset
|
|
|
+ val_dataset = PolypDetectionDataset(
|
|
|
+ root_dir=Path(args.data_root) / args.dataset_name,
|
|
|
+ flag='val',
|
|
|
+ transform=val_transforms
|
|
|
+ )
|
|
|
+
|
|
|
+ print(f"✓ Training set size: {len(train_dataset)} samples")
|
|
|
+ print(f"✓ Validation set size: {len(val_dataset)} samples")
|
|
|
+ print(f"✓ Total samples: {len(train_dataset) + len(val_dataset)} samples")
|
|
|
+ print("=" * 60)
|
|
|
+
|
|
|
+ return train_dataset, val_dataset
|
|
|
+
|
|
|
+
|
|
|
+def create_dataloaders(args, train_dataset, val_dataset):
|
|
|
+ """
|
|
|
+ Create data loaders
|
|
|
+
|
|
|
+ Args:
|
|
|
+ args: Command line arguments
|
|
|
+ train_dataset: Training dataset
|
|
|
+ val_dataset: Validation dataset
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ tuple: (train_loader, val_loader)
|
|
|
+ """
|
|
|
+ train_loader = monai.data.DataLoader(
|
|
|
+ train_dataset,
|
|
|
+ batch_size=args.batch_size,
|
|
|
+ shuffle=True,
|
|
|
+ num_workers=args.num_workers,
|
|
|
+ pin_memory=args.pin_memory,
|
|
|
+ drop_last=True
|
|
|
+ )
|
|
|
+
|
|
|
+ val_loader = monai.data.DataLoader(
|
|
|
+ val_dataset,
|
|
|
+ batch_size=args.batch_size,
|
|
|
+ shuffle=False,
|
|
|
+ num_workers=args.num_workers,
|
|
|
+ pin_memory=args.pin_memory,
|
|
|
+ drop_last=False
|
|
|
+ )
|
|
|
+
|
|
|
+ print(f"✓ Training data loader: {len(train_loader)} batches")
|
|
|
+ print(f"✓ Validation data loader: {len(val_loader)} batches")
|
|
|
+
|
|
|
+ return train_loader, val_loader
|
|
|
+
|
|
|
+
|
|
|
+def create_model(args):
|
|
|
+ """
|
|
|
+ Create the model
|
|
|
+
|
|
|
+ Args:
|
|
|
+ args: Command line arguments
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ torch.nn.Module: Initialized model
|
|
|
+ """
|
|
|
+ print("\n" + "=" * 60)
|
|
|
+ print("Creating model...")
|
|
|
+
|
|
|
+ model = Wavelet_FFT_SwinUNETR(
|
|
|
+ in_channels=args.in_channels,
|
|
|
+ out_channels=args.out_channels,
|
|
|
+ feature_size=args.feature_size,
|
|
|
+ spatial_dims=args.spatial_dims,
|
|
|
+ wavelet_enhancement=args.use_wavelet,
|
|
|
+ wavelet_J=args.wavelet_J,
|
|
|
+ wavelet_wave=args.wavelet_wave,
|
|
|
+ wavelet_mode='symmetric',
|
|
|
+ wavelet_reduction=args.wavelet_reduction,
|
|
|
+ fft_enhancement=args.use_fft,
|
|
|
+ use_v2=args.use_v2
|
|
|
+ )
|
|
|
+
|
|
|
+ # Print model information
|
|
|
+ total_params = sum(p.numel() for p in model.parameters())
|
|
|
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
+
|
|
|
+ print(f"\n✓ Total model parameters: {total_params:,}")
|
|
|
+ print(f"✓ Trainable parameters: {trainable_params:,}")
|
|
|
+ print(f"✓ Using device: {args.device}")
|
|
|
+ print("=" * 60)
|
|
|
+
|
|
|
+ return model
|
|
|
+
|
|
|
+
|
|
|
+def create_loss_function(args):
|
|
|
+ """
|
|
|
+ Create loss function
|
|
|
+
|
|
|
+ Args:
|
|
|
+ args: Command line arguments
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Callable: Loss function
|
|
|
+ """
|
|
|
+ loss_fn = CombinedDiceCEIoULoss(
|
|
|
+ dice_weight=args.dice_weight,
|
|
|
+ ce_weight=args.ce_weight,
|
|
|
+ iou_weight=args.iou_weight,
|
|
|
+ include_background=True,
|
|
|
+ to_onehot_y=False,
|
|
|
+ softmax=False,
|
|
|
+ sigmoid=True,
|
|
|
+ )
|
|
|
+ return loss_fn
|
|
|
+
|
|
|
+
|
|
|
+def create_optimizer(args, model):
|
|
|
+ """
|
|
|
+ Create optimizer
|
|
|
+
|
|
|
+ Args:
|
|
|
+ args: Command line arguments
|
|
|
+ model: Model
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Optimizer: Optimizer
|
|
|
+ """
|
|
|
+ optimizer = AdamW(
|
|
|
+ model.parameters(),
|
|
|
+ lr=args.learning_rate,
|
|
|
+ weight_decay=args.weight_decay
|
|
|
+ )
|
|
|
+ scheduler = ReduceLROnPlateau(
|
|
|
+ optimizer,
|
|
|
+ mode='min', # 验证损失越小越好
|
|
|
+ factor=0.5, # 每次乘以 0.5
|
|
|
+ patience=20, # 20 个 epoch 不下降则降低 LR
|
|
|
+ threshold=1e-4, # 最小变化阈值
|
|
|
+ cooldown=5, # 降低 LR 后的冷却期
|
|
|
+ min_lr=1e-6 # 学习率下限
|
|
|
+ )
|
|
|
+ print(f"✓ Optimizer: AdamW")
|
|
|
+ print(f" - Learning rate: {args.learning_rate}")
|
|
|
+ print(f" - Weight decay: {args.weight_decay}")
|
|
|
+ print(f"✓ Scheduler: ReduceLROnPlateau")
|
|
|
+ print(f" - Mode: {scheduler.mode}")
|
|
|
+ print(f" - Decay factor: {scheduler.factor}")
|
|
|
+ print(f" - Patience: {scheduler.patience}")
|
|
|
+ print(f" - Minimum change threshold: {scheduler.threshold}")
|
|
|
+ print(f" - Cooldown period: {scheduler.cooldown}")
|
|
|
+ print(f" - Minimum learning rate: {scheduler.min_lrs}")
|
|
|
+
|
|
|
+ return optimizer, scheduler
|
|
|
+
|
|
|
+
|
|
|
+def setup_swanlab(args):
|
|
|
+ """
|
|
|
+ Configure SwanLab experiment tracking
|
|
|
+
|
|
|
+ Args:
|
|
|
+ args: Command line arguments
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ swanlab.Run: SwanLab run object
|
|
|
+ """
|
|
|
+ # If experiment name is not specified, use timestamp
|
|
|
+ if args.swanlab_experiment is None:
|
|
|
+ args.swanlab_experiment = "v2_" + args.dataset_name + "_" + datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
+
|
|
|
+ # Create log directory
|
|
|
+ os.makedirs(args.swanlab_log_dir, exist_ok=True)
|
|
|
+ os.makedirs(args.output_dir, exist_ok=True)
|
|
|
+
|
|
|
+ # Initialize SwanLab
|
|
|
+ run = swanlab.init(
|
|
|
+ project=args.swanlab_project,
|
|
|
+ experiment_name=args.swanlab_experiment,
|
|
|
+ logdir=args.swanlab_log_dir,
|
|
|
+ config=vars(args)
|
|
|
+ )
|
|
|
+
|
|
|
+ print(f"\n✓ SwanLab experiment initialized: {args.swanlab_experiment}")
|
|
|
+ print(f" - Project: {args.swanlab_project}")
|
|
|
+ print(f" - Log directory: {args.swanlab_log_dir}")
|
|
|
+
|
|
|
+ return run
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ """
|
|
|
+ Main training function
|
|
|
+ """
|
|
|
+ # ==================== Step 1: Parse arguments ====================
|
|
|
+ args = parse_args()
|
|
|
+
|
|
|
+ # Set random seed for reproducibility
|
|
|
+ torch.manual_seed(args.seed)
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ torch.cuda.manual_seed_all(args.seed)
|
|
|
+
|
|
|
+ print("\n" + "=" * 60)
|
|
|
+ print("Polyp Segmentation Model Training Started")
|
|
|
+ print("=" * 60)
|
|
|
+ print(f"Using device: {args.device}")
|
|
|
+ print(f"Batch size: {args.batch_size}")
|
|
|
+ print(f"Maximum epochs: {args.max_epochs}")
|
|
|
+ if args.early_stopping:
|
|
|
+ print(
|
|
|
+ f"Early stopping: enabled (patience={args.early_stopping_patience}, monitor={args.early_stopping_monitor})")
|
|
|
+
|
|
|
+ # ==================== Step 2: Initialize SwanLab ====================
|
|
|
+ run = setup_swanlab(args)
|
|
|
+
|
|
|
+ # ==================== Step 3: Create datasets and data loaders ====================
|
|
|
+ train_dataset, val_dataset = create_datasets(args)
|
|
|
+ train_loader, val_loader = create_dataloaders(args, train_dataset, val_dataset)
|
|
|
+
|
|
|
+ # ==================== Step 4: Create model, loss function, optimizer ====================
|
|
|
+ model = create_model(args)
|
|
|
+ model = model.to(args.device)
|
|
|
+
|
|
|
+ loss_function = create_loss_function(args)
|
|
|
+ optimizer, scheduler = create_optimizer(args, model)
|
|
|
+
|
|
|
+ # ==================== Step 5: Create evaluation metrics ====================
|
|
|
+ dice_metric = DiceMetric(reduction="mean")
|
|
|
+ iou_metric = MeanIoU(reduction="mean")
|
|
|
+ hd_metric = HausdorffDistanceMetric(reduction="mean")
|
|
|
+ # ==================== Step 6: Setup training loop ====================
|
|
|
+ best_dice = -1
|
|
|
+ best_dice_epoch = -1
|
|
|
+ best_metric = -1
|
|
|
+ best_metric_epoch = -1
|
|
|
+ best_iou = -1
|
|
|
+ best_iou_epoch = -1
|
|
|
+ epoch_loss_values = []
|
|
|
+
|
|
|
+ dice_metric_values = []
|
|
|
+ iou_metric_values = []
|
|
|
+ hd_metric_values = []
|
|
|
+ start_epoch = 0
|
|
|
+ # ==================== Early stopping related variables ====================
|
|
|
+ early_stopping_counter = 0
|
|
|
+ should_stop = False
|
|
|
+ has_restarted = False # Flag indicating whether it has been restarted once
|
|
|
+
|
|
|
+ # ==================== Step 7: Resume training (if checkpoint exists) ====================
|
|
|
+ checkpoint_loaded = False
|
|
|
+ checkpoint = None
|
|
|
+ if args.resume:
|
|
|
+ # User specified checkpoint path
|
|
|
+ if not os.path.exists(args.resume):
|
|
|
+ raise FileNotFoundError(f"Checkpoint file does not exist: {args.resume}")
|
|
|
+
|
|
|
+ checkpoint_path = args.resume
|
|
|
+ print(f"\nResuming training from user-specified checkpoint: {checkpoint_path}")
|
|
|
+ checkpoint = torch.load(checkpoint_path, map_location=args.device)
|
|
|
+ checkpoint_loaded = True
|
|
|
+
|
|
|
+ elif args.auto_resume:
|
|
|
+ # Automatically find best checkpoint
|
|
|
+ checkpoint_path = find_best_checkpoint(args)
|
|
|
+ if checkpoint_path:
|
|
|
+ print(f"\nAuto-resume mode: loading {checkpoint_path}")
|
|
|
+ checkpoint = torch.load(checkpoint_path, map_location=args.device)
|
|
|
+ checkpoint_loaded = True
|
|
|
+ else:
|
|
|
+ print("\nNo checkpoints found, starting training from scratch")
|
|
|
+
|
|
|
+ if checkpoint_loaded:
|
|
|
+ # Load model weights - supports migration from v1 to v2
|
|
|
+ model_dict = model.state_dict()
|
|
|
+
|
|
|
+ # Try to load from checkpoint
|
|
|
+ try:
|
|
|
+ pretrained_dict = checkpoint["model_state_dict"]
|
|
|
+ print("✓ Model weights loaded from training checkpoint")
|
|
|
+ except KeyError:
|
|
|
+ pretrained_dict = checkpoint
|
|
|
+ print("Loading model weights from best Dice or best overall model")
|
|
|
+
|
|
|
+ # Filter and match parameters (handle structural changes from v1->v2)
|
|
|
+ matched_params = {}
|
|
|
+ unmatched_params = []
|
|
|
+ missing_params = []
|
|
|
+
|
|
|
+ for name, param in model_dict.items():
|
|
|
+ if name in pretrained_dict:
|
|
|
+ pretrained_param = pretrained_dict[name]
|
|
|
+ # Check if shape matches
|
|
|
+ if param.shape == pretrained_param.shape:
|
|
|
+ matched_params[name] = pretrained_param
|
|
|
+ else:
|
|
|
+ unmatched_params.append(f"{name} (shape mismatch: {param.shape} vs {pretrained_param.shape})")
|
|
|
+ else:
|
|
|
+ missing_params.append(name)
|
|
|
+
|
|
|
+ # Output loading statistics
|
|
|
+ print(f"\nWeight loading statistics:")
|
|
|
+ print(f" ✓ Successfully matched parameters: {len(matched_params)}/{len(model_dict)}")
|
|
|
+ print(f" ⚠ Shape mismatched parameters: {len(unmatched_params)}")
|
|
|
+ print(f" ✗ Newly added parameters (randomly initialized): {len(missing_params)}")
|
|
|
+
|
|
|
+ if unmatched_params:
|
|
|
+ print(f"\nShape mismatched layers:")
|
|
|
+ for info in unmatched_params[:5]: # Only show first 5
|
|
|
+ print(f" - {info}")
|
|
|
+ if len(unmatched_params) > 5:
|
|
|
+ print(f" ... {len(unmatched_params) - 5} more")
|
|
|
+
|
|
|
+ if missing_params:
|
|
|
+ print(f"\nNewly added layers (will be randomly initialized):")
|
|
|
+ for name in missing_params[:5]: # Only show first 5
|
|
|
+ print(f" - {name}")
|
|
|
+ if len(missing_params) > 5:
|
|
|
+ print(f" ... {len(missing_params) - 5} more")
|
|
|
+
|
|
|
+ # Update pre-trained dictionary
|
|
|
+ model_dict.update(matched_params)
|
|
|
+
|
|
|
+ # Load matched parameters
|
|
|
+ model.load_state_dict(model_dict, strict=False)
|
|
|
+ print(f"\n✓ Model weights loaded (strict mode: False)")
|
|
|
+ print("=" * 60)
|
|
|
+
|
|
|
+ if "optimizer_state_dict" in checkpoint:
|
|
|
+ # Load optimizer state
|
|
|
+ optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
|
|
+ print("✓ Optimizer state loaded")
|
|
|
+
|
|
|
+ # Load epoch number
|
|
|
+ # if "epoch" in checkpoint:
|
|
|
+ # start_epoch = checkpoint["epoch"] + 1 # Start from next epoch
|
|
|
+ # print(f"✓ Training epoch restored to epoch {start_epoch}")
|
|
|
+
|
|
|
+ # Load best metrics
|
|
|
+ if "best_dice" in checkpoint:
|
|
|
+ best_dice = checkpoint["best_dice"]
|
|
|
+ best_dice_epoch = checkpoint["best_dice_epoch"]
|
|
|
+ print(f"✓ Best metrics restored: Dice={best_dice:.4f} (Epoch {best_dice_epoch})")
|
|
|
+
|
|
|
+ # Load historical loss and metric values (optional)
|
|
|
+ if "epoch_loss_values" in checkpoint:
|
|
|
+ epoch_loss_values = checkpoint["epoch_loss_values"]
|
|
|
+ if "dice_metric_values" in checkpoint:
|
|
|
+ dice_metric_values = checkpoint["dice_metric_values"]
|
|
|
+
|
|
|
+ # Load early stopping state
|
|
|
+ if args.early_stopping:
|
|
|
+ if "early_stopping_counter" in checkpoint:
|
|
|
+ early_stopping_counter = checkpoint["early_stopping_counter"]
|
|
|
+ print(f"✓ Early stopping counter restored: {early_stopping_counter}")
|
|
|
+ if "should_stop" in checkpoint and checkpoint["should_stop"]:
|
|
|
+ should_stop = False # Even if marked as stopped, allow continued training
|
|
|
+ print("✓ Early stopping state reset, can continue training")
|
|
|
+
|
|
|
+ print(f"✓ Training will continue from epoch {start_epoch}")
|
|
|
+ print("=" * 60)
|
|
|
+
|
|
|
+ print("\n" + "=" * 60)
|
|
|
+ print("Starting training...")
|
|
|
+ print("=" * 60)
|
|
|
+
|
|
|
+ start_time = time.time()
|
|
|
+
|
|
|
+ try:
|
|
|
+ for epoch in range(start_epoch, run.config.max_epochs):
|
|
|
+ # ========== Check early stopping condition ==========
|
|
|
+ if should_stop and args.early_stopping:
|
|
|
+ print(f"\n{'=' * 60}")
|
|
|
+ print(f"Early stopping triggered! Training will terminate early at epoch {epoch + 1}")
|
|
|
+ print(f"{'=' * 60}")
|
|
|
+
|
|
|
+ if not has_restarted:
|
|
|
+ # First early stopping: load best weights, restart training
|
|
|
+ print("Early stopping detected, preparing to restart training from best model...")
|
|
|
+
|
|
|
+ # 1. Find the best Dice model
|
|
|
+ best_checkpoint_path = os.path.join(
|
|
|
+ args.output_dir,
|
|
|
+ f"best_dice_model_{args.dataset_name}.pt"
|
|
|
+ )
|
|
|
+
|
|
|
+ if os.path.exists(best_checkpoint_path):
|
|
|
+ print(f"Loading best Dice model: {best_checkpoint_path}")
|
|
|
+ checkpoint = torch.load(best_checkpoint_path, map_location=args.device)
|
|
|
+
|
|
|
+ # 2. Load best weights
|
|
|
+ model.load_state_dict(checkpoint)
|
|
|
+ print("✓ Model weights restored to best state")
|
|
|
+
|
|
|
+ # 3. Reset optimizer
|
|
|
+ optimizer, scheduler = create_optimizer(args, model)
|
|
|
+ print("✓ Optimizer has been reset")
|
|
|
+
|
|
|
+ # 4. Reset early stopping counter
|
|
|
+ early_stopping_counter = 0
|
|
|
+ should_stop = False
|
|
|
+ has_restarted = True
|
|
|
+
|
|
|
+ print("✓ Training restarted from best model")
|
|
|
+ print(f"{'=' * 60}\n")
|
|
|
+ continue # Skip break, continue to next epoch
|
|
|
+ else:
|
|
|
+ print(f"Warning: Best model file not found {best_checkpoint_path}")
|
|
|
+ print("Will stop training directly")
|
|
|
+
|
|
|
+ # Second early stopping or best model not found: truly stop
|
|
|
+ print("Training has been restarted once after early stopping, now stopping training")
|
|
|
+ break
|
|
|
+
|
|
|
+ # ========== Training phase ==========
|
|
|
+ model.train()
|
|
|
+ step = 0
|
|
|
+ epoch_loss = 0
|
|
|
+ epoch_loss_dice_ce = 0
|
|
|
+ epoch_loss_iou = 0
|
|
|
+
|
|
|
+ for batch_data in train_loader:
|
|
|
+ step += 1
|
|
|
+ inputs = batch_data["image"].to(args.device)
|
|
|
+ targets = batch_data["label"].to(args.device)
|
|
|
+
|
|
|
+ optimizer.zero_grad()
|
|
|
+ outputs = model(inputs)
|
|
|
+ loss, loss_dice_ce, loss_iou = loss_function(outputs, targets)
|
|
|
+ loss.backward()
|
|
|
+ optimizer.step()
|
|
|
+
|
|
|
+ epoch_loss += loss.item()
|
|
|
+ epoch_loss_dice_ce += loss_dice_ce.item()
|
|
|
+ epoch_loss_iou += loss_iou.item()
|
|
|
+
|
|
|
+ # If this is the first epoch resumed from checkpoint, print notification
|
|
|
+ if epoch == start_epoch and start_epoch > 0:
|
|
|
+ print(f"\n✓ Training resumed from epoch {start_epoch}")
|
|
|
+ epoch_loss /= step
|
|
|
+ epoch_loss_dice_ce /= step
|
|
|
+ epoch_loss_iou /= step
|
|
|
+
|
|
|
+ epoch_loss_values.append(epoch_loss)
|
|
|
+ print(f"\nEpoch {epoch + 1}/{args.max_epochs} - Training loss: {epoch_loss:.4f}")
|
|
|
+ # Log to SwanLab
|
|
|
+ swanlab.log({
|
|
|
+ "train/loss": epoch_loss,
|
|
|
+ "train/loss_dice_ce": epoch_loss_dice_ce,
|
|
|
+ "train/loss_iou": epoch_loss_iou,
|
|
|
+ "train/lr": optimizer.param_groups[0]['lr'],
|
|
|
+ }, step=(epoch + 1))
|
|
|
+
|
|
|
+ # ========== Validation phase ==========
|
|
|
+ model.eval()
|
|
|
+ val_loss_total = 0
|
|
|
+ with torch.no_grad():
|
|
|
+ dice_metric.reset()
|
|
|
+ iou_metric.reset()
|
|
|
+ hd_metric.reset()
|
|
|
+
|
|
|
+ for val_data in val_loader:
|
|
|
+ val_images = val_data["image"].to(args.device)
|
|
|
+ val_labels = val_data["label"].to(args.device)
|
|
|
+
|
|
|
+ val_outputs = model(val_images)
|
|
|
+ # Calculate validation loss
|
|
|
+ val_loss_batch, _, _ = loss_function(val_outputs, val_labels)
|
|
|
+ val_loss_total += val_loss_batch.item()
|
|
|
+
|
|
|
+ # Post-processing
|
|
|
+ val_outputs = torch.sigmoid(val_outputs)
|
|
|
+ val_outputs = (val_outputs > 0.5).int()
|
|
|
+
|
|
|
+ # Calculate Dice score
|
|
|
+ dice_metric(y_pred=val_outputs, y=val_labels)
|
|
|
+ iou_metric(y_pred=val_outputs, y=val_labels)
|
|
|
+ hd_metric(y_pred=val_outputs, y=val_labels)
|
|
|
+
|
|
|
+ # Calculate average validation loss
|
|
|
+ val_loss_avg = val_loss_total / len(val_loader)
|
|
|
+
|
|
|
+ # Update learning rate scheduler
|
|
|
+ scheduler.step(val_loss_avg)
|
|
|
+ current_lr = optimizer.param_groups[0]['lr']
|
|
|
+
|
|
|
+ # Aggregate results
|
|
|
+ mean_dice = dice_metric.aggregate().item()
|
|
|
+ dice_metric_values.append(mean_dice)
|
|
|
+ mean_iou = iou_metric.aggregate().item()
|
|
|
+ iou_metric_values.append(mean_iou)
|
|
|
+ mean_hd = hd_metric.aggregate().item()
|
|
|
+ hd_metric_values.append(mean_hd)
|
|
|
+
|
|
|
+ print(
|
|
|
+ f"Epoch {epoch + 1} - Validation Dice: {mean_dice:.4f}, Validation loss: {val_loss_avg:.4f}, Current LR: {current_lr:.2e}")
|
|
|
+ swanlab.log({
|
|
|
+ "val/loss": val_loss_avg,
|
|
|
+ "val/mean_dice": mean_dice,
|
|
|
+ "val/mean_iou": mean_iou,
|
|
|
+ "val/mean_hd": mean_hd,
|
|
|
+ "val/lr": current_lr,
|
|
|
+ }, step=(epoch + 1))
|
|
|
+
|
|
|
+ # ========== Early stopping check ==========
|
|
|
+ if args.early_stopping:
|
|
|
+ # Get current monitored metric
|
|
|
+ if args.early_stopping_monitor == "dice":
|
|
|
+ current_score = mean_dice
|
|
|
+ best_score = best_dice
|
|
|
+ is_better = current_score > best_score + args.early_stopping_min_delta
|
|
|
+ elif args.early_stopping_monitor == "iou":
|
|
|
+ current_score = mean_iou
|
|
|
+ best_score = best_iou
|
|
|
+ is_better = current_score > best_score + args.early_stopping_min_delta
|
|
|
+ elif args.early_stopping_monitor == "metric":
|
|
|
+ normalized_hd = 1.0 / (1.0 + mean_hd)
|
|
|
+ current_score = 1 * mean_dice + 1 * mean_iou + 1 * normalized_hd
|
|
|
+ best_score = best_metric
|
|
|
+ is_better = current_score > best_score + args.early_stopping_min_delta
|
|
|
+ else: # loss
|
|
|
+ current_score = -val_loss_avg # Lower loss is better, so take negative
|
|
|
+ best_score = -min(epoch_loss_values) if epoch_loss_values else float('-inf')
|
|
|
+ is_better = current_score > best_score + args.early_stopping_min_delta
|
|
|
+
|
|
|
+ # Check if there is improvement
|
|
|
+ if is_better:
|
|
|
+ early_stopping_counter = 0
|
|
|
+ print(
|
|
|
+ f" ✓ {args.early_stopping_monitor.upper()} metric improved: {current_score:.4f} > {best_score:.4f}")
|
|
|
+ else:
|
|
|
+ early_stopping_counter += 1
|
|
|
+ print(
|
|
|
+ f" ⚠ {args.early_stopping_monitor.upper()} metric did not improve, counter: {early_stopping_counter}/{args.early_stopping_patience}")
|
|
|
+
|
|
|
+ # Check if early stopping should be triggered
|
|
|
+ if early_stopping_counter >= args.early_stopping_patience:
|
|
|
+ should_stop = True
|
|
|
+
|
|
|
+ # Save best Dice model
|
|
|
+ if mean_dice > best_dice:
|
|
|
+ best_dice = mean_dice
|
|
|
+ best_dice_epoch = epoch + 1
|
|
|
+ checkpoint_path = os.path.join(args.output_dir, f"best_dice_model_{args.dataset_name}.pt")
|
|
|
+ Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
|
|
|
+ torch.save(model.state_dict(), checkpoint_path)
|
|
|
+ print(
|
|
|
+ f"✓ Found better Dice model! Dice: {mean_dice:.4f}, IoU: {mean_iou:.4f}, HD: {mean_hd:.4f}, saved to {checkpoint_path}")
|
|
|
+ # Save best IoU model
|
|
|
+ if mean_iou > best_iou:
|
|
|
+ best_iou = mean_iou
|
|
|
+ best_iou_epoch = epoch + 1
|
|
|
+ checkpoint_path = os.path.join(args.output_dir, f"best_iou_model_{args.dataset_name}.pt")
|
|
|
+ Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
|
|
|
+ torch.save(model.state_dict(), checkpoint_path)
|
|
|
+ print(
|
|
|
+ f"✓ Found better IoU model! IoU: {mean_iou:.4f}, Dice: {mean_dice:.4f}, HD: {mean_hd:.4f}, saved to {checkpoint_path}"
|
|
|
+ )
|
|
|
+ # Save best overall model
|
|
|
+ normalized_hd = 1.0 / (1.0 + mean_hd)
|
|
|
+ mean_metric = (
|
|
|
+ 1 * mean_dice +
|
|
|
+ 1 * mean_iou +
|
|
|
+ 1 * normalized_hd
|
|
|
+ )
|
|
|
+ if mean_metric > best_metric:
|
|
|
+ best_metric = mean_metric
|
|
|
+ best_metric_epoch = epoch + 1
|
|
|
+ checkpoint_path = os.path.join(args.output_dir, f"best_metric_model_{args.dataset_name}.pt")
|
|
|
+ Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
|
|
|
+ torch.save(model.state_dict(), checkpoint_path)
|
|
|
+ print(
|
|
|
+ f"✓ Found better overall model! Overall score: {mean_metric:.4f}, Dice: {mean_dice:.4f}, IoU: {mean_iou:.4f}, HD: {mean_hd:.4f}, saved to {checkpoint_path}"
|
|
|
+ )
|
|
|
+ # Periodically save checkpoint
|
|
|
+ if (epoch + 1) % args.save_every == 0:
|
|
|
+ checkpoint_path = os.path.join(args.output_dir, f"checkpoints_{args.dataset_name}",
|
|
|
+ f"checkpoint_epoch={epoch}.pt")
|
|
|
+ Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
|
|
|
+ torch.save({
|
|
|
+ "epoch": epoch,
|
|
|
+ "model_state_dict": model.state_dict(),
|
|
|
+ "optimizer_state_dict": optimizer.state_dict(),
|
|
|
+ "best_dice": best_dice,
|
|
|
+ "best_dice_epoch": best_dice_epoch,
|
|
|
+ "epoch_loss_values": epoch_loss_values,
|
|
|
+ "dice_metric_values": dice_metric_values,
|
|
|
+ "iou_metric_values": iou_metric_values,
|
|
|
+ "hd_metric_values": hd_metric_values,
|
|
|
+ "best_metric": best_metric,
|
|
|
+ "best_metric_epoch": best_metric_epoch,
|
|
|
+ "best_iou": best_iou,
|
|
|
+ "best_iou_epoch": best_iou_epoch,
|
|
|
+ "early_stopping_counter": early_stopping_counter,
|
|
|
+ "should_stop": should_stop
|
|
|
+ }, checkpoint_path)
|
|
|
+ print(f"✓ Checkpoint saved: {checkpoint_path}")
|
|
|
+
|
|
|
+ except KeyboardInterrupt:
|
|
|
+ print("\nTraining interrupted by user")
|
|
|
+ finally:
|
|
|
+ end_time = time.time()
|
|
|
+ training_time = end_time - start_time
|
|
|
+
|
|
|
+ print("\n" + "=" * 60)
|
|
|
+ print("Training completed!")
|
|
|
+ print(f"Total training time: {training_time / 3600:.2f} hours")
|
|
|
+ print(f"Best validation Dice: {best_dice:.4f} (Epoch {best_dice_epoch})")
|
|
|
+ print("=" * 60)
|
|
|
+
|
|
|
+ # Close SwanLab
|
|
|
+ swanlab.finish()
|
|
|
+ print("✓ SwanLab experiment saved")
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|