train.py 36 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057
  1. import argparse
  2. import os
  3. import time
  4. from datetime import datetime
  5. from pathlib import Path
  6. import monai
  7. import monai.utils
  8. import swanlab
  9. import torch
  10. from monai.metrics import DiceMetric, MeanIoU, HausdorffDistanceMetric
  11. from monai.transforms import (
  12. Compose, LoadImaged, ScaleIntensityd, RandFlipd, RandRotated, RandRotate90d,
  13. EnsureChannelFirstd, ToTensord, Resized, Lambdad, RandZoomd, RandShiftIntensityd, RandGaussianNoised,
  14. RandGaussianSmoothd, RandAdjustContrastd, RandHistogramShiftd,
  15. RandAxisFlipd, RandCoarseDropoutd,
  16. )
  17. from torch.optim import AdamW
  18. from torch.optim.lr_scheduler import ReduceLROnPlateau
  19. from datasets.PolypDetectionDataset.PolypDetectionDataset import PolypDetectionDataset
  20. from lib.model.model_v4_minute import Wavelet_FFT_SwinUNETR
  21. from lib.tools.combined_loss import CombinedDiceCEIoULoss
  22. def parse_args():
  23. """
  24. 解析命令行参数
  25. Returns:
  26. argparse.Namespace: 解析后的参数对象
  27. """
  28. parser = argparse.ArgumentParser(description="息肉分割模型训练脚本")
  29. # ==================== 数据集相关参数 ====================
  30. parser.add_argument(
  31. "--dataset_name",
  32. type=str,
  33. required=True,
  34. help="数据集名称"
  35. )
  36. parser.add_argument(
  37. "--data_root",
  38. type=str,
  39. default=r"./data/Polyp-Detection-Dataset",
  40. help="数据集根目录路径"
  41. )
  42. parser.add_argument(
  43. "--num_workers",
  44. type=int,
  45. default=0,
  46. help="数据加载器的工作进程数"
  47. )
  48. parser.add_argument(
  49. "--pin_memory",
  50. type=bool,
  51. default=True,
  52. help="是否启用 pinned memory"
  53. )
  54. parser.add_argument(
  55. "--target_spatial_size",
  56. type=tuple,
  57. default=(512, 512),
  58. help="目标空间大小"
  59. )
  60. parser.add_argument(
  61. "--dataset_enhanced",
  62. type=bool,
  63. default=True,
  64. help="是否使用高增强数据策略"
  65. )
  66. # ==================== 模型相关参数 ====================
  67. parser.add_argument(
  68. "--in_channels",
  69. type=int,
  70. default=3,
  71. help="输入图像通道数"
  72. )
  73. parser.add_argument(
  74. "--out_channels",
  75. type=int,
  76. default=1,
  77. help="输出前景"
  78. )
  79. parser.add_argument(
  80. "--feature_size",
  81. type=int,
  82. default=48,
  83. help="网络特征维度"
  84. )
  85. parser.add_argument(
  86. "--spatial_dims",
  87. type=int,
  88. default=2,
  89. choices=[2, 3],
  90. help="空间维度(2D 或 3D)"
  91. )
  92. parser.add_argument(
  93. "--use_wavelet",
  94. type=bool,
  95. default=True,
  96. help="是否启用小波增强模块"
  97. )
  98. parser.add_argument(
  99. "--wavelet_J",
  100. type=int,
  101. default=2,
  102. help="小波分解层数"
  103. )
  104. parser.add_argument(
  105. "--wavelet_wave",
  106. type=str,
  107. default="db4",
  108. help="小波基类型"
  109. )
  110. parser.add_argument(
  111. "--wavelet_reduction",
  112. type=int,
  113. default=16,
  114. help="小波注意力压缩比例"
  115. )
  116. parser.add_argument(
  117. "--use_fft",
  118. type=bool,
  119. default=True,
  120. help="是否启用 FFT 增强模块"
  121. )
  122. parser.add_argument(
  123. "--use_v2",
  124. type=bool,
  125. default=True,
  126. help="是否启用 Swin-UNETR v2 模块"
  127. )
  128. # ==================== 训练相关参数 ====================
  129. parser.add_argument(
  130. "--max_epochs",
  131. type=int,
  132. default=1000,
  133. help="最大训练轮数"
  134. )
  135. parser.add_argument(
  136. "--batch_size",
  137. type=int,
  138. default=4,
  139. help="批次大小"
  140. )
  141. parser.add_argument(
  142. "--learning_rate",
  143. type=float,
  144. default=1e-4,
  145. help="学习率"
  146. )
  147. parser.add_argument(
  148. "--weight_decay",
  149. type=float,
  150. default=1e-4,
  151. help="权重衰减系数"
  152. )
  153. # ==================== 损失函数参数 ====================
  154. parser.add_argument(
  155. "--dice_weight",
  156. type=float,
  157. default=1.0,
  158. help="Dice 损失权重"
  159. )
  160. parser.add_argument(
  161. "--ce_weight",
  162. type=float,
  163. default=1.0,
  164. help="Cross Entropy 损失权重"
  165. )
  166. parser.add_argument(
  167. "--iou_weight",
  168. type=float,
  169. default=1.0,
  170. help="IoU 损失权重"
  171. )
  172. # ==================== SwanLab 参数 ====================
  173. parser.add_argument(
  174. "--swanlab_project",
  175. type=str,
  176. default="polyp-segmentation-v4_minute",
  177. help="SwanLab 项目名称"
  178. )
  179. parser.add_argument(
  180. "--swanlab_experiment",
  181. type=str,
  182. default=None,
  183. help="SwanLab 实验名称(默认使用时间戳)"
  184. )
  185. parser.add_argument(
  186. "--swanlab_log_dir",
  187. type=str,
  188. default="./swanlab_log",
  189. help="SwanLab 日志保存目录"
  190. )
  191. # ==================== 保存与加载参数 ====================
  192. parser.add_argument(
  193. "--output_dir",
  194. type=str,
  195. default="./outputs_v4_minute",
  196. help="模型检查点保存目录"
  197. )
  198. parser.add_argument(
  199. "--save_every",
  200. type=int,
  201. default=50,
  202. help="每隔多少个 epoch 保存一次模型"
  203. )
  204. # ==================== 早停机制参数 ====================
  205. parser.add_argument(
  206. "--early_stopping",
  207. type=bool,
  208. default=True,
  209. help="是否启用早停机制"
  210. )
  211. parser.add_argument(
  212. "--early_stopping_patience",
  213. type=int,
  214. default=100,
  215. help="早停耐心度(验证指标多少轮不改善则停止)"
  216. )
  217. parser.add_argument(
  218. "--early_stopping_min_delta",
  219. type=float,
  220. default=1e-4,
  221. help="最小改善阈值(指标提升小于此值视为无改善)"
  222. )
  223. parser.add_argument(
  224. "--early_stopping_monitor",
  225. type=str,
  226. default="dice",
  227. choices=["dice", "iou", "metric", "loss"],
  228. help="早停监控的指标"
  229. )
  230. parser.add_argument(
  231. "--resume",
  232. type=str,
  233. default=None,
  234. help="恢复训练的检查点路径。如果未指定,将自动加载最佳 Dice 模型(如果存在)"
  235. )
  236. parser.add_argument(
  237. "--no_auto_resume",
  238. action="store_false",
  239. dest="auto_resume",
  240. help="是否启用自动恢复功能(默认加载最佳 Dice 模型)"
  241. )
  242. # ==================== 其他参数 ====================
  243. parser.add_argument(
  244. "--device",
  245. type=str,
  246. default="cuda" if torch.cuda.is_available() else "cpu",
  247. help="训练设备(cuda 或 cpu)"
  248. )
  249. parser.add_argument(
  250. "--seed",
  251. type=int,
  252. default=42,
  253. help="随机种子"
  254. )
  255. return parser.parse_args()
  256. def find_best_checkpoint(args):
  257. """
  258. 查找最佳检查点文件
  259. Args:
  260. args: 命令行参数
  261. Returns:
  262. str or None: 最佳检查点路径,如果不存在则返回 None
  263. """
  264. # 查找最佳 Dice 模型
  265. best_dice_path = os.path.join(args.output_dir, f"best_dice_model_{args.dataset_name}.pt")
  266. if os.path.exists(best_dice_path):
  267. print(f"找到最佳 Dice 模型:{best_dice_path}")
  268. return best_dice_path
  269. # 查找最近的检查点
  270. checkpoint_dir = os.path.join(args.output_dir, f"checkpoints_{args.dataset_name}")
  271. if os.path.exists(checkpoint_dir):
  272. checkpoints = sorted(
  273. [f for f in os.listdir(checkpoint_dir) if f.endswith('.pt')],
  274. key=lambda x: int(x.split('epoch=')[1].split('.')[0]) if 'epoch=' in x else -1,
  275. reverse=True
  276. )
  277. if checkpoints:
  278. latest_checkpoint = os.path.join(checkpoint_dir, checkpoints[0])
  279. print(f"找到最新检查点:{latest_checkpoint}")
  280. return latest_checkpoint
  281. # 查找最佳综合模型
  282. best_metric_path = os.path.join(args.output_dir, f"best_metric_model_{args.dataset_name}.pt")
  283. if os.path.exists(best_metric_path):
  284. print(f"找到最佳综合模型:{best_metric_path}")
  285. return best_metric_path
  286. return None
  287. def create_enhanced_transforms(target_spatial_size=(512, 512)):
  288. """
  289. 高增强版数据增强策略
  290. 包含:
  291. 1. 几何变换:翻转、旋转、缩放、裁剪
  292. 2. 光度变换:亮度、对比度、gamma 校正
  293. 3. 噪声注入:高斯噪声、低分辨率模拟
  294. 4. 正则化:Coarse Dropout
  295. """
  296. def convert_label_to_single_channel(label_tensor):
  297. """将 RGB 标签转为单通道二值掩码"""
  298. single_channel = label_tensor[0:1, :, :]
  299. binary_label = (single_channel > 127).float()
  300. return binary_label
  301. train_transforms = Compose([
  302. # ========== 加载与预处理 ==========
  303. LoadImaged(keys=["image", "label"]),
  304. EnsureChannelFirstd(keys=["image", "label"]),
  305. Lambdad(keys=["label"], func=convert_label_to_single_channel),
  306. # ========== 空间变换 ==========
  307. Resized(keys=["image", "label"], spatial_size=target_spatial_size,
  308. mode=("bilinear", "nearest")),
  309. ScaleIntensityd(keys=["image"]),
  310. # --- 几何增强 ---
  311. # 随机轴翻转
  312. RandAxisFlipd(keys=["image", "label"], prob=0.5),
  313. # 随机旋转(-15° 到 +15°)
  314. RandRotated(keys=["image", "label"], range_x=0.15, prob=0.5,
  315. keep_size=True, mode=("bilinear", "nearest")),
  316. # 随机 90 度旋转
  317. RandRotate90d(keys=["image", "label"], prob=0.5, max_k=2),
  318. # 随机缩放(0.8-1.2 倍)+ 裁剪
  319. RandZoomd(keys=["image", "label"], min_zoom=0.8, max_zoom=1.2,
  320. prob=0.5, mode=("bilinear", "nearest"), keep_size=True),
  321. # ========== 光度变换 ==========
  322. # 随机亮度调整(±20%)
  323. RandShiftIntensityd(keys=["image"], offsets=(-0.2, 0.2), prob=0.5),
  324. # 随机对比度调整(gamma 0.7-1.3)
  325. RandAdjustContrastd(keys=["image"], gamma=(0.7, 1.3), prob=0.5),
  326. # 随机直方图偏移(模拟不同染色/光照条件)
  327. RandHistogramShiftd(keys=["image"], num_control_points=(5, 10),
  328. prob=0.3),
  329. # ========== 噪声与质量退化 ==========
  330. # 随机高斯平滑(模拟模糊)
  331. RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.0),
  332. sigma_y=(0.5, 1.0), prob=0.3),
  333. # 随机高斯噪声
  334. RandGaussianNoised(keys=["image"], mean=0.0, std=0.05, prob=0.3),
  335. # Coarse Dropout(遮挡增强,提升鲁棒性)
  336. RandCoarseDropoutd(
  337. keys=["image"],
  338. holes=1,
  339. max_holes=3,
  340. spatial_size=(32, 32),
  341. max_spatial_size=(64, 64),
  342. prob=0.3
  343. ),
  344. # ========== 后处理 ==========
  345. ToTensord(keys=["image", "label"]),
  346. ])
  347. return train_transforms
  348. def create_datasets(args):
  349. """
  350. 创建训练集和验证集
  351. Args:
  352. args: 命令行参数
  353. Returns:
  354. tuple: (train_dataset, val_dataset)
  355. """
  356. print("=" * 60)
  357. print("正在加载数据集...")
  358. print("=" * 60)
  359. def convert_label_to_single_channel(label_tensor):
  360. """
  361. 全局函数:将 3 通道 RGB 标签转为 1 通道二值标签 (0 或 1)
  362. 输入:label_tensor (shape: [3, H, W], 值域 0-255)
  363. 输出:new_tensor (shape: [1, H, W], 值域 0 或 1)
  364. """
  365. # 1. 提取第一个通道 (R 通道)
  366. single_channel = label_tensor[0:1, :, :]
  367. # 2. 二值化处理:大于 0 的像素设为 1 (假设背景是纯黑 0,息肉是白色或其他颜色)
  368. # 这样确保所有像素值只能是 0 或 1,符合 out_channels=2 的要求
  369. binary_label = (single_channel > 127).float()
  370. return binary_label
  371. # 定义训练集变换
  372. train_transforms = Compose([
  373. LoadImaged(keys=["image", "label"]),
  374. EnsureChannelFirstd(keys=["image", "label"]),
  375. # 将标签转换为单通道(取第一个通道或转换为灰度)
  376. Lambdad(keys=["label"], func=convert_label_to_single_channel),
  377. Resized(keys=["image", "label"], spatial_size=args.target_spatial_size, mode=("bilinear", "nearest")),
  378. ScaleIntensityd(keys=["image"]),
  379. RandFlipd(keys=["image", "label"], prob=0.5),
  380. RandRotated(keys=["image", "label"], range_x=0.15, prob=0.5),
  381. ])
  382. if args.dataset_enhanced:
  383. train_transforms = create_enhanced_transforms(args.target_spatial_size)
  384. print("✓ 使用增强数据增强策略")
  385. # 定义验证集变换
  386. val_transforms = Compose([
  387. LoadImaged(keys=["image", "label"]),
  388. EnsureChannelFirstd(keys=["image", "label"]),
  389. # 将标签转换为单通道(取第一个通道或转换为灰度)
  390. Lambdad(keys=["label"], func=convert_label_to_single_channel),
  391. Resized(keys=["image", "label"], spatial_size=args.target_spatial_size, mode=("bilinear", "nearest")),
  392. ScaleIntensityd(keys=["image"]),
  393. ])
  394. # 创建训练集
  395. train_dataset = PolypDetectionDataset(
  396. root_dir=Path(args.data_root) / args.dataset_name,
  397. flag='train',
  398. transform=train_transforms
  399. )
  400. # 创建验证集
  401. val_dataset = PolypDetectionDataset(
  402. root_dir=Path(args.data_root) / args.dataset_name,
  403. flag='val',
  404. transform=val_transforms
  405. )
  406. print(f"✓ 训练集大小:{len(train_dataset)} 个样本")
  407. print(f"✓ 验证集大小:{len(val_dataset)} 个样本")
  408. print(f"✓ 总样本数:{len(train_dataset) + len(val_dataset)} 个样本")
  409. print("=" * 60)
  410. return train_dataset, val_dataset
  411. def create_dataloaders(args, train_dataset, val_dataset):
  412. """
  413. 创建数据加载器
  414. Args:
  415. args: 命令行参数
  416. train_dataset: 训练集
  417. val_dataset: 验证集
  418. Returns:
  419. tuple: (train_loader, val_loader)
  420. """
  421. train_loader = monai.data.DataLoader(
  422. train_dataset,
  423. batch_size=args.batch_size,
  424. shuffle=True,
  425. num_workers=args.num_workers,
  426. pin_memory=args.pin_memory,
  427. drop_last=True
  428. )
  429. val_loader = monai.data.DataLoader(
  430. val_dataset,
  431. batch_size=args.batch_size,
  432. shuffle=False,
  433. num_workers=args.num_workers,
  434. pin_memory=args.pin_memory,
  435. drop_last=False
  436. )
  437. print(f"✓ 训练数据加载器:{len(train_loader)} 个 batch")
  438. print(f"✓ 验证数据加载器:{len(val_loader)} 个 batch")
  439. return train_loader, val_loader
  440. def create_model(args):
  441. """
  442. 创建模型
  443. Args:
  444. args: 命令行参数
  445. Returns:
  446. torch.nn.Module: 初始化好的模型
  447. """
  448. print("\n" + "=" * 60)
  449. print("正在创建模型...")
  450. model = Wavelet_FFT_SwinUNETR(
  451. in_channels=args.in_channels,
  452. out_channels=args.out_channels,
  453. feature_size=args.feature_size,
  454. spatial_dims=args.spatial_dims,
  455. wavelet_enhancement=args.use_wavelet,
  456. wavelet_J=args.wavelet_J,
  457. wavelet_wave=args.wavelet_wave,
  458. wavelet_mode='symmetric',
  459. wavelet_reduction=args.wavelet_reduction,
  460. fft_enhancement=args.use_fft,
  461. use_v2=args.use_v2
  462. )
  463. # 打印模型信息
  464. total_params = sum(p.numel() for p in model.parameters())
  465. trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  466. print(f"\n✓ 模型总参数量:{total_params:,}")
  467. print(f"✓ 可训练参数量:{trainable_params:,}")
  468. print(f"✓ 使用设备:{args.device}")
  469. print("=" * 60)
  470. return model
  471. def create_loss_function(args):
  472. """
  473. 创建损失函数
  474. Args:
  475. args: 命令行参数
  476. Returns:
  477. Callable: 损失函数
  478. """
  479. loss_fn = CombinedDiceCEIoULoss(
  480. dice_weight=args.dice_weight,
  481. ce_weight=args.ce_weight,
  482. iou_weight=args.iou_weight,
  483. include_background=True,
  484. to_onehot_y=False,
  485. softmax=False,
  486. sigmoid=True,
  487. )
  488. return loss_fn
  489. def create_optimizer(args, model):
  490. """
  491. 创建优化器
  492. Args:
  493. args: 命令行参数
  494. model: 模型
  495. Returns:
  496. Optimizer: 优化器
  497. """
  498. optimizer = AdamW(
  499. model.parameters(),
  500. lr=args.learning_rate,
  501. weight_decay=args.weight_decay
  502. )
  503. scheduler = ReduceLROnPlateau(
  504. optimizer,
  505. mode='min', # 验证损失越小越好
  506. factor=0.5, # 每次乘以 0.5
  507. patience=20, # 20 个 epoch 不下降则降低 LR
  508. threshold=1e-4, # 最小变化阈值
  509. cooldown=5, # 降低 LR 后的冷却期
  510. min_lr=1e-6 # 学习率下限
  511. )
  512. print(f"✓ 优化器:AdamW")
  513. print(f" - 学习率:{args.learning_rate}")
  514. print(f" - 权重衰减:{args.weight_decay}")
  515. print(f"✓ 调度器:ReduceLROnPlateau")
  516. print(f" - 模式:{scheduler.mode}")
  517. print(f" - 衰减因子:{scheduler.factor}")
  518. print(f" - patience:{scheduler.patience}")
  519. print(f" - 最小变化阈值:{scheduler.threshold}")
  520. print(f" - 冷却期:{scheduler.cooldown}")
  521. print(f" - 最小学习率:{scheduler.min_lrs}")
  522. return optimizer, scheduler
  523. def setup_swanlab(args):
  524. """
  525. 配置 SwanLab 实验跟踪
  526. Args:
  527. args: 命令行参数
  528. Returns:
  529. swanlab.Run: SwanLab 运行对象
  530. """
  531. # 如果没有指定实验名称,使用时间戳
  532. if args.swanlab_experiment is None:
  533. args.swanlab_experiment = "v2_" + args.dataset_name + "_" + datetime.now().strftime("%Y%m%d_%H%M%S")
  534. # 创建日志目录
  535. os.makedirs(args.swanlab_log_dir, exist_ok=True)
  536. os.makedirs(args.output_dir, exist_ok=True)
  537. # 初始化 SwanLab
  538. run = swanlab.init(
  539. project=args.swanlab_project,
  540. experiment_name=args.swanlab_experiment,
  541. logdir=args.swanlab_log_dir,
  542. config=vars(args)
  543. )
  544. print(f"\n✓ SwanLab 实验已初始化:{args.swanlab_experiment}")
  545. print(f" - 项目:{args.swanlab_project}")
  546. print(f" - 日志目录:{args.swanlab_log_dir}")
  547. return run
  548. def main():
  549. """
  550. 主训练函数
  551. """
  552. # ==================== Step 1: 解析参数 ====================
  553. args = parse_args()
  554. # 设置随机种子以确保可重复性
  555. torch.manual_seed(args.seed)
  556. if torch.cuda.is_available():
  557. torch.cuda.manual_seed_all(args.seed)
  558. print("\n" + "=" * 60)
  559. print("息肉分割模型训练开始")
  560. print("=" * 60)
  561. print(f"使用设备:{args.device}")
  562. print(f"批次大小:{args.batch_size}")
  563. print(f"最大轮数:{args.max_epochs}")
  564. if args.early_stopping:
  565. print(f"早停机制:启用 (patience={args.early_stopping_patience}, monitor={args.early_stopping_monitor})")
  566. # ==================== Step 2: 初始化 SwanLab ====================
  567. run = setup_swanlab(args)
  568. # ==================== Step 3: 创建数据集和数据加载器 ====================
  569. train_dataset, val_dataset = create_datasets(args)
  570. train_loader, val_loader = create_dataloaders(args, train_dataset, val_dataset)
  571. # ==================== Step 4: 创建模型、损失函数、优化器 ====================
  572. model = create_model(args)
  573. model = model.to(args.device)
  574. loss_function = create_loss_function(args)
  575. optimizer, scheduler = create_optimizer(args, model)
  576. # ==================== Step 5: 创建评估指标 ====================
  577. dice_metric = DiceMetric(reduction="mean")
  578. iou_metric = MeanIoU(reduction="mean")
  579. hd_metric = HausdorffDistanceMetric(reduction="mean")
  580. # ==================== Step 6: 设置训练循环 ====================
  581. best_dice = -1
  582. best_dice_epoch = -1
  583. best_metric = -1
  584. best_metric_epoch = -1
  585. best_iou = -1
  586. best_iou_epoch = -1
  587. epoch_loss_values = []
  588. dice_metric_values = []
  589. iou_metric_values = []
  590. hd_metric_values = []
  591. start_epoch = 0
  592. # ==================== 早停机制相关变量 ====================
  593. early_stopping_counter = 0
  594. should_stop = False
  595. has_restarted = False # 标记是否已经重启过一次
  596. # ==================== Step 7: 恢复训练(如果有检查点) ====================
  597. checkpoint_loaded = False
  598. checkpoint = None
  599. if args.resume:
  600. # 用户指定了检查点路径
  601. if not os.path.exists(args.resume):
  602. raise FileNotFoundError(f"检查点文件不存在:{args.resume}")
  603. checkpoint_path = args.resume
  604. print(f"\n正在从用户指定的检查点恢复训练:{checkpoint_path}")
  605. checkpoint = torch.load(checkpoint_path, map_location=args.device)
  606. checkpoint_loaded = True
  607. elif args.auto_resume:
  608. # 自动查找最佳检查点
  609. checkpoint_path = find_best_checkpoint(args)
  610. if checkpoint_path:
  611. print(f"\n自动恢复模式:加载 {checkpoint_path}")
  612. checkpoint = torch.load(checkpoint_path, map_location=args.device)
  613. checkpoint_loaded = True
  614. else:
  615. print("\n未找到任何检查点,将从头开始训练")
  616. if checkpoint_loaded:
  617. # 加载模型权重 - 支持 v1 到 v2 的迁移
  618. model_dict = model.state_dict()
  619. # 尝试从检查点加载
  620. try:
  621. pretrained_dict = checkpoint["model_state_dict"]
  622. print("✓ 从训练检查点加载模型权重")
  623. except KeyError:
  624. pretrained_dict = checkpoint
  625. print("从最佳 Dice 或最佳综合模型权重中加载模型权重")
  626. # 过滤和匹配参数(处理 v1->v2 的结构变化)
  627. matched_params = {}
  628. unmatched_params = []
  629. missing_params = []
  630. for name, param in model_dict.items():
  631. if name in pretrained_dict:
  632. pretrained_param = pretrained_dict[name]
  633. # 检查形状是否匹配
  634. if param.shape == pretrained_param.shape:
  635. matched_params[name] = pretrained_param
  636. else:
  637. unmatched_params.append(f"{name} (形状不匹配:{param.shape} vs {pretrained_param.shape})")
  638. else:
  639. missing_params.append(name)
  640. # 输出加载统计信息
  641. print(f"\n权重加载统计:")
  642. print(f" ✓ 成功匹配的参数:{len(matched_params)}/{len(model_dict)}")
  643. print(f" ⚠ 形状不匹配的 parameter: {len(unmatched_params)}")
  644. print(f" ✗ 新增的 parameter(随机初始化): {len(missing_params)}")
  645. if unmatched_params:
  646. print(f"\n形状不匹配的层:")
  647. for info in unmatched_params[:5]: # 只显示前 5 个
  648. print(f" - {info}")
  649. if len(unmatched_params) > 5:
  650. print(f" ... 还有 {len(unmatched_params) - 5} 个")
  651. if missing_params:
  652. print(f"\n新增的层 (将随机初始化):")
  653. for name in missing_params[:5]: # 只显示前 5 个
  654. print(f" - {name}")
  655. if len(missing_params) > 5:
  656. print(f" ... 还有 {len(missing_params) - 5} 个")
  657. # 更新预训练字典
  658. model_dict.update(matched_params)
  659. # 加载匹配的参数
  660. model.load_state_dict(model_dict, strict=False)
  661. print(f"\n✓ 模型权重加载完成 (严格模式:False)")
  662. print("=" * 60)
  663. if "optimizer_state_dict" in checkpoint:
  664. # 加载优化器状态
  665. optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
  666. print("✓ 优化器状态已加载")
  667. # 加载轮数
  668. # if "epoch" in checkpoint:
  669. # start_epoch = checkpoint["epoch"] + 1 # 从下一个 epoch 开始
  670. # print(f"✓ 训练轮数已恢复到 epoch {start_epoch}")
  671. # 加载最佳指标
  672. if "best_dice" in checkpoint:
  673. best_dice = checkpoint["best_dice"]
  674. best_dice_epoch = checkpoint["best_dice_epoch"]
  675. print(f"✓ 最佳指标已恢复:Dice={best_dice:.4f} (Epoch {best_dice_epoch})")
  676. # 加载历史损失和指标值(可选)
  677. if "epoch_loss_values" in checkpoint:
  678. epoch_loss_values = checkpoint["epoch_loss_values"]
  679. if "dice_metric_values" in checkpoint:
  680. dice_metric_values = checkpoint["dice_metric_values"]
  681. # 加载早停状态
  682. if args.early_stopping:
  683. if "early_stopping_counter" in checkpoint:
  684. early_stopping_counter = checkpoint["early_stopping_counter"]
  685. print(f"✓ 早停计数器已恢复:{early_stopping_counter}")
  686. if "should_stop" in checkpoint and checkpoint["should_stop"]:
  687. should_stop = False # 即使标记为停止,也允许继续训练
  688. print("✓ 早停状态已重置,可继续训练")
  689. print(f"✓ 训练将从 epoch {start_epoch} 继续")
  690. print("=" * 60)
  691. print("\n" + "=" * 60)
  692. print("开始训练...")
  693. print("=" * 60)
  694. start_time = time.time()
  695. try:
  696. for epoch in range(start_epoch, run.config.max_epochs):
  697. # ========== 检查早停条件 ==========
  698. if should_stop and args.early_stopping:
  699. print(f"\n{'=' * 60}")
  700. print(f"触发早停机制!训练将在 epoch {epoch + 1} 提前终止")
  701. print(f"{'=' * 60}")
  702. if not has_restarted:
  703. # 第一次早停:加载最佳权重,重启训练
  704. print("检测到早停,准备从最佳模型重新开始训练...")
  705. # 1. 查找最佳 Dice 模型
  706. best_checkpoint_path = os.path.join(
  707. args.output_dir,
  708. f"best_dice_model_{args.dataset_name}.pt"
  709. )
  710. if os.path.exists(best_checkpoint_path):
  711. print(f"加载最佳 Dice 模型:{best_checkpoint_path}")
  712. checkpoint = torch.load(best_checkpoint_path, map_location=args.device)
  713. # 2. 加载最佳权重
  714. model.load_state_dict(checkpoint)
  715. print("✓ 模型权重已恢复到最佳状态")
  716. # 3. 重置优化器
  717. optimizer, scheduler = create_optimizer(args, model)
  718. print("✓ 优化器已重置")
  719. # 4. 重置早停计数器
  720. early_stopping_counter = 0
  721. should_stop = False
  722. has_restarted = True
  723. print("✓ 已从最佳模型重新开始训练")
  724. print(f"{'=' * 60}\n")
  725. continue # 跳过 break,继续下一轮 epoch
  726. else:
  727. print(f"警告:未找到最佳模型文件 {best_checkpoint_path}")
  728. print("将直接停止训练")
  729. # 第二次早停或找不到最佳模型:真正停止
  730. print("早停后已重启过一次训练,现在停止训练")
  731. break
  732. # ========== 训练阶段 ==========
  733. model.train()
  734. step = 0
  735. epoch_loss = 0
  736. epoch_loss_dice_ce = 0
  737. epoch_loss_iou = 0
  738. for batch_data in train_loader:
  739. step += 1
  740. inputs = batch_data["image"].to(args.device)
  741. targets = batch_data["label"].to(args.device)
  742. optimizer.zero_grad()
  743. outputs = model(inputs)
  744. loss, loss_dice_ce, loss_iou = loss_function(outputs, targets)
  745. loss.backward()
  746. optimizer.step()
  747. epoch_loss += loss.item()
  748. epoch_loss_dice_ce += loss_dice_ce.item()
  749. epoch_loss_iou += loss_iou.item()
  750. # 如果是从检查点恢复的第一个 epoch,打印提示信息
  751. if epoch == start_epoch and start_epoch > 0:
  752. print(f"\n✓ 已从 epoch {start_epoch} 恢复训练")
  753. epoch_loss /= step
  754. epoch_loss_dice_ce /= step
  755. epoch_loss_iou /= step
  756. epoch_loss_values.append(epoch_loss)
  757. print(f"\nEpoch {epoch + 1}/{args.max_epochs} - 训练损失:{epoch_loss:.4f}")
  758. # 记录到 SwanLab
  759. swanlab.log({
  760. "train/loss": epoch_loss,
  761. "train/loss_dice_ce": epoch_loss_dice_ce,
  762. "train/loss_iou": epoch_loss_iou,
  763. "train/lr": optimizer.param_groups[0]['lr'],
  764. }, step=(epoch + 1))
  765. # ========== 验证阶段 ==========
  766. model.eval()
  767. val_loss_total = 0
  768. with torch.no_grad():
  769. dice_metric.reset()
  770. iou_metric.reset()
  771. hd_metric.reset()
  772. for val_data in val_loader:
  773. val_images = val_data["image"].to(args.device)
  774. val_labels = val_data["label"].to(args.device)
  775. val_outputs = model(val_images)
  776. # 计算验证损失
  777. val_loss_batch, _, _ = loss_function(val_outputs, val_labels)
  778. val_loss_total += val_loss_batch.item()
  779. # 后处理
  780. val_outputs = torch.sigmoid(val_outputs)
  781. val_outputs = (val_outputs > 0.5).int()
  782. # 计算 Dice 分数
  783. dice_metric(y_pred=val_outputs, y=val_labels)
  784. iou_metric(y_pred=val_outputs, y=val_labels)
  785. hd_metric(y_pred=val_outputs, y=val_labels)
  786. # 计算平均验证损失
  787. val_loss_avg = val_loss_total / len(val_loader)
  788. # 更新学习率调度器
  789. scheduler.step(val_loss_avg)
  790. current_lr = optimizer.param_groups[0]['lr']
  791. # 聚合结果
  792. mean_dice = dice_metric.aggregate().item()
  793. dice_metric_values.append(mean_dice)
  794. mean_iou = iou_metric.aggregate().item()
  795. iou_metric_values.append(mean_iou)
  796. mean_hd = hd_metric.aggregate().item()
  797. hd_metric_values.append(mean_hd)
  798. print(
  799. f"Epoch {epoch + 1} - 验证 Dice: {mean_dice:.4f}, 验证损失:{val_loss_avg:.4f}, 当前 LR: {current_lr:.2e}")
  800. swanlab.log({
  801. "val/loss": val_loss_avg,
  802. "val/mean_dice": mean_dice,
  803. "val/mean_iou": mean_iou,
  804. "val/mean_hd": mean_hd,
  805. "val/lr": current_lr,
  806. }, step=(epoch + 1))
  807. # ========== 早停机制检查 ==========
  808. if args.early_stopping:
  809. # 获取当前监控指标
  810. if args.early_stopping_monitor == "dice":
  811. current_score = mean_dice
  812. best_score = best_dice
  813. is_better = current_score > best_score + args.early_stopping_min_delta
  814. elif args.early_stopping_monitor == "iou":
  815. current_score = mean_iou
  816. best_score = best_iou
  817. is_better = current_score > best_score + args.early_stopping_min_delta
  818. elif args.early_stopping_monitor == "metric":
  819. normalized_hd = 1.0 / (1.0 + mean_hd)
  820. current_score = 1 * mean_dice + 1 * mean_iou + 1 * normalized_hd
  821. best_score = best_metric
  822. is_better = current_score > best_score + args.early_stopping_min_delta
  823. else: # loss
  824. current_score = -val_loss_avg # 损失越小越好,所以取负
  825. best_score = -min(epoch_loss_values) if epoch_loss_values else float('-inf')
  826. is_better = current_score > best_score + args.early_stopping_min_delta
  827. # 检查是否有改善
  828. if is_better:
  829. early_stopping_counter = 0
  830. print(
  831. f" ✓ {args.early_stopping_monitor.upper()} 指标改善:{current_score:.4f} > {best_score:.4f}")
  832. else:
  833. early_stopping_counter += 1
  834. print(
  835. f" ⚠ {args.early_stopping_monitor.upper()} 指标未改善,计数器:{early_stopping_counter}/{args.early_stopping_patience}")
  836. # 检查是否触发早停
  837. if early_stopping_counter >= args.early_stopping_patience:
  838. should_stop = True
  839. # 保存最佳Dice模型
  840. if mean_dice > best_dice:
  841. best_dice = mean_dice
  842. best_dice_epoch = epoch + 1
  843. checkpoint_path = os.path.join(args.output_dir, f"best_dice_model_{args.dataset_name}.pt")
  844. Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
  845. torch.save(model.state_dict(), checkpoint_path)
  846. print(
  847. f"✓ 发现更好的Dice模型!Dice: {mean_dice:.4f},IoU: {mean_iou:.4f},HD: {mean_hd:.4f},已保存到 {checkpoint_path}")
  848. # 保存最佳IoU模型
  849. if mean_iou > best_iou:
  850. best_iou = mean_iou
  851. best_iou_epoch = epoch + 1
  852. checkpoint_path = os.path.join(args.output_dir, f"best_iou_model_{args.dataset_name}.pt")
  853. Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
  854. torch.save(model.state_dict(), checkpoint_path)
  855. print(
  856. f"✓ 找到更好的IoU模型!IoU: {mean_iou:.4f},Dice: {mean_dice:.4f},HD: {mean_hd:.4f},已保存到 {checkpoint_path}"
  857. )
  858. # 保存最佳综合模型
  859. normalized_hd = 1.0 / (1.0 + mean_hd)
  860. mean_metric = (
  861. 1 * mean_dice +
  862. 1 * mean_iou +
  863. 1 * normalized_hd
  864. )
  865. if mean_metric > best_metric:
  866. best_metric = mean_metric
  867. best_metric_epoch = epoch + 1
  868. checkpoint_path = os.path.join(args.output_dir, f"best_metric_model_{args.dataset_name}.pt")
  869. Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
  870. torch.save(model.state_dict(), checkpoint_path)
  871. print(
  872. f"✓ 找到更好的综合模型!综合得分: {mean_metric:.4f},Dice: {mean_dice:.4f},IoU: {mean_iou:.4f},HD: {mean_hd:.4f},已保存到 {checkpoint_path}"
  873. )
  874. # 定期保存检查点
  875. if (epoch + 1) % args.save_every == 0:
  876. checkpoint_path = os.path.join(args.output_dir, f"checkpoints_{args.dataset_name}",
  877. f"checkpoint_epoch={epoch}.pt")
  878. Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
  879. torch.save({
  880. "epoch": epoch,
  881. "model_state_dict": model.state_dict(),
  882. "optimizer_state_dict": optimizer.state_dict(),
  883. "best_dice": best_dice,
  884. "best_dice_epoch": best_dice_epoch,
  885. "epoch_loss_values": epoch_loss_values,
  886. "dice_metric_values": dice_metric_values,
  887. "iou_metric_values": iou_metric_values,
  888. "hd_metric_values": hd_metric_values,
  889. "best_metric": best_metric,
  890. "best_metric_epoch": best_metric_epoch,
  891. "best_iou": best_iou,
  892. "best_iou_epoch": best_iou_epoch,
  893. "early_stopping_counter": early_stopping_counter,
  894. "should_stop": should_stop
  895. }, checkpoint_path)
  896. print(f"✓ 检查点已保存:{checkpoint_path}")
  897. except KeyboardInterrupt:
  898. print("\n训练被用户中断")
  899. finally:
  900. end_time = time.time()
  901. training_time = end_time - start_time
  902. print("\n" + "=" * 60)
  903. print("训练完成!")
  904. print(f"总训练时间:{training_time / 3600:.2f} 小时")
  905. print(f"最佳验证 Dice: {best_dice:.4f} (Epoch {best_dice_epoch})")
  906. print("=" * 60)
  907. # 关闭 SwanLab
  908. swanlab.finish()
  909. print("✓ SwanLab 实验已保存")
  910. if __name__ == "__main__":
  911. main()