| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057 |
- import argparse
- import os
- import time
- from datetime import datetime
- from pathlib import Path
- import monai
- import monai.utils
- import swanlab
- import torch
- from monai.metrics import DiceMetric, MeanIoU, HausdorffDistanceMetric
- from monai.transforms import (
- Compose, LoadImaged, ScaleIntensityd, RandFlipd, RandRotated, RandRotate90d,
- EnsureChannelFirstd, ToTensord, Resized, Lambdad, RandZoomd, RandShiftIntensityd, RandGaussianNoised,
- RandGaussianSmoothd, RandAdjustContrastd, RandHistogramShiftd,
- RandAxisFlipd, RandCoarseDropoutd,
- )
- from torch.optim import AdamW
- from torch.optim.lr_scheduler import ReduceLROnPlateau
- from datasets.PolypDetectionDataset.PolypDetectionDataset import PolypDetectionDataset
- from lib.model.model_v4_minute import Wavelet_FFT_SwinUNETR
- from lib.tools.combined_loss import CombinedDiceCEIoULoss
- def parse_args():
- """
- 解析命令行参数
- Returns:
- argparse.Namespace: 解析后的参数对象
- """
- parser = argparse.ArgumentParser(description="息肉分割模型训练脚本")
- # ==================== 数据集相关参数 ====================
- parser.add_argument(
- "--dataset_name",
- type=str,
- required=True,
- help="数据集名称"
- )
- parser.add_argument(
- "--data_root",
- type=str,
- default=r"./data/Polyp-Detection-Dataset",
- help="数据集根目录路径"
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=0,
- help="数据加载器的工作进程数"
- )
- parser.add_argument(
- "--pin_memory",
- type=bool,
- default=True,
- help="是否启用 pinned memory"
- )
- parser.add_argument(
- "--target_spatial_size",
- type=tuple,
- default=(512, 512),
- help="目标空间大小"
- )
- parser.add_argument(
- "--dataset_enhanced",
- type=bool,
- default=True,
- help="是否使用高增强数据策略"
- )
- # ==================== 模型相关参数 ====================
- parser.add_argument(
- "--in_channels",
- type=int,
- default=3,
- help="输入图像通道数"
- )
- parser.add_argument(
- "--out_channels",
- type=int,
- default=1,
- help="输出前景"
- )
- parser.add_argument(
- "--feature_size",
- type=int,
- default=48,
- help="网络特征维度"
- )
- parser.add_argument(
- "--spatial_dims",
- type=int,
- default=2,
- choices=[2, 3],
- help="空间维度(2D 或 3D)"
- )
- parser.add_argument(
- "--use_wavelet",
- type=bool,
- default=True,
- help="是否启用小波增强模块"
- )
- parser.add_argument(
- "--wavelet_J",
- type=int,
- default=2,
- help="小波分解层数"
- )
- parser.add_argument(
- "--wavelet_wave",
- type=str,
- default="db4",
- help="小波基类型"
- )
- parser.add_argument(
- "--wavelet_reduction",
- type=int,
- default=16,
- help="小波注意力压缩比例"
- )
- parser.add_argument(
- "--use_fft",
- type=bool,
- default=True,
- help="是否启用 FFT 增强模块"
- )
- parser.add_argument(
- "--use_v2",
- type=bool,
- default=True,
- help="是否启用 Swin-UNETR v2 模块"
- )
- # ==================== 训练相关参数 ====================
- parser.add_argument(
- "--max_epochs",
- type=int,
- default=1000,
- help="最大训练轮数"
- )
- parser.add_argument(
- "--batch_size",
- type=int,
- default=4,
- help="批次大小"
- )
- parser.add_argument(
- "--learning_rate",
- type=float,
- default=1e-4,
- help="学习率"
- )
- parser.add_argument(
- "--weight_decay",
- type=float,
- default=1e-4,
- help="权重衰减系数"
- )
- # ==================== 损失函数参数 ====================
- parser.add_argument(
- "--dice_weight",
- type=float,
- default=1.0,
- help="Dice 损失权重"
- )
- parser.add_argument(
- "--ce_weight",
- type=float,
- default=1.0,
- help="Cross Entropy 损失权重"
- )
- parser.add_argument(
- "--iou_weight",
- type=float,
- default=1.0,
- help="IoU 损失权重"
- )
- # ==================== SwanLab 参数 ====================
- parser.add_argument(
- "--swanlab_project",
- type=str,
- default="polyp-segmentation-v4_minute",
- help="SwanLab 项目名称"
- )
- parser.add_argument(
- "--swanlab_experiment",
- type=str,
- default=None,
- help="SwanLab 实验名称(默认使用时间戳)"
- )
- parser.add_argument(
- "--swanlab_log_dir",
- type=str,
- default="./swanlab_log",
- help="SwanLab 日志保存目录"
- )
- # ==================== 保存与加载参数 ====================
- parser.add_argument(
- "--output_dir",
- type=str,
- default="./outputs_v4_minute",
- help="模型检查点保存目录"
- )
- parser.add_argument(
- "--save_every",
- type=int,
- default=50,
- help="每隔多少个 epoch 保存一次模型"
- )
- # ==================== 早停机制参数 ====================
- parser.add_argument(
- "--early_stopping",
- type=bool,
- default=True,
- help="是否启用早停机制"
- )
- parser.add_argument(
- "--early_stopping_patience",
- type=int,
- default=100,
- help="早停耐心度(验证指标多少轮不改善则停止)"
- )
- parser.add_argument(
- "--early_stopping_min_delta",
- type=float,
- default=1e-4,
- help="最小改善阈值(指标提升小于此值视为无改善)"
- )
- parser.add_argument(
- "--early_stopping_monitor",
- type=str,
- default="dice",
- choices=["dice", "iou", "metric", "loss"],
- help="早停监控的指标"
- )
- parser.add_argument(
- "--resume",
- type=str,
- default=None,
- help="恢复训练的检查点路径。如果未指定,将自动加载最佳 Dice 模型(如果存在)"
- )
- parser.add_argument(
- "--no_auto_resume",
- action="store_false",
- dest="auto_resume",
- help="是否启用自动恢复功能(默认加载最佳 Dice 模型)"
- )
- # ==================== 其他参数 ====================
- parser.add_argument(
- "--device",
- type=str,
- default="cuda" if torch.cuda.is_available() else "cpu",
- help="训练设备(cuda 或 cpu)"
- )
- parser.add_argument(
- "--seed",
- type=int,
- default=42,
- help="随机种子"
- )
- return parser.parse_args()
- def find_best_checkpoint(args):
- """
- 查找最佳检查点文件
- Args:
- args: 命令行参数
- Returns:
- str or None: 最佳检查点路径,如果不存在则返回 None
- """
- # 查找最佳 Dice 模型
- best_dice_path = os.path.join(args.output_dir, f"best_dice_model_{args.dataset_name}.pt")
- if os.path.exists(best_dice_path):
- print(f"找到最佳 Dice 模型:{best_dice_path}")
- return best_dice_path
- # 查找最近的检查点
- checkpoint_dir = os.path.join(args.output_dir, f"checkpoints_{args.dataset_name}")
- if os.path.exists(checkpoint_dir):
- checkpoints = sorted(
- [f for f in os.listdir(checkpoint_dir) if f.endswith('.pt')],
- key=lambda x: int(x.split('epoch=')[1].split('.')[0]) if 'epoch=' in x else -1,
- reverse=True
- )
- if checkpoints:
- latest_checkpoint = os.path.join(checkpoint_dir, checkpoints[0])
- print(f"找到最新检查点:{latest_checkpoint}")
- return latest_checkpoint
- # 查找最佳综合模型
- best_metric_path = os.path.join(args.output_dir, f"best_metric_model_{args.dataset_name}.pt")
- if os.path.exists(best_metric_path):
- print(f"找到最佳综合模型:{best_metric_path}")
- return best_metric_path
- return None
- def create_enhanced_transforms(target_spatial_size=(512, 512)):
- """
- 高增强版数据增强策略
- 包含:
- 1. 几何变换:翻转、旋转、缩放、裁剪
- 2. 光度变换:亮度、对比度、gamma 校正
- 3. 噪声注入:高斯噪声、低分辨率模拟
- 4. 正则化:Coarse Dropout
- """
- def convert_label_to_single_channel(label_tensor):
- """将 RGB 标签转为单通道二值掩码"""
- single_channel = label_tensor[0:1, :, :]
- binary_label = (single_channel > 127).float()
- return binary_label
- train_transforms = Compose([
- # ========== 加载与预处理 ==========
- LoadImaged(keys=["image", "label"]),
- EnsureChannelFirstd(keys=["image", "label"]),
- Lambdad(keys=["label"], func=convert_label_to_single_channel),
- # ========== 空间变换 ==========
- Resized(keys=["image", "label"], spatial_size=target_spatial_size,
- mode=("bilinear", "nearest")),
- ScaleIntensityd(keys=["image"]),
- # --- 几何增强 ---
- # 随机轴翻转
- RandAxisFlipd(keys=["image", "label"], prob=0.5),
- # 随机旋转(-15° 到 +15°)
- RandRotated(keys=["image", "label"], range_x=0.15, prob=0.5,
- keep_size=True, mode=("bilinear", "nearest")),
- # 随机 90 度旋转
- RandRotate90d(keys=["image", "label"], prob=0.5, max_k=2),
- # 随机缩放(0.8-1.2 倍)+ 裁剪
- RandZoomd(keys=["image", "label"], min_zoom=0.8, max_zoom=1.2,
- prob=0.5, mode=("bilinear", "nearest"), keep_size=True),
- # ========== 光度变换 ==========
- # 随机亮度调整(±20%)
- RandShiftIntensityd(keys=["image"], offsets=(-0.2, 0.2), prob=0.5),
- # 随机对比度调整(gamma 0.7-1.3)
- RandAdjustContrastd(keys=["image"], gamma=(0.7, 1.3), prob=0.5),
- # 随机直方图偏移(模拟不同染色/光照条件)
- RandHistogramShiftd(keys=["image"], num_control_points=(5, 10),
- prob=0.3),
- # ========== 噪声与质量退化 ==========
- # 随机高斯平滑(模拟模糊)
- RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.0),
- sigma_y=(0.5, 1.0), prob=0.3),
- # 随机高斯噪声
- RandGaussianNoised(keys=["image"], mean=0.0, std=0.05, prob=0.3),
- # Coarse Dropout(遮挡增强,提升鲁棒性)
- RandCoarseDropoutd(
- keys=["image"],
- holes=1,
- max_holes=3,
- spatial_size=(32, 32),
- max_spatial_size=(64, 64),
- prob=0.3
- ),
- # ========== 后处理 ==========
- ToTensord(keys=["image", "label"]),
- ])
- return train_transforms
- def create_datasets(args):
- """
- 创建训练集和验证集
- Args:
- args: 命令行参数
- Returns:
- tuple: (train_dataset, val_dataset)
- """
- print("=" * 60)
- print("正在加载数据集...")
- print("=" * 60)
- def convert_label_to_single_channel(label_tensor):
- """
- 全局函数:将 3 通道 RGB 标签转为 1 通道二值标签 (0 或 1)
- 输入:label_tensor (shape: [3, H, W], 值域 0-255)
- 输出:new_tensor (shape: [1, H, W], 值域 0 或 1)
- """
- # 1. 提取第一个通道 (R 通道)
- single_channel = label_tensor[0:1, :, :]
- # 2. 二值化处理:大于 0 的像素设为 1 (假设背景是纯黑 0,息肉是白色或其他颜色)
- # 这样确保所有像素值只能是 0 或 1,符合 out_channels=2 的要求
- binary_label = (single_channel > 127).float()
- return binary_label
- # 定义训练集变换
- train_transforms = Compose([
- LoadImaged(keys=["image", "label"]),
- EnsureChannelFirstd(keys=["image", "label"]),
- # 将标签转换为单通道(取第一个通道或转换为灰度)
- Lambdad(keys=["label"], func=convert_label_to_single_channel),
- Resized(keys=["image", "label"], spatial_size=args.target_spatial_size, mode=("bilinear", "nearest")),
- ScaleIntensityd(keys=["image"]),
- RandFlipd(keys=["image", "label"], prob=0.5),
- RandRotated(keys=["image", "label"], range_x=0.15, prob=0.5),
- ])
- if args.dataset_enhanced:
- train_transforms = create_enhanced_transforms(args.target_spatial_size)
- print("✓ 使用增强数据增强策略")
- # 定义验证集变换
- val_transforms = Compose([
- LoadImaged(keys=["image", "label"]),
- EnsureChannelFirstd(keys=["image", "label"]),
- # 将标签转换为单通道(取第一个通道或转换为灰度)
- Lambdad(keys=["label"], func=convert_label_to_single_channel),
- Resized(keys=["image", "label"], spatial_size=args.target_spatial_size, mode=("bilinear", "nearest")),
- ScaleIntensityd(keys=["image"]),
- ])
- # 创建训练集
- train_dataset = PolypDetectionDataset(
- root_dir=Path(args.data_root) / args.dataset_name,
- flag='train',
- transform=train_transforms
- )
- # 创建验证集
- val_dataset = PolypDetectionDataset(
- root_dir=Path(args.data_root) / args.dataset_name,
- flag='val',
- transform=val_transforms
- )
- print(f"✓ 训练集大小:{len(train_dataset)} 个样本")
- print(f"✓ 验证集大小:{len(val_dataset)} 个样本")
- print(f"✓ 总样本数:{len(train_dataset) + len(val_dataset)} 个样本")
- print("=" * 60)
- return train_dataset, val_dataset
- def create_dataloaders(args, train_dataset, val_dataset):
- """
- 创建数据加载器
- Args:
- args: 命令行参数
- train_dataset: 训练集
- val_dataset: 验证集
- Returns:
- tuple: (train_loader, val_loader)
- """
- train_loader = monai.data.DataLoader(
- train_dataset,
- batch_size=args.batch_size,
- shuffle=True,
- num_workers=args.num_workers,
- pin_memory=args.pin_memory,
- drop_last=True
- )
- val_loader = monai.data.DataLoader(
- val_dataset,
- batch_size=args.batch_size,
- shuffle=False,
- num_workers=args.num_workers,
- pin_memory=args.pin_memory,
- drop_last=False
- )
- print(f"✓ 训练数据加载器:{len(train_loader)} 个 batch")
- print(f"✓ 验证数据加载器:{len(val_loader)} 个 batch")
- return train_loader, val_loader
- def create_model(args):
- """
- 创建模型
- Args:
- args: 命令行参数
- Returns:
- torch.nn.Module: 初始化好的模型
- """
- print("\n" + "=" * 60)
- print("正在创建模型...")
- model = Wavelet_FFT_SwinUNETR(
- in_channels=args.in_channels,
- out_channels=args.out_channels,
- feature_size=args.feature_size,
- spatial_dims=args.spatial_dims,
- wavelet_enhancement=args.use_wavelet,
- wavelet_J=args.wavelet_J,
- wavelet_wave=args.wavelet_wave,
- wavelet_mode='symmetric',
- wavelet_reduction=args.wavelet_reduction,
- fft_enhancement=args.use_fft,
- use_v2=args.use_v2
- )
- # 打印模型信息
- total_params = sum(p.numel() for p in model.parameters())
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
- print(f"\n✓ 模型总参数量:{total_params:,}")
- print(f"✓ 可训练参数量:{trainable_params:,}")
- print(f"✓ 使用设备:{args.device}")
- print("=" * 60)
- return model
- def create_loss_function(args):
- """
- 创建损失函数
- Args:
- args: 命令行参数
- Returns:
- Callable: 损失函数
- """
- loss_fn = CombinedDiceCEIoULoss(
- dice_weight=args.dice_weight,
- ce_weight=args.ce_weight,
- iou_weight=args.iou_weight,
- include_background=True,
- to_onehot_y=False,
- softmax=False,
- sigmoid=True,
- )
- return loss_fn
- def create_optimizer(args, model):
- """
- 创建优化器
- Args:
- args: 命令行参数
- model: 模型
- Returns:
- Optimizer: 优化器
- """
- optimizer = AdamW(
- model.parameters(),
- lr=args.learning_rate,
- weight_decay=args.weight_decay
- )
- scheduler = ReduceLROnPlateau(
- optimizer,
- mode='min', # 验证损失越小越好
- factor=0.5, # 每次乘以 0.5
- patience=20, # 20 个 epoch 不下降则降低 LR
- threshold=1e-4, # 最小变化阈值
- cooldown=5, # 降低 LR 后的冷却期
- min_lr=1e-6 # 学习率下限
- )
- print(f"✓ 优化器:AdamW")
- print(f" - 学习率:{args.learning_rate}")
- print(f" - 权重衰减:{args.weight_decay}")
- print(f"✓ 调度器:ReduceLROnPlateau")
- print(f" - 模式:{scheduler.mode}")
- print(f" - 衰减因子:{scheduler.factor}")
- print(f" - patience:{scheduler.patience}")
- print(f" - 最小变化阈值:{scheduler.threshold}")
- print(f" - 冷却期:{scheduler.cooldown}")
- print(f" - 最小学习率:{scheduler.min_lrs}")
- return optimizer, scheduler
- def setup_swanlab(args):
- """
- 配置 SwanLab 实验跟踪
- Args:
- args: 命令行参数
- Returns:
- swanlab.Run: SwanLab 运行对象
- """
- # 如果没有指定实验名称,使用时间戳
- if args.swanlab_experiment is None:
- args.swanlab_experiment = "v2_" + args.dataset_name + "_" + datetime.now().strftime("%Y%m%d_%H%M%S")
- # 创建日志目录
- os.makedirs(args.swanlab_log_dir, exist_ok=True)
- os.makedirs(args.output_dir, exist_ok=True)
- # 初始化 SwanLab
- run = swanlab.init(
- project=args.swanlab_project,
- experiment_name=args.swanlab_experiment,
- logdir=args.swanlab_log_dir,
- config=vars(args)
- )
- print(f"\n✓ SwanLab 实验已初始化:{args.swanlab_experiment}")
- print(f" - 项目:{args.swanlab_project}")
- print(f" - 日志目录:{args.swanlab_log_dir}")
- return run
- def main():
- """
- 主训练函数
- """
- # ==================== Step 1: 解析参数 ====================
- args = parse_args()
- # 设置随机种子以确保可重复性
- torch.manual_seed(args.seed)
- if torch.cuda.is_available():
- torch.cuda.manual_seed_all(args.seed)
- print("\n" + "=" * 60)
- print("息肉分割模型训练开始")
- print("=" * 60)
- print(f"使用设备:{args.device}")
- print(f"批次大小:{args.batch_size}")
- print(f"最大轮数:{args.max_epochs}")
- if args.early_stopping:
- print(f"早停机制:启用 (patience={args.early_stopping_patience}, monitor={args.early_stopping_monitor})")
- # ==================== Step 2: 初始化 SwanLab ====================
- run = setup_swanlab(args)
- # ==================== Step 3: 创建数据集和数据加载器 ====================
- train_dataset, val_dataset = create_datasets(args)
- train_loader, val_loader = create_dataloaders(args, train_dataset, val_dataset)
- # ==================== Step 4: 创建模型、损失函数、优化器 ====================
- model = create_model(args)
- model = model.to(args.device)
- loss_function = create_loss_function(args)
- optimizer, scheduler = create_optimizer(args, model)
- # ==================== Step 5: 创建评估指标 ====================
- dice_metric = DiceMetric(reduction="mean")
- iou_metric = MeanIoU(reduction="mean")
- hd_metric = HausdorffDistanceMetric(reduction="mean")
- # ==================== Step 6: 设置训练循环 ====================
- best_dice = -1
- best_dice_epoch = -1
- best_metric = -1
- best_metric_epoch = -1
- best_iou = -1
- best_iou_epoch = -1
- epoch_loss_values = []
- dice_metric_values = []
- iou_metric_values = []
- hd_metric_values = []
- start_epoch = 0
- # ==================== 早停机制相关变量 ====================
- early_stopping_counter = 0
- should_stop = False
- has_restarted = False # 标记是否已经重启过一次
- # ==================== Step 7: 恢复训练(如果有检查点) ====================
- checkpoint_loaded = False
- checkpoint = None
- if args.resume:
- # 用户指定了检查点路径
- if not os.path.exists(args.resume):
- raise FileNotFoundError(f"检查点文件不存在:{args.resume}")
- checkpoint_path = args.resume
- print(f"\n正在从用户指定的检查点恢复训练:{checkpoint_path}")
- checkpoint = torch.load(checkpoint_path, map_location=args.device)
- checkpoint_loaded = True
- elif args.auto_resume:
- # 自动查找最佳检查点
- checkpoint_path = find_best_checkpoint(args)
- if checkpoint_path:
- print(f"\n自动恢复模式:加载 {checkpoint_path}")
- checkpoint = torch.load(checkpoint_path, map_location=args.device)
- checkpoint_loaded = True
- else:
- print("\n未找到任何检查点,将从头开始训练")
- if checkpoint_loaded:
- # 加载模型权重 - 支持 v1 到 v2 的迁移
- model_dict = model.state_dict()
- # 尝试从检查点加载
- try:
- pretrained_dict = checkpoint["model_state_dict"]
- print("✓ 从训练检查点加载模型权重")
- except KeyError:
- pretrained_dict = checkpoint
- print("从最佳 Dice 或最佳综合模型权重中加载模型权重")
- # 过滤和匹配参数(处理 v1->v2 的结构变化)
- matched_params = {}
- unmatched_params = []
- missing_params = []
- for name, param in model_dict.items():
- if name in pretrained_dict:
- pretrained_param = pretrained_dict[name]
- # 检查形状是否匹配
- if param.shape == pretrained_param.shape:
- matched_params[name] = pretrained_param
- else:
- unmatched_params.append(f"{name} (形状不匹配:{param.shape} vs {pretrained_param.shape})")
- else:
- missing_params.append(name)
- # 输出加载统计信息
- print(f"\n权重加载统计:")
- print(f" ✓ 成功匹配的参数:{len(matched_params)}/{len(model_dict)}")
- print(f" ⚠ 形状不匹配的 parameter: {len(unmatched_params)}")
- print(f" ✗ 新增的 parameter(随机初始化): {len(missing_params)}")
- if unmatched_params:
- print(f"\n形状不匹配的层:")
- for info in unmatched_params[:5]: # 只显示前 5 个
- print(f" - {info}")
- if len(unmatched_params) > 5:
- print(f" ... 还有 {len(unmatched_params) - 5} 个")
- if missing_params:
- print(f"\n新增的层 (将随机初始化):")
- for name in missing_params[:5]: # 只显示前 5 个
- print(f" - {name}")
- if len(missing_params) > 5:
- print(f" ... 还有 {len(missing_params) - 5} 个")
- # 更新预训练字典
- model_dict.update(matched_params)
- # 加载匹配的参数
- model.load_state_dict(model_dict, strict=False)
- print(f"\n✓ 模型权重加载完成 (严格模式:False)")
- print("=" * 60)
- if "optimizer_state_dict" in checkpoint:
- # 加载优化器状态
- optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
- print("✓ 优化器状态已加载")
- # 加载轮数
- # if "epoch" in checkpoint:
- # start_epoch = checkpoint["epoch"] + 1 # 从下一个 epoch 开始
- # print(f"✓ 训练轮数已恢复到 epoch {start_epoch}")
- # 加载最佳指标
- if "best_dice" in checkpoint:
- best_dice = checkpoint["best_dice"]
- best_dice_epoch = checkpoint["best_dice_epoch"]
- print(f"✓ 最佳指标已恢复:Dice={best_dice:.4f} (Epoch {best_dice_epoch})")
- # 加载历史损失和指标值(可选)
- if "epoch_loss_values" in checkpoint:
- epoch_loss_values = checkpoint["epoch_loss_values"]
- if "dice_metric_values" in checkpoint:
- dice_metric_values = checkpoint["dice_metric_values"]
- # 加载早停状态
- if args.early_stopping:
- if "early_stopping_counter" in checkpoint:
- early_stopping_counter = checkpoint["early_stopping_counter"]
- print(f"✓ 早停计数器已恢复:{early_stopping_counter}")
- if "should_stop" in checkpoint and checkpoint["should_stop"]:
- should_stop = False # 即使标记为停止,也允许继续训练
- print("✓ 早停状态已重置,可继续训练")
- print(f"✓ 训练将从 epoch {start_epoch} 继续")
- print("=" * 60)
- print("\n" + "=" * 60)
- print("开始训练...")
- print("=" * 60)
- start_time = time.time()
- try:
- for epoch in range(start_epoch, run.config.max_epochs):
- # ========== 检查早停条件 ==========
- if should_stop and args.early_stopping:
- print(f"\n{'=' * 60}")
- print(f"触发早停机制!训练将在 epoch {epoch + 1} 提前终止")
- print(f"{'=' * 60}")
- if not has_restarted:
- # 第一次早停:加载最佳权重,重启训练
- print("检测到早停,准备从最佳模型重新开始训练...")
- # 1. 查找最佳 Dice 模型
- best_checkpoint_path = os.path.join(
- args.output_dir,
- f"best_dice_model_{args.dataset_name}.pt"
- )
- if os.path.exists(best_checkpoint_path):
- print(f"加载最佳 Dice 模型:{best_checkpoint_path}")
- checkpoint = torch.load(best_checkpoint_path, map_location=args.device)
- # 2. 加载最佳权重
- model.load_state_dict(checkpoint)
- print("✓ 模型权重已恢复到最佳状态")
- # 3. 重置优化器
- optimizer, scheduler = create_optimizer(args, model)
- print("✓ 优化器已重置")
- # 4. 重置早停计数器
- early_stopping_counter = 0
- should_stop = False
- has_restarted = True
- print("✓ 已从最佳模型重新开始训练")
- print(f"{'=' * 60}\n")
- continue # 跳过 break,继续下一轮 epoch
- else:
- print(f"警告:未找到最佳模型文件 {best_checkpoint_path}")
- print("将直接停止训练")
- # 第二次早停或找不到最佳模型:真正停止
- print("早停后已重启过一次训练,现在停止训练")
- break
- # ========== 训练阶段 ==========
- model.train()
- step = 0
- epoch_loss = 0
- epoch_loss_dice_ce = 0
- epoch_loss_iou = 0
- for batch_data in train_loader:
- step += 1
- inputs = batch_data["image"].to(args.device)
- targets = batch_data["label"].to(args.device)
- optimizer.zero_grad()
- outputs = model(inputs)
- loss, loss_dice_ce, loss_iou = loss_function(outputs, targets)
- loss.backward()
- optimizer.step()
- epoch_loss += loss.item()
- epoch_loss_dice_ce += loss_dice_ce.item()
- epoch_loss_iou += loss_iou.item()
- # 如果是从检查点恢复的第一个 epoch,打印提示信息
- if epoch == start_epoch and start_epoch > 0:
- print(f"\n✓ 已从 epoch {start_epoch} 恢复训练")
- epoch_loss /= step
- epoch_loss_dice_ce /= step
- epoch_loss_iou /= step
- epoch_loss_values.append(epoch_loss)
- print(f"\nEpoch {epoch + 1}/{args.max_epochs} - 训练损失:{epoch_loss:.4f}")
- # 记录到 SwanLab
- swanlab.log({
- "train/loss": epoch_loss,
- "train/loss_dice_ce": epoch_loss_dice_ce,
- "train/loss_iou": epoch_loss_iou,
- "train/lr": optimizer.param_groups[0]['lr'],
- }, step=(epoch + 1))
- # ========== 验证阶段 ==========
- model.eval()
- val_loss_total = 0
- with torch.no_grad():
- dice_metric.reset()
- iou_metric.reset()
- hd_metric.reset()
- for val_data in val_loader:
- val_images = val_data["image"].to(args.device)
- val_labels = val_data["label"].to(args.device)
- val_outputs = model(val_images)
- # 计算验证损失
- val_loss_batch, _, _ = loss_function(val_outputs, val_labels)
- val_loss_total += val_loss_batch.item()
- # 后处理
- val_outputs = torch.sigmoid(val_outputs)
- val_outputs = (val_outputs > 0.5).int()
- # 计算 Dice 分数
- dice_metric(y_pred=val_outputs, y=val_labels)
- iou_metric(y_pred=val_outputs, y=val_labels)
- hd_metric(y_pred=val_outputs, y=val_labels)
- # 计算平均验证损失
- val_loss_avg = val_loss_total / len(val_loader)
- # 更新学习率调度器
- scheduler.step(val_loss_avg)
- current_lr = optimizer.param_groups[0]['lr']
- # 聚合结果
- mean_dice = dice_metric.aggregate().item()
- dice_metric_values.append(mean_dice)
- mean_iou = iou_metric.aggregate().item()
- iou_metric_values.append(mean_iou)
- mean_hd = hd_metric.aggregate().item()
- hd_metric_values.append(mean_hd)
- print(
- f"Epoch {epoch + 1} - 验证 Dice: {mean_dice:.4f}, 验证损失:{val_loss_avg:.4f}, 当前 LR: {current_lr:.2e}")
- swanlab.log({
- "val/loss": val_loss_avg,
- "val/mean_dice": mean_dice,
- "val/mean_iou": mean_iou,
- "val/mean_hd": mean_hd,
- "val/lr": current_lr,
- }, step=(epoch + 1))
- # ========== 早停机制检查 ==========
- if args.early_stopping:
- # 获取当前监控指标
- if args.early_stopping_monitor == "dice":
- current_score = mean_dice
- best_score = best_dice
- is_better = current_score > best_score + args.early_stopping_min_delta
- elif args.early_stopping_monitor == "iou":
- current_score = mean_iou
- best_score = best_iou
- is_better = current_score > best_score + args.early_stopping_min_delta
- elif args.early_stopping_monitor == "metric":
- normalized_hd = 1.0 / (1.0 + mean_hd)
- current_score = 1 * mean_dice + 1 * mean_iou + 1 * normalized_hd
- best_score = best_metric
- is_better = current_score > best_score + args.early_stopping_min_delta
- else: # loss
- current_score = -val_loss_avg # 损失越小越好,所以取负
- best_score = -min(epoch_loss_values) if epoch_loss_values else float('-inf')
- is_better = current_score > best_score + args.early_stopping_min_delta
- # 检查是否有改善
- if is_better:
- early_stopping_counter = 0
- print(
- f" ✓ {args.early_stopping_monitor.upper()} 指标改善:{current_score:.4f} > {best_score:.4f}")
- else:
- early_stopping_counter += 1
- print(
- f" ⚠ {args.early_stopping_monitor.upper()} 指标未改善,计数器:{early_stopping_counter}/{args.early_stopping_patience}")
- # 检查是否触发早停
- if early_stopping_counter >= args.early_stopping_patience:
- should_stop = True
- # 保存最佳Dice模型
- if mean_dice > best_dice:
- best_dice = mean_dice
- best_dice_epoch = epoch + 1
- checkpoint_path = os.path.join(args.output_dir, f"best_dice_model_{args.dataset_name}.pt")
- Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
- torch.save(model.state_dict(), checkpoint_path)
- print(
- f"✓ 发现更好的Dice模型!Dice: {mean_dice:.4f},IoU: {mean_iou:.4f},HD: {mean_hd:.4f},已保存到 {checkpoint_path}")
- # 保存最佳IoU模型
- if mean_iou > best_iou:
- best_iou = mean_iou
- best_iou_epoch = epoch + 1
- checkpoint_path = os.path.join(args.output_dir, f"best_iou_model_{args.dataset_name}.pt")
- Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
- torch.save(model.state_dict(), checkpoint_path)
- print(
- f"✓ 找到更好的IoU模型!IoU: {mean_iou:.4f},Dice: {mean_dice:.4f},HD: {mean_hd:.4f},已保存到 {checkpoint_path}"
- )
- # 保存最佳综合模型
- normalized_hd = 1.0 / (1.0 + mean_hd)
- mean_metric = (
- 1 * mean_dice +
- 1 * mean_iou +
- 1 * normalized_hd
- )
- if mean_metric > best_metric:
- best_metric = mean_metric
- best_metric_epoch = epoch + 1
- checkpoint_path = os.path.join(args.output_dir, f"best_metric_model_{args.dataset_name}.pt")
- Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
- torch.save(model.state_dict(), checkpoint_path)
- print(
- f"✓ 找到更好的综合模型!综合得分: {mean_metric:.4f},Dice: {mean_dice:.4f},IoU: {mean_iou:.4f},HD: {mean_hd:.4f},已保存到 {checkpoint_path}"
- )
- # 定期保存检查点
- if (epoch + 1) % args.save_every == 0:
- checkpoint_path = os.path.join(args.output_dir, f"checkpoints_{args.dataset_name}",
- f"checkpoint_epoch={epoch}.pt")
- Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
- torch.save({
- "epoch": epoch,
- "model_state_dict": model.state_dict(),
- "optimizer_state_dict": optimizer.state_dict(),
- "best_dice": best_dice,
- "best_dice_epoch": best_dice_epoch,
- "epoch_loss_values": epoch_loss_values,
- "dice_metric_values": dice_metric_values,
- "iou_metric_values": iou_metric_values,
- "hd_metric_values": hd_metric_values,
- "best_metric": best_metric,
- "best_metric_epoch": best_metric_epoch,
- "best_iou": best_iou,
- "best_iou_epoch": best_iou_epoch,
- "early_stopping_counter": early_stopping_counter,
- "should_stop": should_stop
- }, checkpoint_path)
- print(f"✓ 检查点已保存:{checkpoint_path}")
- except KeyboardInterrupt:
- print("\n训练被用户中断")
- finally:
- end_time = time.time()
- training_time = end_time - start_time
- print("\n" + "=" * 60)
- print("训练完成!")
- print(f"总训练时间:{training_time / 3600:.2f} 小时")
- print(f"最佳验证 Dice: {best_dice:.4f} (Epoch {best_dice_epoch})")
- print("=" * 60)
- # 关闭 SwanLab
- swanlab.finish()
- print("✓ SwanLab 实验已保存")
- if __name__ == "__main__":
- main()
|