| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053 |
- import argparse
- import os
- import time
- from datetime import datetime
- from pathlib import Path
- import monai
- import monai.utils
- import swanlab
- import torch
- from lib.model.model_v4_minute import Wavelet_FFT_SwinUNETR
- from monai.metrics import DiceMetric, MeanIoU, HausdorffDistanceMetric
- from monai.transforms import (
- Compose, LoadImaged, ScaleIntensityd, RandFlipd, RandRotated, RandRotate90d,
- EnsureChannelFirstd, ToTensord, Resized, Lambdad, RandZoomd, RandShiftIntensityd, RandGaussianNoised,
- RandGaussianSmoothd, RandAdjustContrastd, RandHistogramShiftd,
- RandAxisFlipd, RandCoarseDropoutd,
- )
- from torch.optim import AdamW
- from torch.optim.lr_scheduler import ReduceLROnPlateau
- from datasets.PolypDetectionDataset.PolypDetectionDataset import PolypDetectionDataset
- from lib.tools.combined_loss import CombinedDiceCEIoULoss
- def parse_args():
- parser = argparse.ArgumentParser(description="Polyp Segmentation Model Training Script")
- # ==================== Dataset-related Parameters ====================
- parser.add_argument(
- "--dataset_name",
- type=str,
- required=True,
- help="Dataset name"
- )
- parser.add_argument(
- "--data_root",
- type=str,
- default=r"./data/Polyp-Detection-Dataset",
- help="Root directory path of the dataset"
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=0,
- help="Number of worker processes for data loader"
- )
- parser.add_argument(
- "--pin_memory",
- type=bool,
- default=True,
- help="Whether to enable pinned memory"
- )
- parser.add_argument(
- "--target_spatial_size",
- type=tuple,
- default=(512, 512),
- help="Target spatial size"
- )
- parser.add_argument(
- "--dataset_enhanced",
- type=bool,
- default=True,
- help="Whether to use enhanced data augmentation strategy"
- )
- # ==================== Model-related Parameters ====================
- parser.add_argument(
- "--in_channels",
- type=int,
- default=3,
- help="Number of input image channels"
- )
- parser.add_argument(
- "--out_channels",
- type=int,
- default=1,
- help="Number of output foreground channels"
- )
- parser.add_argument(
- "--feature_size",
- type=int,
- default=48,
- help="Network feature dimension"
- )
- parser.add_argument(
- "--spatial_dims",
- type=int,
- default=2,
- choices=[2, 3],
- help="Spatial dimension (2D or 3D)"
- )
- parser.add_argument(
- "--no_wavelet",
- action="store_false",
- dest="use_wavelet",
- help="Whether to enable wavelet enhancement module"
- )
- parser.add_argument(
- "--wavelet_J",
- type=int,
- default=2,
- help="Wavelet decomposition levels"
- )
- parser.add_argument(
- "--wavelet_wave",
- type=str,
- default="db4",
- help="Wavelet basis type"
- )
- parser.add_argument(
- "--wavelet_reduction",
- type=int,
- default=16,
- help="Wavelet attention compression ratio"
- )
- parser.add_argument(
- "--no_fft",
- action="store_false",
- dest="use_fft",
- help="Whether to enable FFT enhancement module"
- )
- parser.add_argument(
- "--use_v2",
- type=bool,
- default=True,
- help="Whether to enable Swin-UNETR v2 module"
- )
- # ==================== Training-related Parameters ====================
- parser.add_argument(
- "--max_epochs",
- type=int,
- default=1000,
- help="Maximum number of training epochs"
- )
- parser.add_argument(
- "--batch_size",
- type=int,
- default=4,
- help="Batch size"
- )
- parser.add_argument(
- "--learning_rate",
- type=float,
- default=1e-4,
- help="Learning rate"
- )
- parser.add_argument(
- "--weight_decay",
- type=float,
- default=1e-4,
- help="Weight decay coefficient"
- )
- # ==================== Loss Function Parameters ====================
- parser.add_argument(
- "--dice_weight",
- type=float,
- default=1.0,
- help="Dice loss weight"
- )
- parser.add_argument(
- "--ce_weight",
- type=float,
- default=1.0,
- help="Cross Entropy loss weight"
- )
- parser.add_argument(
- "--iou_weight",
- type=float,
- default=1.0,
- help="IoU loss weight"
- )
- # ==================== SwanLab Parameters ====================
- parser.add_argument(
- "--swanlab_project",
- type=str,
- default="polyp-segmentation-v4_minute",
- help="SwanLab project name"
- )
- parser.add_argument(
- "--swanlab_experiment",
- type=str,
- default=None,
- help="SwanLab experiment name (default uses timestamp)"
- )
- parser.add_argument(
- "--swanlab_log_dir",
- type=str,
- default="./swanlab_log",
- help="SwanLab log directory"
- )
- # ==================== Saving and Loading Parameters ====================
- parser.add_argument(
- "--output_dir",
- type=str,
- default="./outputs_v4_minute",
- help="Directory for saving model checkpoints"
- )
- parser.add_argument(
- "--save_every",
- type=int,
- default=50,
- help="Save model every N epochs"
- )
- # ==================== Early Stopping Parameters ====================
- parser.add_argument(
- "--early_stopping",
- type=bool,
- default=True,
- help="Whether to enable early stopping"
- )
- parser.add_argument(
- "--early_stopping_patience",
- type=int,
- default=100,
- help="Early stopping patience (stop if validation metric doesn't improve for N rounds)"
- )
- parser.add_argument(
- "--early_stopping_min_delta",
- type=float,
- default=1e-4,
- help="Minimum improvement threshold (improvement below this value is considered no improvement)"
- )
- parser.add_argument(
- "--early_stopping_monitor",
- type=str,
- default="dice",
- choices=["dice", "iou", "metric", "loss"],
- help="Metric to monitor for early stopping"
- )
- parser.add_argument(
- "--resume",
- type=str,
- default=None,
- help="Checkpoint path to resume training. If not specified, will automatically load the best Dice model (if exists)"
- )
- parser.add_argument(
- "--no_auto_resume",
- action="store_false",
- dest="auto_resume",
- help="Whether to enable auto-resume functionality (default loads best Dice model)"
- )
- # ==================== Other Parameters ====================
- parser.add_argument(
- "--device",
- type=str,
- default="cuda" if torch.cuda.is_available() else "cpu",
- help="Training device (cuda or cpu)"
- )
- parser.add_argument(
- "--seed",
- type=int,
- default=42,
- help="Random seed"
- )
- return parser.parse_args()
- def find_best_checkpoint(args):
- """
- Find the best checkpoint file
- Args:
- args: Command line arguments
- Returns:
- str or None: Best checkpoint path, returns None if not exists
- """
- # Find the best Dice model
- best_dice_path = os.path.join(args.output_dir, f"best_dice_model_{args.dataset_name}.pt")
- if os.path.exists(best_dice_path):
- print(f"Found best Dice model: {best_dice_path}")
- return best_dice_path
- # Find latest checkpoint
- checkpoint_dir = os.path.join(args.output_dir, f"checkpoints_{args.dataset_name}")
- if os.path.exists(checkpoint_dir):
- checkpoints = sorted(
- [f for f in os.listdir(checkpoint_dir) if f.endswith('.pt')],
- key=lambda x: int(x.split('epoch=')[1].split('.')[0]) if 'epoch=' in x else -1,
- reverse=True
- )
- if checkpoints:
- latest_checkpoint = os.path.join(checkpoint_dir, checkpoints[0])
- print(f"Found latest checkpoint: {latest_checkpoint}")
- return latest_checkpoint
- # Find best overall model
- best_metric_path = os.path.join(args.output_dir, f"best_metric_model_{args.dataset_name}.pt")
- if os.path.exists(best_metric_path):
- print(f"Found best overall model: {best_metric_path}")
- return best_metric_path
- return None
- def create_enhanced_transforms(target_spatial_size=(512, 512)):
- """
- Enhanced data augmentation strategy
- Includes:
- 1. Geometric transformations: flip, rotation, scaling, cropping
- 2. Photometric transformations: brightness, contrast, gamma correction
- 3. Noise injection: Gaussian noise, low-resolution simulation
- 4. Regularization: Coarse Dropout
- """
- def convert_label_to_single_channel(label_tensor):
- """Convert RGB labels to single-channel binary mask"""
- single_channel = label_tensor[0:1, :, :]
- binary_label = (single_channel > 127).float()
- return binary_label
- train_transforms = Compose([
- # ========== Loading and Preprocessing ==========
- LoadImaged(keys=["image", "label"]),
- EnsureChannelFirstd(keys=["image", "label"]),
- Lambdad(keys=["label"], func=convert_label_to_single_channel),
- # ========== Spatial Transformations ==========
- Resized(keys=["image", "label"], spatial_size=target_spatial_size,
- mode=("bilinear", "nearest")),
- ScaleIntensityd(keys=["image"]),
- # --- Geometric Augmentation ---
- # Random axis flip
- RandAxisFlipd(keys=["image", "label"], prob=0.5),
- # Random rotation (-15° to +15°)
- RandRotated(keys=["image", "label"], range_x=0.15, prob=0.5,
- keep_size=True, mode=("bilinear", "nearest")),
- # Random 90-degree rotation
- RandRotate90d(keys=["image", "label"], prob=0.5, max_k=2),
- # Random zoom (0.8-1.2x) + cropping
- RandZoomd(keys=["image", "label"], min_zoom=0.8, max_zoom=1.2,
- prob=0.5, mode=("bilinear", "nearest"), keep_size=True),
- # ========== Photometric Transformations ==========
- # Random brightness adjustment (±20%)
- RandShiftIntensityd(keys=["image"], offsets=(-0.2, 0.2), prob=0.5),
- # Random contrast adjustment (gamma 0.7-1.3)
- RandAdjustContrastd(keys=["image"], gamma=(0.7, 1.3), prob=0.5),
- # Random histogram shift (simulate different staining/lighting conditions)
- RandHistogramShiftd(keys=["image"], num_control_points=(5, 10),
- prob=0.3),
- # ========== Noise and Quality Degradation ==========
- # Random Gaussian smoothing (simulate blur)
- RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.0),
- sigma_y=(0.5, 1.0), prob=0.3),
- # Random Gaussian noise
- RandGaussianNoised(keys=["image"], mean=0.0, std=0.05, prob=0.3),
- # Coarse Dropout (occlusion augmentation, improve robustness)
- RandCoarseDropoutd(
- keys=["image"],
- holes=1,
- max_holes=3,
- spatial_size=(32, 32),
- max_spatial_size=(64, 64),
- prob=0.3
- ),
- # ========== Post-processing ==========
- ToTensord(keys=["image", "label"]),
- ])
- return train_transforms
- def create_datasets(args):
- """
- Create training and validation datasets
- Args:
- args: Command line arguments
- Returns:
- tuple: (train_dataset, val_dataset)
- """
- print("=" * 60)
- print("正在加载数据集...")
- print("=" * 60)
- def convert_label_to_single_channel(label_tensor):
- """
- Global function: Convert 3-channel RGB labels to 1-channel binary labels (0 or 1)
- Input: label_tensor (shape: [3, H, W], value range 0-255)
- Output: new_tensor (shape: [1, H, W], value range 0 or 1)
- """
- # 1. Extract first channel (R channel)
- single_channel = label_tensor[0:1, :, :]
- # 2. Binarization: pixels greater than 0 are set to 1 (assuming background is pure black 0, polyp is white or other color)
- # This ensures all pixel values can only be 0 or 1, meeting the requirements for out_channels=2
- binary_label = (single_channel > 127).float()
- return binary_label
- # Define training set transformations
- train_transforms = Compose([
- LoadImaged(keys=["image", "label"]),
- EnsureChannelFirstd(keys=["image", "label"]),
- # Convert labels to single-channel (take first channel or convert to grayscale)
- Lambdad(keys=["label"], func=convert_label_to_single_channel),
- Resized(keys=["image", "label"], spatial_size=args.target_spatial_size, mode=("bilinear", "nearest")),
- ScaleIntensityd(keys=["image"]),
- RandFlipd(keys=["image", "label"], prob=0.5),
- RandRotated(keys=["image", "label"], range_x=0.15, prob=0.5),
- ])
- if args.dataset_enhanced:
- train_transforms = create_enhanced_transforms(args.target_spatial_size)
- print("✓ 使用增强数据增强策略")
- # Define validation set transformations
- val_transforms = Compose([
- LoadImaged(keys=["image", "label"]),
- EnsureChannelFirstd(keys=["image", "label"]),
- # Convert labels to single-channel (take first channel or convert to grayscale)
- Lambdad(keys=["label"], func=convert_label_to_single_channel),
- Resized(keys=["image", "label"], spatial_size=args.target_spatial_size, mode=("bilinear", "nearest")),
- ScaleIntensityd(keys=["image"]),
- ])
- # Create training dataset
- train_dataset = PolypDetectionDataset(
- root_dir=Path(args.data_root) / args.dataset_name,
- flag='train',
- transform=train_transforms
- )
- # Create validation dataset
- val_dataset = PolypDetectionDataset(
- root_dir=Path(args.data_root) / args.dataset_name,
- flag='val',
- transform=val_transforms
- )
- print(f"✓ Training set size: {len(train_dataset)} samples")
- print(f"✓ Validation set size: {len(val_dataset)} samples")
- print(f"✓ Total samples: {len(train_dataset) + len(val_dataset)} samples")
- print("=" * 60)
- return train_dataset, val_dataset
- def create_dataloaders(args, train_dataset, val_dataset):
- """
- Create data loaders
- Args:
- args: Command line arguments
- train_dataset: Training dataset
- val_dataset: Validation dataset
- Returns:
- tuple: (train_loader, val_loader)
- """
- train_loader = monai.data.DataLoader(
- train_dataset,
- batch_size=args.batch_size,
- shuffle=True,
- num_workers=args.num_workers,
- pin_memory=args.pin_memory,
- drop_last=True
- )
- val_loader = monai.data.DataLoader(
- val_dataset,
- batch_size=args.batch_size,
- shuffle=False,
- num_workers=args.num_workers,
- pin_memory=args.pin_memory,
- drop_last=False
- )
- print(f"✓ Training data loader: {len(train_loader)} batches")
- print(f"✓ Validation data loader: {len(val_loader)} batches")
- return train_loader, val_loader
- def create_model(args):
- """
- Create the model
- Args:
- args: Command line arguments
- Returns:
- torch.nn.Module: Initialized model
- """
- print("\n" + "=" * 60)
- print("Creating model...")
- model = Wavelet_FFT_SwinUNETR(
- in_channels=args.in_channels,
- out_channels=args.out_channels,
- feature_size=args.feature_size,
- spatial_dims=args.spatial_dims,
- wavelet_enhancement=args.use_wavelet,
- wavelet_J=args.wavelet_J,
- wavelet_wave=args.wavelet_wave,
- wavelet_mode='symmetric',
- wavelet_reduction=args.wavelet_reduction,
- fft_enhancement=args.use_fft,
- use_v2=args.use_v2
- )
- # Print model information
- total_params = sum(p.numel() for p in model.parameters())
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
- print(f"\n✓ Total model parameters: {total_params:,}")
- print(f"✓ Trainable parameters: {trainable_params:,}")
- print(f"✓ Using device: {args.device}")
- print("=" * 60)
- return model
- def create_loss_function(args):
- """
- Create loss function
- Args:
- args: Command line arguments
- Returns:
- Callable: Loss function
- """
- loss_fn = CombinedDiceCEIoULoss(
- dice_weight=args.dice_weight,
- ce_weight=args.ce_weight,
- iou_weight=args.iou_weight,
- include_background=True,
- to_onehot_y=False,
- softmax=False,
- sigmoid=True,
- )
- return loss_fn
- def create_optimizer(args, model):
- """
- Create optimizer
- Args:
- args: Command line arguments
- model: Model
- Returns:
- Optimizer: Optimizer
- """
- optimizer = AdamW(
- model.parameters(),
- lr=args.learning_rate,
- weight_decay=args.weight_decay
- )
- scheduler = ReduceLROnPlateau(
- optimizer,
- mode='min', # 验证损失越小越好
- factor=0.5, # 每次乘以 0.5
- patience=20, # 20 个 epoch 不下降则降低 LR
- threshold=1e-4, # 最小变化阈值
- cooldown=5, # 降低 LR 后的冷却期
- min_lr=1e-6 # 学习率下限
- )
- print(f"✓ Optimizer: AdamW")
- print(f" - Learning rate: {args.learning_rate}")
- print(f" - Weight decay: {args.weight_decay}")
- print(f"✓ Scheduler: ReduceLROnPlateau")
- print(f" - Mode: {scheduler.mode}")
- print(f" - Decay factor: {scheduler.factor}")
- print(f" - Patience: {scheduler.patience}")
- print(f" - Minimum change threshold: {scheduler.threshold}")
- print(f" - Cooldown period: {scheduler.cooldown}")
- print(f" - Minimum learning rate: {scheduler.min_lrs}")
- return optimizer, scheduler
- def setup_swanlab(args):
- """
- Configure SwanLab experiment tracking
- Args:
- args: Command line arguments
- Returns:
- swanlab.Run: SwanLab run object
- """
- # If experiment name is not specified, use timestamp
- if args.swanlab_experiment is None:
- args.swanlab_experiment = "v2_" + args.dataset_name + "_" + datetime.now().strftime("%Y%m%d_%H%M%S")
- # Create log directory
- os.makedirs(args.swanlab_log_dir, exist_ok=True)
- os.makedirs(args.output_dir, exist_ok=True)
- # Initialize SwanLab
- run = swanlab.init(
- project=args.swanlab_project,
- experiment_name=args.swanlab_experiment,
- logdir=args.swanlab_log_dir,
- config=vars(args)
- )
- print(f"\n✓ SwanLab experiment initialized: {args.swanlab_experiment}")
- print(f" - Project: {args.swanlab_project}")
- print(f" - Log directory: {args.swanlab_log_dir}")
- return run
- def main():
- """
- Main training function
- """
- # ==================== Step 1: Parse arguments ====================
- args = parse_args()
- # Set random seed for reproducibility
- torch.manual_seed(args.seed)
- if torch.cuda.is_available():
- torch.cuda.manual_seed_all(args.seed)
- print("\n" + "=" * 60)
- print("Polyp Segmentation Model Training Started")
- print("=" * 60)
- print(f"Using device: {args.device}")
- print(f"Batch size: {args.batch_size}")
- print(f"Maximum epochs: {args.max_epochs}")
- if args.early_stopping:
- print(
- f"Early stopping: enabled (patience={args.early_stopping_patience}, monitor={args.early_stopping_monitor})")
- # ==================== Step 2: Initialize SwanLab ====================
- run = setup_swanlab(args)
- # ==================== Step 3: Create datasets and data loaders ====================
- train_dataset, val_dataset = create_datasets(args)
- train_loader, val_loader = create_dataloaders(args, train_dataset, val_dataset)
- # ==================== Step 4: Create model, loss function, optimizer ====================
- model = create_model(args)
- model = model.to(args.device)
- loss_function = create_loss_function(args)
- optimizer, scheduler = create_optimizer(args, model)
- # ==================== Step 5: Create evaluation metrics ====================
- dice_metric = DiceMetric(reduction="mean")
- iou_metric = MeanIoU(reduction="mean")
- hd_metric = HausdorffDistanceMetric(reduction="mean")
- # ==================== Step 6: Setup training loop ====================
- best_dice = -1
- best_dice_epoch = -1
- best_metric = -1
- best_metric_epoch = -1
- best_iou = -1
- best_iou_epoch = -1
- epoch_loss_values = []
- dice_metric_values = []
- iou_metric_values = []
- hd_metric_values = []
- start_epoch = 0
- # ==================== Early stopping related variables ====================
- early_stopping_counter = 0
- should_stop = False
- has_restarted = False # Flag indicating whether it has been restarted once
- # ==================== Step 7: Resume training (if checkpoint exists) ====================
- checkpoint_loaded = False
- checkpoint = None
- if args.resume:
- # User specified checkpoint path
- if not os.path.exists(args.resume):
- raise FileNotFoundError(f"Checkpoint file does not exist: {args.resume}")
- checkpoint_path = args.resume
- print(f"\nResuming training from user-specified checkpoint: {checkpoint_path}")
- checkpoint = torch.load(checkpoint_path, map_location=args.device)
- checkpoint_loaded = True
- elif args.auto_resume:
- # Automatically find best checkpoint
- checkpoint_path = find_best_checkpoint(args)
- if checkpoint_path:
- print(f"\nAuto-resume mode: loading {checkpoint_path}")
- checkpoint = torch.load(checkpoint_path, map_location=args.device)
- checkpoint_loaded = True
- else:
- print("\nNo checkpoints found, starting training from scratch")
- if checkpoint_loaded:
- # Load model weights - supports migration from v1 to v2
- model_dict = model.state_dict()
- # Try to load from checkpoint
- try:
- pretrained_dict = checkpoint["model_state_dict"]
- print("✓ Model weights loaded from training checkpoint")
- except KeyError:
- pretrained_dict = checkpoint
- print("Loading model weights from best Dice or best overall model")
- # Filter and match parameters (handle structural changes from v1->v2)
- matched_params = {}
- unmatched_params = []
- missing_params = []
- for name, param in model_dict.items():
- if name in pretrained_dict:
- pretrained_param = pretrained_dict[name]
- # Check if shape matches
- if param.shape == pretrained_param.shape:
- matched_params[name] = pretrained_param
- else:
- unmatched_params.append(f"{name} (shape mismatch: {param.shape} vs {pretrained_param.shape})")
- else:
- missing_params.append(name)
- # Output loading statistics
- print(f"\nWeight loading statistics:")
- print(f" ✓ Successfully matched parameters: {len(matched_params)}/{len(model_dict)}")
- print(f" ⚠ Shape mismatched parameters: {len(unmatched_params)}")
- print(f" ✗ Newly added parameters (randomly initialized): {len(missing_params)}")
- if unmatched_params:
- print(f"\nShape mismatched layers:")
- for info in unmatched_params[:5]: # Only show first 5
- print(f" - {info}")
- if len(unmatched_params) > 5:
- print(f" ... {len(unmatched_params) - 5} more")
- if missing_params:
- print(f"\nNewly added layers (will be randomly initialized):")
- for name in missing_params[:5]: # Only show first 5
- print(f" - {name}")
- if len(missing_params) > 5:
- print(f" ... {len(missing_params) - 5} more")
- # Update pre-trained dictionary
- model_dict.update(matched_params)
- # Load matched parameters
- model.load_state_dict(model_dict, strict=False)
- print(f"\n✓ Model weights loaded (strict mode: False)")
- print("=" * 60)
- if "optimizer_state_dict" in checkpoint:
- # Load optimizer state
- optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
- print("✓ Optimizer state loaded")
- # Load epoch number
- # if "epoch" in checkpoint:
- # start_epoch = checkpoint["epoch"] + 1 # Start from next epoch
- # print(f"✓ Training epoch restored to epoch {start_epoch}")
- # Load best metrics
- if "best_dice" in checkpoint:
- best_dice = checkpoint["best_dice"]
- best_dice_epoch = checkpoint["best_dice_epoch"]
- print(f"✓ Best metrics restored: Dice={best_dice:.4f} (Epoch {best_dice_epoch})")
- # Load historical loss and metric values (optional)
- if "epoch_loss_values" in checkpoint:
- epoch_loss_values = checkpoint["epoch_loss_values"]
- if "dice_metric_values" in checkpoint:
- dice_metric_values = checkpoint["dice_metric_values"]
- # Load early stopping state
- if args.early_stopping:
- if "early_stopping_counter" in checkpoint:
- early_stopping_counter = checkpoint["early_stopping_counter"]
- print(f"✓ Early stopping counter restored: {early_stopping_counter}")
- if "should_stop" in checkpoint and checkpoint["should_stop"]:
- should_stop = False # Even if marked as stopped, allow continued training
- print("✓ Early stopping state reset, can continue training")
- print(f"✓ Training will continue from epoch {start_epoch}")
- print("=" * 60)
- print("\n" + "=" * 60)
- print("Starting training...")
- print("=" * 60)
- start_time = time.time()
- try:
- for epoch in range(start_epoch, run.config.max_epochs):
- # ========== Check early stopping condition ==========
- if should_stop and args.early_stopping:
- print(f"\n{'=' * 60}")
- print(f"Early stopping triggered! Training will terminate early at epoch {epoch + 1}")
- print(f"{'=' * 60}")
- if not has_restarted:
- # First early stopping: load best weights, restart training
- print("Early stopping detected, preparing to restart training from best model...")
- # 1. Find the best Dice model
- best_checkpoint_path = os.path.join(
- args.output_dir,
- f"best_dice_model_{args.dataset_name}.pt"
- )
- if os.path.exists(best_checkpoint_path):
- print(f"Loading best Dice model: {best_checkpoint_path}")
- checkpoint = torch.load(best_checkpoint_path, map_location=args.device)
- # 2. Load best weights
- model.load_state_dict(checkpoint)
- print("✓ Model weights restored to best state")
- # 3. Reset optimizer
- optimizer, scheduler = create_optimizer(args, model)
- print("✓ Optimizer has been reset")
- # 4. Reset early stopping counter
- early_stopping_counter = 0
- should_stop = False
- has_restarted = True
- print("✓ Training restarted from best model")
- print(f"{'=' * 60}\n")
- continue # Skip break, continue to next epoch
- else:
- print(f"Warning: Best model file not found {best_checkpoint_path}")
- print("Will stop training directly")
- # Second early stopping or best model not found: truly stop
- print("Training has been restarted once after early stopping, now stopping training")
- break
- # ========== Training phase ==========
- model.train()
- step = 0
- epoch_loss = 0
- epoch_loss_dice_ce = 0
- epoch_loss_iou = 0
- for batch_data in train_loader:
- step += 1
- inputs = batch_data["image"].to(args.device)
- targets = batch_data["label"].to(args.device)
- optimizer.zero_grad()
- outputs = model(inputs)
- loss, loss_dice_ce, loss_iou = loss_function(outputs, targets)
- loss.backward()
- optimizer.step()
- epoch_loss += loss.item()
- epoch_loss_dice_ce += loss_dice_ce.item()
- epoch_loss_iou += loss_iou.item()
- # If this is the first epoch resumed from checkpoint, print notification
- if epoch == start_epoch and start_epoch > 0:
- print(f"\n✓ Training resumed from epoch {start_epoch}")
- epoch_loss /= step
- epoch_loss_dice_ce /= step
- epoch_loss_iou /= step
- epoch_loss_values.append(epoch_loss)
- print(f"\nEpoch {epoch + 1}/{args.max_epochs} - Training loss: {epoch_loss:.4f}")
- # Log to SwanLab
- swanlab.log({
- "train/loss": epoch_loss,
- "train/loss_dice_ce": epoch_loss_dice_ce,
- "train/loss_iou": epoch_loss_iou,
- "train/lr": optimizer.param_groups[0]['lr'],
- }, step=(epoch + 1))
- # ========== Validation phase ==========
- model.eval()
- val_loss_total = 0
- with torch.no_grad():
- dice_metric.reset()
- iou_metric.reset()
- hd_metric.reset()
- for val_data in val_loader:
- val_images = val_data["image"].to(args.device)
- val_labels = val_data["label"].to(args.device)
- val_outputs = model(val_images)
- # Calculate validation loss
- val_loss_batch, _, _ = loss_function(val_outputs, val_labels)
- val_loss_total += val_loss_batch.item()
- # Post-processing
- val_outputs = torch.sigmoid(val_outputs)
- val_outputs = (val_outputs > 0.5).int()
- # Calculate Dice score
- dice_metric(y_pred=val_outputs, y=val_labels)
- iou_metric(y_pred=val_outputs, y=val_labels)
- hd_metric(y_pred=val_outputs, y=val_labels)
- # Calculate average validation loss
- val_loss_avg = val_loss_total / len(val_loader)
- # Update learning rate scheduler
- scheduler.step(val_loss_avg)
- current_lr = optimizer.param_groups[0]['lr']
- # Aggregate results
- mean_dice = dice_metric.aggregate().item()
- dice_metric_values.append(mean_dice)
- mean_iou = iou_metric.aggregate().item()
- iou_metric_values.append(mean_iou)
- mean_hd = hd_metric.aggregate().item()
- hd_metric_values.append(mean_hd)
- print(
- f"Epoch {epoch + 1} - Validation Dice: {mean_dice:.4f}, Validation loss: {val_loss_avg:.4f}, Current LR: {current_lr:.2e}")
- swanlab.log({
- "val/loss": val_loss_avg,
- "val/mean_dice": mean_dice,
- "val/mean_iou": mean_iou,
- "val/mean_hd": mean_hd,
- "val/lr": current_lr,
- }, step=(epoch + 1))
- # ========== Early stopping check ==========
- if args.early_stopping:
- # Get current monitored metric
- if args.early_stopping_monitor == "dice":
- current_score = mean_dice
- best_score = best_dice
- is_better = current_score > best_score + args.early_stopping_min_delta
- elif args.early_stopping_monitor == "iou":
- current_score = mean_iou
- best_score = best_iou
- is_better = current_score > best_score + args.early_stopping_min_delta
- elif args.early_stopping_monitor == "metric":
- normalized_hd = 1.0 / (1.0 + mean_hd)
- current_score = 1 * mean_dice + 1 * mean_iou + 1 * normalized_hd
- best_score = best_metric
- is_better = current_score > best_score + args.early_stopping_min_delta
- else: # loss
- current_score = -val_loss_avg # Lower loss is better, so take negative
- best_score = -min(epoch_loss_values) if epoch_loss_values else float('-inf')
- is_better = current_score > best_score + args.early_stopping_min_delta
- # Check if there is improvement
- if is_better:
- early_stopping_counter = 0
- print(
- f" ✓ {args.early_stopping_monitor.upper()} metric improved: {current_score:.4f} > {best_score:.4f}")
- else:
- early_stopping_counter += 1
- print(
- f" ⚠ {args.early_stopping_monitor.upper()} metric did not improve, counter: {early_stopping_counter}/{args.early_stopping_patience}")
- # Check if early stopping should be triggered
- if early_stopping_counter >= args.early_stopping_patience:
- should_stop = True
- # Save best Dice model
- if mean_dice > best_dice:
- best_dice = mean_dice
- best_dice_epoch = epoch + 1
- checkpoint_path = os.path.join(args.output_dir, f"best_dice_model_{args.dataset_name}.pt")
- Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
- torch.save(model.state_dict(), checkpoint_path)
- print(
- f"✓ Found better Dice model! Dice: {mean_dice:.4f}, IoU: {mean_iou:.4f}, HD: {mean_hd:.4f}, saved to {checkpoint_path}")
- # Save best IoU model
- if mean_iou > best_iou:
- best_iou = mean_iou
- best_iou_epoch = epoch + 1
- checkpoint_path = os.path.join(args.output_dir, f"best_iou_model_{args.dataset_name}.pt")
- Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
- torch.save(model.state_dict(), checkpoint_path)
- print(
- f"✓ Found better IoU model! IoU: {mean_iou:.4f}, Dice: {mean_dice:.4f}, HD: {mean_hd:.4f}, saved to {checkpoint_path}"
- )
- # Save best overall model
- normalized_hd = 1.0 / (1.0 + mean_hd)
- mean_metric = (
- 1 * mean_dice +
- 1 * mean_iou +
- 1 * normalized_hd
- )
- if mean_metric > best_metric:
- best_metric = mean_metric
- best_metric_epoch = epoch + 1
- checkpoint_path = os.path.join(args.output_dir, f"best_metric_model_{args.dataset_name}.pt")
- Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
- torch.save(model.state_dict(), checkpoint_path)
- print(
- f"✓ Found better overall model! Overall score: {mean_metric:.4f}, Dice: {mean_dice:.4f}, IoU: {mean_iou:.4f}, HD: {mean_hd:.4f}, saved to {checkpoint_path}"
- )
- # Periodically save checkpoint
- if (epoch + 1) % args.save_every == 0:
- checkpoint_path = os.path.join(args.output_dir, f"checkpoints_{args.dataset_name}",
- f"checkpoint_epoch={epoch}.pt")
- Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
- torch.save({
- "epoch": epoch,
- "model_state_dict": model.state_dict(),
- "optimizer_state_dict": optimizer.state_dict(),
- "best_dice": best_dice,
- "best_dice_epoch": best_dice_epoch,
- "epoch_loss_values": epoch_loss_values,
- "dice_metric_values": dice_metric_values,
- "iou_metric_values": iou_metric_values,
- "hd_metric_values": hd_metric_values,
- "best_metric": best_metric,
- "best_metric_epoch": best_metric_epoch,
- "best_iou": best_iou,
- "best_iou_epoch": best_iou_epoch,
- "early_stopping_counter": early_stopping_counter,
- "should_stop": should_stop
- }, checkpoint_path)
- print(f"✓ Checkpoint saved: {checkpoint_path}")
- except KeyboardInterrupt:
- print("\nTraining interrupted by user")
- finally:
- end_time = time.time()
- training_time = end_time - start_time
- print("\n" + "=" * 60)
- print("Training completed!")
- print(f"Total training time: {training_time / 3600:.2f} hours")
- print(f"Best validation Dice: {best_dice:.4f} (Epoch {best_dice_epoch})")
- print("=" * 60)
- # Close SwanLab
- swanlab.finish()
- print("✓ SwanLab experiment saved")
- if __name__ == "__main__":
- main()
|