import argparse import os import time from datetime import datetime from pathlib import Path import monai import monai.utils import swanlab import torch from lib.model.model_v4_minute import Wavelet_FFT_SwinUNETR from monai.metrics import DiceMetric, MeanIoU, HausdorffDistanceMetric from monai.transforms import ( Compose, LoadImaged, ScaleIntensityd, RandFlipd, RandRotated, RandRotate90d, EnsureChannelFirstd, ToTensord, Resized, Lambdad, RandZoomd, RandShiftIntensityd, RandGaussianNoised, RandGaussianSmoothd, RandAdjustContrastd, RandHistogramShiftd, RandAxisFlipd, RandCoarseDropoutd, ) from torch.optim import AdamW from torch.optim.lr_scheduler import ReduceLROnPlateau from datasets.PolypDetectionDataset.PolypDetectionDataset import PolypDetectionDataset from lib.tools.combined_loss import CombinedDiceCEIoULoss def parse_args(): parser = argparse.ArgumentParser(description="Polyp Segmentation Model Training Script") # ==================== Dataset-related Parameters ==================== parser.add_argument( "--dataset_name", type=str, required=True, help="Dataset name" ) parser.add_argument( "--data_root", type=str, default=r"./data/Polyp-Detection-Dataset", help="Root directory path of the dataset" ) parser.add_argument( "--num_workers", type=int, default=0, help="Number of worker processes for data loader" ) parser.add_argument( "--pin_memory", type=bool, default=True, help="Whether to enable pinned memory" ) parser.add_argument( "--target_spatial_size", type=tuple, default=(512, 512), help="Target spatial size" ) parser.add_argument( "--dataset_enhanced", type=bool, default=True, help="Whether to use enhanced data augmentation strategy" ) # ==================== Model-related Parameters ==================== parser.add_argument( "--in_channels", type=int, default=3, help="Number of input image channels" ) parser.add_argument( "--out_channels", type=int, default=1, help="Number of output foreground channels" ) parser.add_argument( "--feature_size", type=int, default=48, help="Network feature dimension" ) parser.add_argument( "--spatial_dims", type=int, default=2, choices=[2, 3], help="Spatial dimension (2D or 3D)" ) parser.add_argument( "--no_wavelet", action="store_false", dest="use_wavelet", help="Whether to enable wavelet enhancement module" ) parser.add_argument( "--wavelet_J", type=int, default=2, help="Wavelet decomposition levels" ) parser.add_argument( "--wavelet_wave", type=str, default="db4", help="Wavelet basis type" ) parser.add_argument( "--wavelet_reduction", type=int, default=16, help="Wavelet attention compression ratio" ) parser.add_argument( "--no_fft", action="store_false", dest="use_fft", help="Whether to enable FFT enhancement module" ) parser.add_argument( "--use_v2", type=bool, default=True, help="Whether to enable Swin-UNETR v2 module" ) # ==================== Training-related Parameters ==================== parser.add_argument( "--max_epochs", type=int, default=1000, help="Maximum number of training epochs" ) parser.add_argument( "--batch_size", type=int, default=4, help="Batch size" ) parser.add_argument( "--learning_rate", type=float, default=1e-4, help="Learning rate" ) parser.add_argument( "--weight_decay", type=float, default=1e-4, help="Weight decay coefficient" ) # ==================== Loss Function Parameters ==================== parser.add_argument( "--dice_weight", type=float, default=1.0, help="Dice loss weight" ) parser.add_argument( "--ce_weight", type=float, default=1.0, help="Cross Entropy loss weight" ) parser.add_argument( "--iou_weight", type=float, default=1.0, help="IoU loss weight" ) # ==================== SwanLab Parameters ==================== parser.add_argument( "--swanlab_project", type=str, default="polyp-segmentation-v4_minute", help="SwanLab project name" ) parser.add_argument( "--swanlab_experiment", type=str, default=None, help="SwanLab experiment name (default uses timestamp)" ) parser.add_argument( "--swanlab_log_dir", type=str, default="./swanlab_log", help="SwanLab log directory" ) # ==================== Saving and Loading Parameters ==================== parser.add_argument( "--output_dir", type=str, default="./outputs_v4_minute", help="Directory for saving model checkpoints" ) parser.add_argument( "--save_every", type=int, default=50, help="Save model every N epochs" ) # ==================== Early Stopping Parameters ==================== parser.add_argument( "--early_stopping", type=bool, default=True, help="Whether to enable early stopping" ) parser.add_argument( "--early_stopping_patience", type=int, default=100, help="Early stopping patience (stop if validation metric doesn't improve for N rounds)" ) parser.add_argument( "--early_stopping_min_delta", type=float, default=1e-4, help="Minimum improvement threshold (improvement below this value is considered no improvement)" ) parser.add_argument( "--early_stopping_monitor", type=str, default="dice", choices=["dice", "iou", "metric", "loss"], help="Metric to monitor for early stopping" ) parser.add_argument( "--resume", type=str, default=None, help="Checkpoint path to resume training. If not specified, will automatically load the best Dice model (if exists)" ) parser.add_argument( "--no_auto_resume", action="store_false", dest="auto_resume", help="Whether to enable auto-resume functionality (default loads best Dice model)" ) # ==================== Other Parameters ==================== parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Training device (cuda or cpu)" ) parser.add_argument( "--seed", type=int, default=42, help="Random seed" ) return parser.parse_args() def find_best_checkpoint(args): """ Find the best checkpoint file Args: args: Command line arguments Returns: str or None: Best checkpoint path, returns None if not exists """ # Find the best Dice model best_dice_path = os.path.join(args.output_dir, f"best_dice_model_{args.dataset_name}.pt") if os.path.exists(best_dice_path): print(f"Found best Dice model: {best_dice_path}") return best_dice_path # Find latest checkpoint checkpoint_dir = os.path.join(args.output_dir, f"checkpoints_{args.dataset_name}") if os.path.exists(checkpoint_dir): checkpoints = sorted( [f for f in os.listdir(checkpoint_dir) if f.endswith('.pt')], key=lambda x: int(x.split('epoch=')[1].split('.')[0]) if 'epoch=' in x else -1, reverse=True ) if checkpoints: latest_checkpoint = os.path.join(checkpoint_dir, checkpoints[0]) print(f"Found latest checkpoint: {latest_checkpoint}") return latest_checkpoint # Find best overall model best_metric_path = os.path.join(args.output_dir, f"best_metric_model_{args.dataset_name}.pt") if os.path.exists(best_metric_path): print(f"Found best overall model: {best_metric_path}") return best_metric_path return None def create_enhanced_transforms(target_spatial_size=(512, 512)): """ Enhanced data augmentation strategy Includes: 1. Geometric transformations: flip, rotation, scaling, cropping 2. Photometric transformations: brightness, contrast, gamma correction 3. Noise injection: Gaussian noise, low-resolution simulation 4. Regularization: Coarse Dropout """ def convert_label_to_single_channel(label_tensor): """Convert RGB labels to single-channel binary mask""" single_channel = label_tensor[0:1, :, :] binary_label = (single_channel > 127).float() return binary_label train_transforms = Compose([ # ========== Loading and Preprocessing ========== LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Lambdad(keys=["label"], func=convert_label_to_single_channel), # ========== Spatial Transformations ========== Resized(keys=["image", "label"], spatial_size=target_spatial_size, mode=("bilinear", "nearest")), ScaleIntensityd(keys=["image"]), # --- Geometric Augmentation --- # Random axis flip RandAxisFlipd(keys=["image", "label"], prob=0.5), # Random rotation (-15° to +15°) RandRotated(keys=["image", "label"], range_x=0.15, prob=0.5, keep_size=True, mode=("bilinear", "nearest")), # Random 90-degree rotation RandRotate90d(keys=["image", "label"], prob=0.5, max_k=2), # Random zoom (0.8-1.2x) + cropping RandZoomd(keys=["image", "label"], min_zoom=0.8, max_zoom=1.2, prob=0.5, mode=("bilinear", "nearest"), keep_size=True), # ========== Photometric Transformations ========== # Random brightness adjustment (±20%) RandShiftIntensityd(keys=["image"], offsets=(-0.2, 0.2), prob=0.5), # Random contrast adjustment (gamma 0.7-1.3) RandAdjustContrastd(keys=["image"], gamma=(0.7, 1.3), prob=0.5), # Random histogram shift (simulate different staining/lighting conditions) RandHistogramShiftd(keys=["image"], num_control_points=(5, 10), prob=0.3), # ========== Noise and Quality Degradation ========== # Random Gaussian smoothing (simulate blur) RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.0), sigma_y=(0.5, 1.0), prob=0.3), # Random Gaussian noise RandGaussianNoised(keys=["image"], mean=0.0, std=0.05, prob=0.3), # Coarse Dropout (occlusion augmentation, improve robustness) RandCoarseDropoutd( keys=["image"], holes=1, max_holes=3, spatial_size=(32, 32), max_spatial_size=(64, 64), prob=0.3 ), # ========== Post-processing ========== ToTensord(keys=["image", "label"]), ]) return train_transforms def create_datasets(args): """ Create training and validation datasets Args: args: Command line arguments Returns: tuple: (train_dataset, val_dataset) """ print("=" * 60) print("正在加载数据集...") print("=" * 60) def convert_label_to_single_channel(label_tensor): """ Global function: Convert 3-channel RGB labels to 1-channel binary labels (0 or 1) Input: label_tensor (shape: [3, H, W], value range 0-255) Output: new_tensor (shape: [1, H, W], value range 0 or 1) """ # 1. Extract first channel (R channel) single_channel = label_tensor[0:1, :, :] # 2. Binarization: pixels greater than 0 are set to 1 (assuming background is pure black 0, polyp is white or other color) # This ensures all pixel values can only be 0 or 1, meeting the requirements for out_channels=2 binary_label = (single_channel > 127).float() return binary_label # Define training set transformations train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), # Convert labels to single-channel (take first channel or convert to grayscale) Lambdad(keys=["label"], func=convert_label_to_single_channel), Resized(keys=["image", "label"], spatial_size=args.target_spatial_size, mode=("bilinear", "nearest")), ScaleIntensityd(keys=["image"]), RandFlipd(keys=["image", "label"], prob=0.5), RandRotated(keys=["image", "label"], range_x=0.15, prob=0.5), ]) if args.dataset_enhanced: train_transforms = create_enhanced_transforms(args.target_spatial_size) print("✓ 使用增强数据增强策略") # Define validation set transformations val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), # Convert labels to single-channel (take first channel or convert to grayscale) Lambdad(keys=["label"], func=convert_label_to_single_channel), Resized(keys=["image", "label"], spatial_size=args.target_spatial_size, mode=("bilinear", "nearest")), ScaleIntensityd(keys=["image"]), ]) # Create training dataset train_dataset = PolypDetectionDataset( root_dir=Path(args.data_root) / args.dataset_name, flag='train', transform=train_transforms ) # Create validation dataset val_dataset = PolypDetectionDataset( root_dir=Path(args.data_root) / args.dataset_name, flag='val', transform=val_transforms ) print(f"✓ Training set size: {len(train_dataset)} samples") print(f"✓ Validation set size: {len(val_dataset)} samples") print(f"✓ Total samples: {len(train_dataset) + len(val_dataset)} samples") print("=" * 60) return train_dataset, val_dataset def create_dataloaders(args, train_dataset, val_dataset): """ Create data loaders Args: args: Command line arguments train_dataset: Training dataset val_dataset: Validation dataset Returns: tuple: (train_loader, val_loader) """ train_loader = monai.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=args.pin_memory, drop_last=True ) val_loader = monai.data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_memory, drop_last=False ) print(f"✓ Training data loader: {len(train_loader)} batches") print(f"✓ Validation data loader: {len(val_loader)} batches") return train_loader, val_loader def create_model(args): """ Create the model Args: args: Command line arguments Returns: torch.nn.Module: Initialized model """ print("\n" + "=" * 60) print("Creating model...") model = Wavelet_FFT_SwinUNETR( in_channels=args.in_channels, out_channels=args.out_channels, feature_size=args.feature_size, spatial_dims=args.spatial_dims, wavelet_enhancement=args.use_wavelet, wavelet_J=args.wavelet_J, wavelet_wave=args.wavelet_wave, wavelet_mode='symmetric', wavelet_reduction=args.wavelet_reduction, fft_enhancement=args.use_fft, use_v2=args.use_v2 ) # Print model information total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"\n✓ Total model parameters: {total_params:,}") print(f"✓ Trainable parameters: {trainable_params:,}") print(f"✓ Using device: {args.device}") print("=" * 60) return model def create_loss_function(args): """ Create loss function Args: args: Command line arguments Returns: Callable: Loss function """ loss_fn = CombinedDiceCEIoULoss( dice_weight=args.dice_weight, ce_weight=args.ce_weight, iou_weight=args.iou_weight, include_background=True, to_onehot_y=False, softmax=False, sigmoid=True, ) return loss_fn def create_optimizer(args, model): """ Create optimizer Args: args: Command line arguments model: Model Returns: Optimizer: Optimizer """ optimizer = AdamW( model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay ) scheduler = ReduceLROnPlateau( optimizer, mode='min', # 验证损失越小越好 factor=0.5, # 每次乘以 0.5 patience=20, # 20 个 epoch 不下降则降低 LR threshold=1e-4, # 最小变化阈值 cooldown=5, # 降低 LR 后的冷却期 min_lr=1e-6 # 学习率下限 ) print(f"✓ Optimizer: AdamW") print(f" - Learning rate: {args.learning_rate}") print(f" - Weight decay: {args.weight_decay}") print(f"✓ Scheduler: ReduceLROnPlateau") print(f" - Mode: {scheduler.mode}") print(f" - Decay factor: {scheduler.factor}") print(f" - Patience: {scheduler.patience}") print(f" - Minimum change threshold: {scheduler.threshold}") print(f" - Cooldown period: {scheduler.cooldown}") print(f" - Minimum learning rate: {scheduler.min_lrs}") return optimizer, scheduler def setup_swanlab(args): """ Configure SwanLab experiment tracking Args: args: Command line arguments Returns: swanlab.Run: SwanLab run object """ # If experiment name is not specified, use timestamp if args.swanlab_experiment is None: args.swanlab_experiment = "v2_" + args.dataset_name + "_" + datetime.now().strftime("%Y%m%d_%H%M%S") # Create log directory os.makedirs(args.swanlab_log_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True) # Initialize SwanLab run = swanlab.init( project=args.swanlab_project, experiment_name=args.swanlab_experiment, logdir=args.swanlab_log_dir, config=vars(args) ) print(f"\n✓ SwanLab experiment initialized: {args.swanlab_experiment}") print(f" - Project: {args.swanlab_project}") print(f" - Log directory: {args.swanlab_log_dir}") return run def main(): """ Main training function """ # ==================== Step 1: Parse arguments ==================== args = parse_args() # Set random seed for reproducibility torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) print("\n" + "=" * 60) print("Polyp Segmentation Model Training Started") print("=" * 60) print(f"Using device: {args.device}") print(f"Batch size: {args.batch_size}") print(f"Maximum epochs: {args.max_epochs}") if args.early_stopping: print( f"Early stopping: enabled (patience={args.early_stopping_patience}, monitor={args.early_stopping_monitor})") # ==================== Step 2: Initialize SwanLab ==================== run = setup_swanlab(args) # ==================== Step 3: Create datasets and data loaders ==================== train_dataset, val_dataset = create_datasets(args) train_loader, val_loader = create_dataloaders(args, train_dataset, val_dataset) # ==================== Step 4: Create model, loss function, optimizer ==================== model = create_model(args) model = model.to(args.device) loss_function = create_loss_function(args) optimizer, scheduler = create_optimizer(args, model) # ==================== Step 5: Create evaluation metrics ==================== dice_metric = DiceMetric(reduction="mean") iou_metric = MeanIoU(reduction="mean") hd_metric = HausdorffDistanceMetric(reduction="mean") # ==================== Step 6: Setup training loop ==================== best_dice = -1 best_dice_epoch = -1 best_metric = -1 best_metric_epoch = -1 best_iou = -1 best_iou_epoch = -1 epoch_loss_values = [] dice_metric_values = [] iou_metric_values = [] hd_metric_values = [] start_epoch = 0 # ==================== Early stopping related variables ==================== early_stopping_counter = 0 should_stop = False has_restarted = False # Flag indicating whether it has been restarted once # ==================== Step 7: Resume training (if checkpoint exists) ==================== checkpoint_loaded = False checkpoint = None if args.resume: # User specified checkpoint path if not os.path.exists(args.resume): raise FileNotFoundError(f"Checkpoint file does not exist: {args.resume}") checkpoint_path = args.resume print(f"\nResuming training from user-specified checkpoint: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location=args.device) checkpoint_loaded = True elif args.auto_resume: # Automatically find best checkpoint checkpoint_path = find_best_checkpoint(args) if checkpoint_path: print(f"\nAuto-resume mode: loading {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location=args.device) checkpoint_loaded = True else: print("\nNo checkpoints found, starting training from scratch") if checkpoint_loaded: # Load model weights - supports migration from v1 to v2 model_dict = model.state_dict() # Try to load from checkpoint try: pretrained_dict = checkpoint["model_state_dict"] print("✓ Model weights loaded from training checkpoint") except KeyError: pretrained_dict = checkpoint print("Loading model weights from best Dice or best overall model") # Filter and match parameters (handle structural changes from v1->v2) matched_params = {} unmatched_params = [] missing_params = [] for name, param in model_dict.items(): if name in pretrained_dict: pretrained_param = pretrained_dict[name] # Check if shape matches if param.shape == pretrained_param.shape: matched_params[name] = pretrained_param else: unmatched_params.append(f"{name} (shape mismatch: {param.shape} vs {pretrained_param.shape})") else: missing_params.append(name) # Output loading statistics print(f"\nWeight loading statistics:") print(f" ✓ Successfully matched parameters: {len(matched_params)}/{len(model_dict)}") print(f" ⚠ Shape mismatched parameters: {len(unmatched_params)}") print(f" ✗ Newly added parameters (randomly initialized): {len(missing_params)}") if unmatched_params: print(f"\nShape mismatched layers:") for info in unmatched_params[:5]: # Only show first 5 print(f" - {info}") if len(unmatched_params) > 5: print(f" ... {len(unmatched_params) - 5} more") if missing_params: print(f"\nNewly added layers (will be randomly initialized):") for name in missing_params[:5]: # Only show first 5 print(f" - {name}") if len(missing_params) > 5: print(f" ... {len(missing_params) - 5} more") # Update pre-trained dictionary model_dict.update(matched_params) # Load matched parameters model.load_state_dict(model_dict, strict=False) print(f"\n✓ Model weights loaded (strict mode: False)") print("=" * 60) if "optimizer_state_dict" in checkpoint: # Load optimizer state optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) print("✓ Optimizer state loaded") # Load epoch number # if "epoch" in checkpoint: # start_epoch = checkpoint["epoch"] + 1 # Start from next epoch # print(f"✓ Training epoch restored to epoch {start_epoch}") # Load best metrics if "best_dice" in checkpoint: best_dice = checkpoint["best_dice"] best_dice_epoch = checkpoint["best_dice_epoch"] print(f"✓ Best metrics restored: Dice={best_dice:.4f} (Epoch {best_dice_epoch})") # Load historical loss and metric values (optional) if "epoch_loss_values" in checkpoint: epoch_loss_values = checkpoint["epoch_loss_values"] if "dice_metric_values" in checkpoint: dice_metric_values = checkpoint["dice_metric_values"] # Load early stopping state if args.early_stopping: if "early_stopping_counter" in checkpoint: early_stopping_counter = checkpoint["early_stopping_counter"] print(f"✓ Early stopping counter restored: {early_stopping_counter}") if "should_stop" in checkpoint and checkpoint["should_stop"]: should_stop = False # Even if marked as stopped, allow continued training print("✓ Early stopping state reset, can continue training") print(f"✓ Training will continue from epoch {start_epoch}") print("=" * 60) print("\n" + "=" * 60) print("Starting training...") print("=" * 60) start_time = time.time() try: for epoch in range(start_epoch, run.config.max_epochs): # ========== Check early stopping condition ========== if should_stop and args.early_stopping: print(f"\n{'=' * 60}") print(f"Early stopping triggered! Training will terminate early at epoch {epoch + 1}") print(f"{'=' * 60}") if not has_restarted: # First early stopping: load best weights, restart training print("Early stopping detected, preparing to restart training from best model...") # 1. Find the best Dice model best_checkpoint_path = os.path.join( args.output_dir, f"best_dice_model_{args.dataset_name}.pt" ) if os.path.exists(best_checkpoint_path): print(f"Loading best Dice model: {best_checkpoint_path}") checkpoint = torch.load(best_checkpoint_path, map_location=args.device) # 2. Load best weights model.load_state_dict(checkpoint) print("✓ Model weights restored to best state") # 3. Reset optimizer optimizer, scheduler = create_optimizer(args, model) print("✓ Optimizer has been reset") # 4. Reset early stopping counter early_stopping_counter = 0 should_stop = False has_restarted = True print("✓ Training restarted from best model") print(f"{'=' * 60}\n") continue # Skip break, continue to next epoch else: print(f"Warning: Best model file not found {best_checkpoint_path}") print("Will stop training directly") # Second early stopping or best model not found: truly stop print("Training has been restarted once after early stopping, now stopping training") break # ========== Training phase ========== model.train() step = 0 epoch_loss = 0 epoch_loss_dice_ce = 0 epoch_loss_iou = 0 for batch_data in train_loader: step += 1 inputs = batch_data["image"].to(args.device) targets = batch_data["label"].to(args.device) optimizer.zero_grad() outputs = model(inputs) loss, loss_dice_ce, loss_iou = loss_function(outputs, targets) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_loss_dice_ce += loss_dice_ce.item() epoch_loss_iou += loss_iou.item() # If this is the first epoch resumed from checkpoint, print notification if epoch == start_epoch and start_epoch > 0: print(f"\n✓ Training resumed from epoch {start_epoch}") epoch_loss /= step epoch_loss_dice_ce /= step epoch_loss_iou /= step epoch_loss_values.append(epoch_loss) print(f"\nEpoch {epoch + 1}/{args.max_epochs} - Training loss: {epoch_loss:.4f}") # Log to SwanLab swanlab.log({ "train/loss": epoch_loss, "train/loss_dice_ce": epoch_loss_dice_ce, "train/loss_iou": epoch_loss_iou, "train/lr": optimizer.param_groups[0]['lr'], }, step=(epoch + 1)) # ========== Validation phase ========== model.eval() val_loss_total = 0 with torch.no_grad(): dice_metric.reset() iou_metric.reset() hd_metric.reset() for val_data in val_loader: val_images = val_data["image"].to(args.device) val_labels = val_data["label"].to(args.device) val_outputs = model(val_images) # Calculate validation loss val_loss_batch, _, _ = loss_function(val_outputs, val_labels) val_loss_total += val_loss_batch.item() # Post-processing val_outputs = torch.sigmoid(val_outputs) val_outputs = (val_outputs > 0.5).int() # Calculate Dice score dice_metric(y_pred=val_outputs, y=val_labels) iou_metric(y_pred=val_outputs, y=val_labels) hd_metric(y_pred=val_outputs, y=val_labels) # Calculate average validation loss val_loss_avg = val_loss_total / len(val_loader) # Update learning rate scheduler scheduler.step(val_loss_avg) current_lr = optimizer.param_groups[0]['lr'] # Aggregate results mean_dice = dice_metric.aggregate().item() dice_metric_values.append(mean_dice) mean_iou = iou_metric.aggregate().item() iou_metric_values.append(mean_iou) mean_hd = hd_metric.aggregate().item() hd_metric_values.append(mean_hd) print( f"Epoch {epoch + 1} - Validation Dice: {mean_dice:.4f}, Validation loss: {val_loss_avg:.4f}, Current LR: {current_lr:.2e}") swanlab.log({ "val/loss": val_loss_avg, "val/mean_dice": mean_dice, "val/mean_iou": mean_iou, "val/mean_hd": mean_hd, "val/lr": current_lr, }, step=(epoch + 1)) # ========== Early stopping check ========== if args.early_stopping: # Get current monitored metric if args.early_stopping_monitor == "dice": current_score = mean_dice best_score = best_dice is_better = current_score > best_score + args.early_stopping_min_delta elif args.early_stopping_monitor == "iou": current_score = mean_iou best_score = best_iou is_better = current_score > best_score + args.early_stopping_min_delta elif args.early_stopping_monitor == "metric": normalized_hd = 1.0 / (1.0 + mean_hd) current_score = 1 * mean_dice + 1 * mean_iou + 1 * normalized_hd best_score = best_metric is_better = current_score > best_score + args.early_stopping_min_delta else: # loss current_score = -val_loss_avg # Lower loss is better, so take negative best_score = -min(epoch_loss_values) if epoch_loss_values else float('-inf') is_better = current_score > best_score + args.early_stopping_min_delta # Check if there is improvement if is_better: early_stopping_counter = 0 print( f" ✓ {args.early_stopping_monitor.upper()} metric improved: {current_score:.4f} > {best_score:.4f}") else: early_stopping_counter += 1 print( f" ⚠ {args.early_stopping_monitor.upper()} metric did not improve, counter: {early_stopping_counter}/{args.early_stopping_patience}") # Check if early stopping should be triggered if early_stopping_counter >= args.early_stopping_patience: should_stop = True # Save best Dice model if mean_dice > best_dice: best_dice = mean_dice best_dice_epoch = epoch + 1 checkpoint_path = os.path.join(args.output_dir, f"best_dice_model_{args.dataset_name}.pt") Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True) torch.save(model.state_dict(), checkpoint_path) print( f"✓ Found better Dice model! Dice: {mean_dice:.4f}, IoU: {mean_iou:.4f}, HD: {mean_hd:.4f}, saved to {checkpoint_path}") # Save best IoU model if mean_iou > best_iou: best_iou = mean_iou best_iou_epoch = epoch + 1 checkpoint_path = os.path.join(args.output_dir, f"best_iou_model_{args.dataset_name}.pt") Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True) torch.save(model.state_dict(), checkpoint_path) print( f"✓ Found better IoU model! IoU: {mean_iou:.4f}, Dice: {mean_dice:.4f}, HD: {mean_hd:.4f}, saved to {checkpoint_path}" ) # Save best overall model normalized_hd = 1.0 / (1.0 + mean_hd) mean_metric = ( 1 * mean_dice + 1 * mean_iou + 1 * normalized_hd ) if mean_metric > best_metric: best_metric = mean_metric best_metric_epoch = epoch + 1 checkpoint_path = os.path.join(args.output_dir, f"best_metric_model_{args.dataset_name}.pt") Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True) torch.save(model.state_dict(), checkpoint_path) print( f"✓ Found better overall model! Overall score: {mean_metric:.4f}, Dice: {mean_dice:.4f}, IoU: {mean_iou:.4f}, HD: {mean_hd:.4f}, saved to {checkpoint_path}" ) # Periodically save checkpoint if (epoch + 1) % args.save_every == 0: checkpoint_path = os.path.join(args.output_dir, f"checkpoints_{args.dataset_name}", f"checkpoint_epoch={epoch}.pt") Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True) torch.save({ "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "best_dice": best_dice, "best_dice_epoch": best_dice_epoch, "epoch_loss_values": epoch_loss_values, "dice_metric_values": dice_metric_values, "iou_metric_values": iou_metric_values, "hd_metric_values": hd_metric_values, "best_metric": best_metric, "best_metric_epoch": best_metric_epoch, "best_iou": best_iou, "best_iou_epoch": best_iou_epoch, "early_stopping_counter": early_stopping_counter, "should_stop": should_stop }, checkpoint_path) print(f"✓ Checkpoint saved: {checkpoint_path}") except KeyboardInterrupt: print("\nTraining interrupted by user") finally: end_time = time.time() training_time = end_time - start_time print("\n" + "=" * 60) print("Training completed!") print(f"Total training time: {training_time / 3600:.2f} hours") print(f"Best validation Dice: {best_dice:.4f} (Epoch {best_dice_epoch})") print("=" * 60) # Close SwanLab swanlab.finish() print("✓ SwanLab experiment saved") if __name__ == "__main__": main()