import argparse import os import time from datetime import datetime from pathlib import Path import monai import monai.utils import swanlab import torch from monai.metrics import DiceMetric, MeanIoU, HausdorffDistanceMetric from monai.transforms import ( Compose, LoadImaged, ScaleIntensityd, RandFlipd, RandRotated, RandRotate90d, EnsureChannelFirstd, ToTensord, Resized, Lambdad, RandZoomd, RandShiftIntensityd, RandGaussianNoised, RandGaussianSmoothd, RandAdjustContrastd, RandHistogramShiftd, RandAxisFlipd, RandCoarseDropoutd, ) from torch.optim import AdamW from torch.optim.lr_scheduler import ReduceLROnPlateau from datasets.PolypDetectionDataset.PolypDetectionDataset import PolypDetectionDataset from lib.model.model_v4_minute import Wavelet_FFT_SwinUNETR from lib.tools.combined_loss import CombinedDiceCEIoULoss def parse_args(): """ 解析命令行参数 Returns: argparse.Namespace: 解析后的参数对象 """ parser = argparse.ArgumentParser(description="息肉分割模型训练脚本") # ==================== 数据集相关参数 ==================== parser.add_argument( "--dataset_name", type=str, required=True, help="数据集名称" ) parser.add_argument( "--data_root", type=str, default=r"./data/Polyp-Detection-Dataset", help="数据集根目录路径" ) parser.add_argument( "--num_workers", type=int, default=0, help="数据加载器的工作进程数" ) parser.add_argument( "--pin_memory", type=bool, default=True, help="是否启用 pinned memory" ) parser.add_argument( "--target_spatial_size", type=tuple, default=(512, 512), help="目标空间大小" ) parser.add_argument( "--dataset_enhanced", type=bool, default=True, help="是否使用高增强数据策略" ) # ==================== 模型相关参数 ==================== parser.add_argument( "--in_channels", type=int, default=3, help="输入图像通道数" ) parser.add_argument( "--out_channels", type=int, default=1, help="输出前景" ) parser.add_argument( "--feature_size", type=int, default=48, help="网络特征维度" ) parser.add_argument( "--spatial_dims", type=int, default=2, choices=[2, 3], help="空间维度(2D 或 3D)" ) parser.add_argument( "--use_wavelet", type=bool, default=True, help="是否启用小波增强模块" ) parser.add_argument( "--wavelet_J", type=int, default=2, help="小波分解层数" ) parser.add_argument( "--wavelet_wave", type=str, default="db4", help="小波基类型" ) parser.add_argument( "--wavelet_reduction", type=int, default=16, help="小波注意力压缩比例" ) parser.add_argument( "--use_fft", type=bool, default=True, help="是否启用 FFT 增强模块" ) parser.add_argument( "--use_v2", type=bool, default=True, help="是否启用 Swin-UNETR v2 模块" ) # ==================== 训练相关参数 ==================== parser.add_argument( "--max_epochs", type=int, default=1000, help="最大训练轮数" ) parser.add_argument( "--batch_size", type=int, default=4, help="批次大小" ) parser.add_argument( "--learning_rate", type=float, default=1e-4, help="学习率" ) parser.add_argument( "--weight_decay", type=float, default=1e-4, help="权重衰减系数" ) # ==================== 损失函数参数 ==================== parser.add_argument( "--dice_weight", type=float, default=1.0, help="Dice 损失权重" ) parser.add_argument( "--ce_weight", type=float, default=1.0, help="Cross Entropy 损失权重" ) parser.add_argument( "--iou_weight", type=float, default=1.0, help="IoU 损失权重" ) # ==================== SwanLab 参数 ==================== parser.add_argument( "--swanlab_project", type=str, default="polyp-segmentation-v4_minute", help="SwanLab 项目名称" ) parser.add_argument( "--swanlab_experiment", type=str, default=None, help="SwanLab 实验名称(默认使用时间戳)" ) parser.add_argument( "--swanlab_log_dir", type=str, default="./swanlab_log", help="SwanLab 日志保存目录" ) # ==================== 保存与加载参数 ==================== parser.add_argument( "--output_dir", type=str, default="./outputs_v4_minute", help="模型检查点保存目录" ) parser.add_argument( "--save_every", type=int, default=50, help="每隔多少个 epoch 保存一次模型" ) # ==================== 早停机制参数 ==================== parser.add_argument( "--early_stopping", type=bool, default=True, help="是否启用早停机制" ) parser.add_argument( "--early_stopping_patience", type=int, default=100, help="早停耐心度(验证指标多少轮不改善则停止)" ) parser.add_argument( "--early_stopping_min_delta", type=float, default=1e-4, help="最小改善阈值(指标提升小于此值视为无改善)" ) parser.add_argument( "--early_stopping_monitor", type=str, default="dice", choices=["dice", "iou", "metric", "loss"], help="早停监控的指标" ) parser.add_argument( "--resume", type=str, default=None, help="恢复训练的检查点路径。如果未指定,将自动加载最佳 Dice 模型(如果存在)" ) parser.add_argument( "--no_auto_resume", action="store_false", dest="auto_resume", help="是否启用自动恢复功能(默认加载最佳 Dice 模型)" ) # ==================== 其他参数 ==================== parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="训练设备(cuda 或 cpu)" ) parser.add_argument( "--seed", type=int, default=42, help="随机种子" ) return parser.parse_args() def find_best_checkpoint(args): """ 查找最佳检查点文件 Args: args: 命令行参数 Returns: str or None: 最佳检查点路径,如果不存在则返回 None """ # 查找最佳 Dice 模型 best_dice_path = os.path.join(args.output_dir, f"best_dice_model_{args.dataset_name}.pt") if os.path.exists(best_dice_path): print(f"找到最佳 Dice 模型:{best_dice_path}") return best_dice_path # 查找最近的检查点 checkpoint_dir = os.path.join(args.output_dir, f"checkpoints_{args.dataset_name}") if os.path.exists(checkpoint_dir): checkpoints = sorted( [f for f in os.listdir(checkpoint_dir) if f.endswith('.pt')], key=lambda x: int(x.split('epoch=')[1].split('.')[0]) if 'epoch=' in x else -1, reverse=True ) if checkpoints: latest_checkpoint = os.path.join(checkpoint_dir, checkpoints[0]) print(f"找到最新检查点:{latest_checkpoint}") return latest_checkpoint # 查找最佳综合模型 best_metric_path = os.path.join(args.output_dir, f"best_metric_model_{args.dataset_name}.pt") if os.path.exists(best_metric_path): print(f"找到最佳综合模型:{best_metric_path}") return best_metric_path return None def create_enhanced_transforms(target_spatial_size=(512, 512)): """ 高增强版数据增强策略 包含: 1. 几何变换:翻转、旋转、缩放、裁剪 2. 光度变换:亮度、对比度、gamma 校正 3. 噪声注入:高斯噪声、低分辨率模拟 4. 正则化:Coarse Dropout """ def convert_label_to_single_channel(label_tensor): """将 RGB 标签转为单通道二值掩码""" single_channel = label_tensor[0:1, :, :] binary_label = (single_channel > 127).float() return binary_label train_transforms = Compose([ # ========== 加载与预处理 ========== LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Lambdad(keys=["label"], func=convert_label_to_single_channel), # ========== 空间变换 ========== Resized(keys=["image", "label"], spatial_size=target_spatial_size, mode=("bilinear", "nearest")), ScaleIntensityd(keys=["image"]), # --- 几何增强 --- # 随机轴翻转 RandAxisFlipd(keys=["image", "label"], prob=0.5), # 随机旋转(-15° 到 +15°) RandRotated(keys=["image", "label"], range_x=0.15, prob=0.5, keep_size=True, mode=("bilinear", "nearest")), # 随机 90 度旋转 RandRotate90d(keys=["image", "label"], prob=0.5, max_k=2), # 随机缩放(0.8-1.2 倍)+ 裁剪 RandZoomd(keys=["image", "label"], min_zoom=0.8, max_zoom=1.2, prob=0.5, mode=("bilinear", "nearest"), keep_size=True), # ========== 光度变换 ========== # 随机亮度调整(±20%) RandShiftIntensityd(keys=["image"], offsets=(-0.2, 0.2), prob=0.5), # 随机对比度调整(gamma 0.7-1.3) RandAdjustContrastd(keys=["image"], gamma=(0.7, 1.3), prob=0.5), # 随机直方图偏移(模拟不同染色/光照条件) RandHistogramShiftd(keys=["image"], num_control_points=(5, 10), prob=0.3), # ========== 噪声与质量退化 ========== # 随机高斯平滑(模拟模糊) RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.0), sigma_y=(0.5, 1.0), prob=0.3), # 随机高斯噪声 RandGaussianNoised(keys=["image"], mean=0.0, std=0.05, prob=0.3), # Coarse Dropout(遮挡增强,提升鲁棒性) RandCoarseDropoutd( keys=["image"], holes=1, max_holes=3, spatial_size=(32, 32), max_spatial_size=(64, 64), prob=0.3 ), # ========== 后处理 ========== ToTensord(keys=["image", "label"]), ]) return train_transforms def create_datasets(args): """ 创建训练集和验证集 Args: args: 命令行参数 Returns: tuple: (train_dataset, val_dataset) """ print("=" * 60) print("正在加载数据集...") print("=" * 60) def convert_label_to_single_channel(label_tensor): """ 全局函数:将 3 通道 RGB 标签转为 1 通道二值标签 (0 或 1) 输入:label_tensor (shape: [3, H, W], 值域 0-255) 输出:new_tensor (shape: [1, H, W], 值域 0 或 1) """ # 1. 提取第一个通道 (R 通道) single_channel = label_tensor[0:1, :, :] # 2. 二值化处理:大于 0 的像素设为 1 (假设背景是纯黑 0,息肉是白色或其他颜色) # 这样确保所有像素值只能是 0 或 1,符合 out_channels=2 的要求 binary_label = (single_channel > 127).float() return binary_label # 定义训练集变换 train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), # 将标签转换为单通道(取第一个通道或转换为灰度) Lambdad(keys=["label"], func=convert_label_to_single_channel), Resized(keys=["image", "label"], spatial_size=args.target_spatial_size, mode=("bilinear", "nearest")), ScaleIntensityd(keys=["image"]), RandFlipd(keys=["image", "label"], prob=0.5), RandRotated(keys=["image", "label"], range_x=0.15, prob=0.5), ]) if args.dataset_enhanced: train_transforms = create_enhanced_transforms(args.target_spatial_size) print("✓ 使用增强数据增强策略") # 定义验证集变换 val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), # 将标签转换为单通道(取第一个通道或转换为灰度) Lambdad(keys=["label"], func=convert_label_to_single_channel), Resized(keys=["image", "label"], spatial_size=args.target_spatial_size, mode=("bilinear", "nearest")), ScaleIntensityd(keys=["image"]), ]) # 创建训练集 train_dataset = PolypDetectionDataset( root_dir=Path(args.data_root) / args.dataset_name, flag='train', transform=train_transforms ) # 创建验证集 val_dataset = PolypDetectionDataset( root_dir=Path(args.data_root) / args.dataset_name, flag='val', transform=val_transforms ) print(f"✓ 训练集大小:{len(train_dataset)} 个样本") print(f"✓ 验证集大小:{len(val_dataset)} 个样本") print(f"✓ 总样本数:{len(train_dataset) + len(val_dataset)} 个样本") print("=" * 60) return train_dataset, val_dataset def create_dataloaders(args, train_dataset, val_dataset): """ 创建数据加载器 Args: args: 命令行参数 train_dataset: 训练集 val_dataset: 验证集 Returns: tuple: (train_loader, val_loader) """ train_loader = monai.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=args.pin_memory, drop_last=True ) val_loader = monai.data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_memory, drop_last=False ) print(f"✓ 训练数据加载器:{len(train_loader)} 个 batch") print(f"✓ 验证数据加载器:{len(val_loader)} 个 batch") return train_loader, val_loader def create_model(args): """ 创建模型 Args: args: 命令行参数 Returns: torch.nn.Module: 初始化好的模型 """ print("\n" + "=" * 60) print("正在创建模型...") model = Wavelet_FFT_SwinUNETR( in_channels=args.in_channels, out_channels=args.out_channels, feature_size=args.feature_size, spatial_dims=args.spatial_dims, wavelet_enhancement=args.use_wavelet, wavelet_J=args.wavelet_J, wavelet_wave=args.wavelet_wave, wavelet_mode='symmetric', wavelet_reduction=args.wavelet_reduction, fft_enhancement=args.use_fft, use_v2=args.use_v2 ) # 打印模型信息 total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"\n✓ 模型总参数量:{total_params:,}") print(f"✓ 可训练参数量:{trainable_params:,}") print(f"✓ 使用设备:{args.device}") print("=" * 60) return model def create_loss_function(args): """ 创建损失函数 Args: args: 命令行参数 Returns: Callable: 损失函数 """ loss_fn = CombinedDiceCEIoULoss( dice_weight=args.dice_weight, ce_weight=args.ce_weight, iou_weight=args.iou_weight, include_background=True, to_onehot_y=False, softmax=False, sigmoid=True, ) return loss_fn def create_optimizer(args, model): """ 创建优化器 Args: args: 命令行参数 model: 模型 Returns: Optimizer: 优化器 """ optimizer = AdamW( model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay ) scheduler = ReduceLROnPlateau( optimizer, mode='min', # 验证损失越小越好 factor=0.5, # 每次乘以 0.5 patience=20, # 20 个 epoch 不下降则降低 LR threshold=1e-4, # 最小变化阈值 cooldown=5, # 降低 LR 后的冷却期 min_lr=1e-6 # 学习率下限 ) print(f"✓ 优化器:AdamW") print(f" - 学习率:{args.learning_rate}") print(f" - 权重衰减:{args.weight_decay}") print(f"✓ 调度器:ReduceLROnPlateau") print(f" - 模式:{scheduler.mode}") print(f" - 衰减因子:{scheduler.factor}") print(f" - patience:{scheduler.patience}") print(f" - 最小变化阈值:{scheduler.threshold}") print(f" - 冷却期:{scheduler.cooldown}") print(f" - 最小学习率:{scheduler.min_lrs}") return optimizer, scheduler def setup_swanlab(args): """ 配置 SwanLab 实验跟踪 Args: args: 命令行参数 Returns: swanlab.Run: SwanLab 运行对象 """ # 如果没有指定实验名称,使用时间戳 if args.swanlab_experiment is None: args.swanlab_experiment = "v2_" + args.dataset_name + "_" + datetime.now().strftime("%Y%m%d_%H%M%S") # 创建日志目录 os.makedirs(args.swanlab_log_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True) # 初始化 SwanLab run = swanlab.init( project=args.swanlab_project, experiment_name=args.swanlab_experiment, logdir=args.swanlab_log_dir, config=vars(args) ) print(f"\n✓ SwanLab 实验已初始化:{args.swanlab_experiment}") print(f" - 项目:{args.swanlab_project}") print(f" - 日志目录:{args.swanlab_log_dir}") return run def main(): """ 主训练函数 """ # ==================== Step 1: 解析参数 ==================== args = parse_args() # 设置随机种子以确保可重复性 torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) print("\n" + "=" * 60) print("息肉分割模型训练开始") print("=" * 60) print(f"使用设备:{args.device}") print(f"批次大小:{args.batch_size}") print(f"最大轮数:{args.max_epochs}") if args.early_stopping: print(f"早停机制:启用 (patience={args.early_stopping_patience}, monitor={args.early_stopping_monitor})") # ==================== Step 2: 初始化 SwanLab ==================== run = setup_swanlab(args) # ==================== Step 3: 创建数据集和数据加载器 ==================== train_dataset, val_dataset = create_datasets(args) train_loader, val_loader = create_dataloaders(args, train_dataset, val_dataset) # ==================== Step 4: 创建模型、损失函数、优化器 ==================== model = create_model(args) model = model.to(args.device) loss_function = create_loss_function(args) optimizer, scheduler = create_optimizer(args, model) # ==================== Step 5: 创建评估指标 ==================== dice_metric = DiceMetric(reduction="mean") iou_metric = MeanIoU(reduction="mean") hd_metric = HausdorffDistanceMetric(reduction="mean") # ==================== Step 6: 设置训练循环 ==================== best_dice = -1 best_dice_epoch = -1 best_metric = -1 best_metric_epoch = -1 best_iou = -1 best_iou_epoch = -1 epoch_loss_values = [] dice_metric_values = [] iou_metric_values = [] hd_metric_values = [] start_epoch = 0 # ==================== 早停机制相关变量 ==================== early_stopping_counter = 0 should_stop = False has_restarted = False # 标记是否已经重启过一次 # ==================== Step 7: 恢复训练(如果有检查点) ==================== checkpoint_loaded = False checkpoint = None if args.resume: # 用户指定了检查点路径 if not os.path.exists(args.resume): raise FileNotFoundError(f"检查点文件不存在:{args.resume}") checkpoint_path = args.resume print(f"\n正在从用户指定的检查点恢复训练:{checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location=args.device) checkpoint_loaded = True elif args.auto_resume: # 自动查找最佳检查点 checkpoint_path = find_best_checkpoint(args) if checkpoint_path: print(f"\n自动恢复模式:加载 {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location=args.device) checkpoint_loaded = True else: print("\n未找到任何检查点,将从头开始训练") if checkpoint_loaded: # 加载模型权重 - 支持 v1 到 v2 的迁移 model_dict = model.state_dict() # 尝试从检查点加载 try: pretrained_dict = checkpoint["model_state_dict"] print("✓ 从训练检查点加载模型权重") except KeyError: pretrained_dict = checkpoint print("从最佳 Dice 或最佳综合模型权重中加载模型权重") # 过滤和匹配参数(处理 v1->v2 的结构变化) matched_params = {} unmatched_params = [] missing_params = [] for name, param in model_dict.items(): if name in pretrained_dict: pretrained_param = pretrained_dict[name] # 检查形状是否匹配 if param.shape == pretrained_param.shape: matched_params[name] = pretrained_param else: unmatched_params.append(f"{name} (形状不匹配:{param.shape} vs {pretrained_param.shape})") else: missing_params.append(name) # 输出加载统计信息 print(f"\n权重加载统计:") print(f" ✓ 成功匹配的参数:{len(matched_params)}/{len(model_dict)}") print(f" ⚠ 形状不匹配的 parameter: {len(unmatched_params)}") print(f" ✗ 新增的 parameter(随机初始化): {len(missing_params)}") if unmatched_params: print(f"\n形状不匹配的层:") for info in unmatched_params[:5]: # 只显示前 5 个 print(f" - {info}") if len(unmatched_params) > 5: print(f" ... 还有 {len(unmatched_params) - 5} 个") if missing_params: print(f"\n新增的层 (将随机初始化):") for name in missing_params[:5]: # 只显示前 5 个 print(f" - {name}") if len(missing_params) > 5: print(f" ... 还有 {len(missing_params) - 5} 个") # 更新预训练字典 model_dict.update(matched_params) # 加载匹配的参数 model.load_state_dict(model_dict, strict=False) print(f"\n✓ 模型权重加载完成 (严格模式:False)") print("=" * 60) if "optimizer_state_dict" in checkpoint: # 加载优化器状态 optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) print("✓ 优化器状态已加载") # 加载轮数 # if "epoch" in checkpoint: # start_epoch = checkpoint["epoch"] + 1 # 从下一个 epoch 开始 # print(f"✓ 训练轮数已恢复到 epoch {start_epoch}") # 加载最佳指标 if "best_dice" in checkpoint: best_dice = checkpoint["best_dice"] best_dice_epoch = checkpoint["best_dice_epoch"] print(f"✓ 最佳指标已恢复:Dice={best_dice:.4f} (Epoch {best_dice_epoch})") # 加载历史损失和指标值(可选) if "epoch_loss_values" in checkpoint: epoch_loss_values = checkpoint["epoch_loss_values"] if "dice_metric_values" in checkpoint: dice_metric_values = checkpoint["dice_metric_values"] # 加载早停状态 if args.early_stopping: if "early_stopping_counter" in checkpoint: early_stopping_counter = checkpoint["early_stopping_counter"] print(f"✓ 早停计数器已恢复:{early_stopping_counter}") if "should_stop" in checkpoint and checkpoint["should_stop"]: should_stop = False # 即使标记为停止,也允许继续训练 print("✓ 早停状态已重置,可继续训练") print(f"✓ 训练将从 epoch {start_epoch} 继续") print("=" * 60) print("\n" + "=" * 60) print("开始训练...") print("=" * 60) start_time = time.time() try: for epoch in range(start_epoch, run.config.max_epochs): # ========== 检查早停条件 ========== if should_stop and args.early_stopping: print(f"\n{'=' * 60}") print(f"触发早停机制!训练将在 epoch {epoch + 1} 提前终止") print(f"{'=' * 60}") if not has_restarted: # 第一次早停:加载最佳权重,重启训练 print("检测到早停,准备从最佳模型重新开始训练...") # 1. 查找最佳 Dice 模型 best_checkpoint_path = os.path.join( args.output_dir, f"best_dice_model_{args.dataset_name}.pt" ) if os.path.exists(best_checkpoint_path): print(f"加载最佳 Dice 模型:{best_checkpoint_path}") checkpoint = torch.load(best_checkpoint_path, map_location=args.device) # 2. 加载最佳权重 model.load_state_dict(checkpoint) print("✓ 模型权重已恢复到最佳状态") # 3. 重置优化器 optimizer, scheduler = create_optimizer(args, model) print("✓ 优化器已重置") # 4. 重置早停计数器 early_stopping_counter = 0 should_stop = False has_restarted = True print("✓ 已从最佳模型重新开始训练") print(f"{'=' * 60}\n") continue # 跳过 break,继续下一轮 epoch else: print(f"警告:未找到最佳模型文件 {best_checkpoint_path}") print("将直接停止训练") # 第二次早停或找不到最佳模型:真正停止 print("早停后已重启过一次训练,现在停止训练") break # ========== 训练阶段 ========== model.train() step = 0 epoch_loss = 0 epoch_loss_dice_ce = 0 epoch_loss_iou = 0 for batch_data in train_loader: step += 1 inputs = batch_data["image"].to(args.device) targets = batch_data["label"].to(args.device) optimizer.zero_grad() outputs = model(inputs) loss, loss_dice_ce, loss_iou = loss_function(outputs, targets) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_loss_dice_ce += loss_dice_ce.item() epoch_loss_iou += loss_iou.item() # 如果是从检查点恢复的第一个 epoch,打印提示信息 if epoch == start_epoch and start_epoch > 0: print(f"\n✓ 已从 epoch {start_epoch} 恢复训练") epoch_loss /= step epoch_loss_dice_ce /= step epoch_loss_iou /= step epoch_loss_values.append(epoch_loss) print(f"\nEpoch {epoch + 1}/{args.max_epochs} - 训练损失:{epoch_loss:.4f}") # 记录到 SwanLab swanlab.log({ "train/loss": epoch_loss, "train/loss_dice_ce": epoch_loss_dice_ce, "train/loss_iou": epoch_loss_iou, "train/lr": optimizer.param_groups[0]['lr'], }, step=(epoch + 1)) # ========== 验证阶段 ========== model.eval() val_loss_total = 0 with torch.no_grad(): dice_metric.reset() iou_metric.reset() hd_metric.reset() for val_data in val_loader: val_images = val_data["image"].to(args.device) val_labels = val_data["label"].to(args.device) val_outputs = model(val_images) # 计算验证损失 val_loss_batch, _, _ = loss_function(val_outputs, val_labels) val_loss_total += val_loss_batch.item() # 后处理 val_outputs = torch.sigmoid(val_outputs) val_outputs = (val_outputs > 0.5).int() # 计算 Dice 分数 dice_metric(y_pred=val_outputs, y=val_labels) iou_metric(y_pred=val_outputs, y=val_labels) hd_metric(y_pred=val_outputs, y=val_labels) # 计算平均验证损失 val_loss_avg = val_loss_total / len(val_loader) # 更新学习率调度器 scheduler.step(val_loss_avg) current_lr = optimizer.param_groups[0]['lr'] # 聚合结果 mean_dice = dice_metric.aggregate().item() dice_metric_values.append(mean_dice) mean_iou = iou_metric.aggregate().item() iou_metric_values.append(mean_iou) mean_hd = hd_metric.aggregate().item() hd_metric_values.append(mean_hd) print( f"Epoch {epoch + 1} - 验证 Dice: {mean_dice:.4f}, 验证损失:{val_loss_avg:.4f}, 当前 LR: {current_lr:.2e}") swanlab.log({ "val/loss": val_loss_avg, "val/mean_dice": mean_dice, "val/mean_iou": mean_iou, "val/mean_hd": mean_hd, "val/lr": current_lr, }, step=(epoch + 1)) # ========== 早停机制检查 ========== if args.early_stopping: # 获取当前监控指标 if args.early_stopping_monitor == "dice": current_score = mean_dice best_score = best_dice is_better = current_score > best_score + args.early_stopping_min_delta elif args.early_stopping_monitor == "iou": current_score = mean_iou best_score = best_iou is_better = current_score > best_score + args.early_stopping_min_delta elif args.early_stopping_monitor == "metric": normalized_hd = 1.0 / (1.0 + mean_hd) current_score = 1 * mean_dice + 1 * mean_iou + 1 * normalized_hd best_score = best_metric is_better = current_score > best_score + args.early_stopping_min_delta else: # loss current_score = -val_loss_avg # 损失越小越好,所以取负 best_score = -min(epoch_loss_values) if epoch_loss_values else float('-inf') is_better = current_score > best_score + args.early_stopping_min_delta # 检查是否有改善 if is_better: early_stopping_counter = 0 print( f" ✓ {args.early_stopping_monitor.upper()} 指标改善:{current_score:.4f} > {best_score:.4f}") else: early_stopping_counter += 1 print( f" ⚠ {args.early_stopping_monitor.upper()} 指标未改善,计数器:{early_stopping_counter}/{args.early_stopping_patience}") # 检查是否触发早停 if early_stopping_counter >= args.early_stopping_patience: should_stop = True # 保存最佳Dice模型 if mean_dice > best_dice: best_dice = mean_dice best_dice_epoch = epoch + 1 checkpoint_path = os.path.join(args.output_dir, f"best_dice_model_{args.dataset_name}.pt") Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True) torch.save(model.state_dict(), checkpoint_path) print( f"✓ 发现更好的Dice模型!Dice: {mean_dice:.4f},IoU: {mean_iou:.4f},HD: {mean_hd:.4f},已保存到 {checkpoint_path}") # 保存最佳IoU模型 if mean_iou > best_iou: best_iou = mean_iou best_iou_epoch = epoch + 1 checkpoint_path = os.path.join(args.output_dir, f"best_iou_model_{args.dataset_name}.pt") Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True) torch.save(model.state_dict(), checkpoint_path) print( f"✓ 找到更好的IoU模型!IoU: {mean_iou:.4f},Dice: {mean_dice:.4f},HD: {mean_hd:.4f},已保存到 {checkpoint_path}" ) # 保存最佳综合模型 normalized_hd = 1.0 / (1.0 + mean_hd) mean_metric = ( 1 * mean_dice + 1 * mean_iou + 1 * normalized_hd ) if mean_metric > best_metric: best_metric = mean_metric best_metric_epoch = epoch + 1 checkpoint_path = os.path.join(args.output_dir, f"best_metric_model_{args.dataset_name}.pt") Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True) torch.save(model.state_dict(), checkpoint_path) print( f"✓ 找到更好的综合模型!综合得分: {mean_metric:.4f},Dice: {mean_dice:.4f},IoU: {mean_iou:.4f},HD: {mean_hd:.4f},已保存到 {checkpoint_path}" ) # 定期保存检查点 if (epoch + 1) % args.save_every == 0: checkpoint_path = os.path.join(args.output_dir, f"checkpoints_{args.dataset_name}", f"checkpoint_epoch={epoch}.pt") Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True) torch.save({ "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "best_dice": best_dice, "best_dice_epoch": best_dice_epoch, "epoch_loss_values": epoch_loss_values, "dice_metric_values": dice_metric_values, "iou_metric_values": iou_metric_values, "hd_metric_values": hd_metric_values, "best_metric": best_metric, "best_metric_epoch": best_metric_epoch, "best_iou": best_iou, "best_iou_epoch": best_iou_epoch, "early_stopping_counter": early_stopping_counter, "should_stop": should_stop }, checkpoint_path) print(f"✓ 检查点已保存:{checkpoint_path}") except KeyboardInterrupt: print("\n训练被用户中断") finally: end_time = time.time() training_time = end_time - start_time print("\n" + "=" * 60) print("训练完成!") print(f"总训练时间:{training_time / 3600:.2f} 小时") print(f"最佳验证 Dice: {best_dice:.4f} (Epoch {best_dice_epoch})") print("=" * 60) # 关闭 SwanLab swanlab.finish() print("✓ SwanLab 实验已保存") if __name__ == "__main__": main()