train.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053
  1. import argparse
  2. import os
  3. import time
  4. from datetime import datetime
  5. from pathlib import Path
  6. import monai
  7. import monai.utils
  8. import swanlab
  9. import torch
  10. from lib.model.model_v4_minute import Wavelet_FFT_SwinUNETR
  11. from monai.metrics import DiceMetric, MeanIoU, HausdorffDistanceMetric
  12. from monai.transforms import (
  13. Compose, LoadImaged, ScaleIntensityd, RandFlipd, RandRotated, RandRotate90d,
  14. EnsureChannelFirstd, ToTensord, Resized, Lambdad, RandZoomd, RandShiftIntensityd, RandGaussianNoised,
  15. RandGaussianSmoothd, RandAdjustContrastd, RandHistogramShiftd,
  16. RandAxisFlipd, RandCoarseDropoutd,
  17. )
  18. from torch.optim import AdamW
  19. from torch.optim.lr_scheduler import ReduceLROnPlateau
  20. from datasets.PolypDetectionDataset.PolypDetectionDataset import PolypDetectionDataset
  21. from lib.tools.combined_loss import CombinedDiceCEIoULoss
  22. def parse_args():
  23. parser = argparse.ArgumentParser(description="Polyp Segmentation Model Training Script")
  24. # ==================== Dataset-related Parameters ====================
  25. parser.add_argument(
  26. "--dataset_name",
  27. type=str,
  28. required=True,
  29. help="Dataset name"
  30. )
  31. parser.add_argument(
  32. "--data_root",
  33. type=str,
  34. default=r"./data/Polyp-Detection-Dataset",
  35. help="Root directory path of the dataset"
  36. )
  37. parser.add_argument(
  38. "--num_workers",
  39. type=int,
  40. default=0,
  41. help="Number of worker processes for data loader"
  42. )
  43. parser.add_argument(
  44. "--pin_memory",
  45. type=bool,
  46. default=True,
  47. help="Whether to enable pinned memory"
  48. )
  49. parser.add_argument(
  50. "--target_spatial_size",
  51. type=tuple,
  52. default=(512, 512),
  53. help="Target spatial size"
  54. )
  55. parser.add_argument(
  56. "--dataset_enhanced",
  57. type=bool,
  58. default=True,
  59. help="Whether to use enhanced data augmentation strategy"
  60. )
  61. # ==================== Model-related Parameters ====================
  62. parser.add_argument(
  63. "--in_channels",
  64. type=int,
  65. default=3,
  66. help="Number of input image channels"
  67. )
  68. parser.add_argument(
  69. "--out_channels",
  70. type=int,
  71. default=1,
  72. help="Number of output foreground channels"
  73. )
  74. parser.add_argument(
  75. "--feature_size",
  76. type=int,
  77. default=48,
  78. help="Network feature dimension"
  79. )
  80. parser.add_argument(
  81. "--spatial_dims",
  82. type=int,
  83. default=2,
  84. choices=[2, 3],
  85. help="Spatial dimension (2D or 3D)"
  86. )
  87. parser.add_argument(
  88. "--no_wavelet",
  89. action="store_false",
  90. dest="use_wavelet",
  91. help="Whether to enable wavelet enhancement module"
  92. )
  93. parser.add_argument(
  94. "--wavelet_J",
  95. type=int,
  96. default=2,
  97. help="Wavelet decomposition levels"
  98. )
  99. parser.add_argument(
  100. "--wavelet_wave",
  101. type=str,
  102. default="db4",
  103. help="Wavelet basis type"
  104. )
  105. parser.add_argument(
  106. "--wavelet_reduction",
  107. type=int,
  108. default=16,
  109. help="Wavelet attention compression ratio"
  110. )
  111. parser.add_argument(
  112. "--no_fft",
  113. action="store_false",
  114. dest="use_fft",
  115. help="Whether to enable FFT enhancement module"
  116. )
  117. parser.add_argument(
  118. "--use_v2",
  119. type=bool,
  120. default=True,
  121. help="Whether to enable Swin-UNETR v2 module"
  122. )
  123. # ==================== Training-related Parameters ====================
  124. parser.add_argument(
  125. "--max_epochs",
  126. type=int,
  127. default=1000,
  128. help="Maximum number of training epochs"
  129. )
  130. parser.add_argument(
  131. "--batch_size",
  132. type=int,
  133. default=4,
  134. help="Batch size"
  135. )
  136. parser.add_argument(
  137. "--learning_rate",
  138. type=float,
  139. default=1e-4,
  140. help="Learning rate"
  141. )
  142. parser.add_argument(
  143. "--weight_decay",
  144. type=float,
  145. default=1e-4,
  146. help="Weight decay coefficient"
  147. )
  148. # ==================== Loss Function Parameters ====================
  149. parser.add_argument(
  150. "--dice_weight",
  151. type=float,
  152. default=1.0,
  153. help="Dice loss weight"
  154. )
  155. parser.add_argument(
  156. "--ce_weight",
  157. type=float,
  158. default=1.0,
  159. help="Cross Entropy loss weight"
  160. )
  161. parser.add_argument(
  162. "--iou_weight",
  163. type=float,
  164. default=1.0,
  165. help="IoU loss weight"
  166. )
  167. # ==================== SwanLab Parameters ====================
  168. parser.add_argument(
  169. "--swanlab_project",
  170. type=str,
  171. default="polyp-segmentation-v4_minute",
  172. help="SwanLab project name"
  173. )
  174. parser.add_argument(
  175. "--swanlab_experiment",
  176. type=str,
  177. default=None,
  178. help="SwanLab experiment name (default uses timestamp)"
  179. )
  180. parser.add_argument(
  181. "--swanlab_log_dir",
  182. type=str,
  183. default="./swanlab_log",
  184. help="SwanLab log directory"
  185. )
  186. # ==================== Saving and Loading Parameters ====================
  187. parser.add_argument(
  188. "--output_dir",
  189. type=str,
  190. default="./outputs_v4_minute",
  191. help="Directory for saving model checkpoints"
  192. )
  193. parser.add_argument(
  194. "--save_every",
  195. type=int,
  196. default=50,
  197. help="Save model every N epochs"
  198. )
  199. # ==================== Early Stopping Parameters ====================
  200. parser.add_argument(
  201. "--early_stopping",
  202. type=bool,
  203. default=True,
  204. help="Whether to enable early stopping"
  205. )
  206. parser.add_argument(
  207. "--early_stopping_patience",
  208. type=int,
  209. default=100,
  210. help="Early stopping patience (stop if validation metric doesn't improve for N rounds)"
  211. )
  212. parser.add_argument(
  213. "--early_stopping_min_delta",
  214. type=float,
  215. default=1e-4,
  216. help="Minimum improvement threshold (improvement below this value is considered no improvement)"
  217. )
  218. parser.add_argument(
  219. "--early_stopping_monitor",
  220. type=str,
  221. default="dice",
  222. choices=["dice", "iou", "metric", "loss"],
  223. help="Metric to monitor for early stopping"
  224. )
  225. parser.add_argument(
  226. "--resume",
  227. type=str,
  228. default=None,
  229. help="Checkpoint path to resume training. If not specified, will automatically load the best Dice model (if exists)"
  230. )
  231. parser.add_argument(
  232. "--no_auto_resume",
  233. action="store_false",
  234. dest="auto_resume",
  235. help="Whether to enable auto-resume functionality (default loads best Dice model)"
  236. )
  237. # ==================== Other Parameters ====================
  238. parser.add_argument(
  239. "--device",
  240. type=str,
  241. default="cuda" if torch.cuda.is_available() else "cpu",
  242. help="Training device (cuda or cpu)"
  243. )
  244. parser.add_argument(
  245. "--seed",
  246. type=int,
  247. default=42,
  248. help="Random seed"
  249. )
  250. return parser.parse_args()
  251. def find_best_checkpoint(args):
  252. """
  253. Find the best checkpoint file
  254. Args:
  255. args: Command line arguments
  256. Returns:
  257. str or None: Best checkpoint path, returns None if not exists
  258. """
  259. # Find the best Dice model
  260. best_dice_path = os.path.join(args.output_dir, f"best_dice_model_{args.dataset_name}.pt")
  261. if os.path.exists(best_dice_path):
  262. print(f"Found best Dice model: {best_dice_path}")
  263. return best_dice_path
  264. # Find latest checkpoint
  265. checkpoint_dir = os.path.join(args.output_dir, f"checkpoints_{args.dataset_name}")
  266. if os.path.exists(checkpoint_dir):
  267. checkpoints = sorted(
  268. [f for f in os.listdir(checkpoint_dir) if f.endswith('.pt')],
  269. key=lambda x: int(x.split('epoch=')[1].split('.')[0]) if 'epoch=' in x else -1,
  270. reverse=True
  271. )
  272. if checkpoints:
  273. latest_checkpoint = os.path.join(checkpoint_dir, checkpoints[0])
  274. print(f"Found latest checkpoint: {latest_checkpoint}")
  275. return latest_checkpoint
  276. # Find best overall model
  277. best_metric_path = os.path.join(args.output_dir, f"best_metric_model_{args.dataset_name}.pt")
  278. if os.path.exists(best_metric_path):
  279. print(f"Found best overall model: {best_metric_path}")
  280. return best_metric_path
  281. return None
  282. def create_enhanced_transforms(target_spatial_size=(512, 512)):
  283. """
  284. Enhanced data augmentation strategy
  285. Includes:
  286. 1. Geometric transformations: flip, rotation, scaling, cropping
  287. 2. Photometric transformations: brightness, contrast, gamma correction
  288. 3. Noise injection: Gaussian noise, low-resolution simulation
  289. 4. Regularization: Coarse Dropout
  290. """
  291. def convert_label_to_single_channel(label_tensor):
  292. """Convert RGB labels to single-channel binary mask"""
  293. single_channel = label_tensor[0:1, :, :]
  294. binary_label = (single_channel > 127).float()
  295. return binary_label
  296. train_transforms = Compose([
  297. # ========== Loading and Preprocessing ==========
  298. LoadImaged(keys=["image", "label"]),
  299. EnsureChannelFirstd(keys=["image", "label"]),
  300. Lambdad(keys=["label"], func=convert_label_to_single_channel),
  301. # ========== Spatial Transformations ==========
  302. Resized(keys=["image", "label"], spatial_size=target_spatial_size,
  303. mode=("bilinear", "nearest")),
  304. ScaleIntensityd(keys=["image"]),
  305. # --- Geometric Augmentation ---
  306. # Random axis flip
  307. RandAxisFlipd(keys=["image", "label"], prob=0.5),
  308. # Random rotation (-15° to +15°)
  309. RandRotated(keys=["image", "label"], range_x=0.15, prob=0.5,
  310. keep_size=True, mode=("bilinear", "nearest")),
  311. # Random 90-degree rotation
  312. RandRotate90d(keys=["image", "label"], prob=0.5, max_k=2),
  313. # Random zoom (0.8-1.2x) + cropping
  314. RandZoomd(keys=["image", "label"], min_zoom=0.8, max_zoom=1.2,
  315. prob=0.5, mode=("bilinear", "nearest"), keep_size=True),
  316. # ========== Photometric Transformations ==========
  317. # Random brightness adjustment (±20%)
  318. RandShiftIntensityd(keys=["image"], offsets=(-0.2, 0.2), prob=0.5),
  319. # Random contrast adjustment (gamma 0.7-1.3)
  320. RandAdjustContrastd(keys=["image"], gamma=(0.7, 1.3), prob=0.5),
  321. # Random histogram shift (simulate different staining/lighting conditions)
  322. RandHistogramShiftd(keys=["image"], num_control_points=(5, 10),
  323. prob=0.3),
  324. # ========== Noise and Quality Degradation ==========
  325. # Random Gaussian smoothing (simulate blur)
  326. RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.0),
  327. sigma_y=(0.5, 1.0), prob=0.3),
  328. # Random Gaussian noise
  329. RandGaussianNoised(keys=["image"], mean=0.0, std=0.05, prob=0.3),
  330. # Coarse Dropout (occlusion augmentation, improve robustness)
  331. RandCoarseDropoutd(
  332. keys=["image"],
  333. holes=1,
  334. max_holes=3,
  335. spatial_size=(32, 32),
  336. max_spatial_size=(64, 64),
  337. prob=0.3
  338. ),
  339. # ========== Post-processing ==========
  340. ToTensord(keys=["image", "label"]),
  341. ])
  342. return train_transforms
  343. def create_datasets(args):
  344. """
  345. Create training and validation datasets
  346. Args:
  347. args: Command line arguments
  348. Returns:
  349. tuple: (train_dataset, val_dataset)
  350. """
  351. print("=" * 60)
  352. print("正在加载数据集...")
  353. print("=" * 60)
  354. def convert_label_to_single_channel(label_tensor):
  355. """
  356. Global function: Convert 3-channel RGB labels to 1-channel binary labels (0 or 1)
  357. Input: label_tensor (shape: [3, H, W], value range 0-255)
  358. Output: new_tensor (shape: [1, H, W], value range 0 or 1)
  359. """
  360. # 1. Extract first channel (R channel)
  361. single_channel = label_tensor[0:1, :, :]
  362. # 2. Binarization: pixels greater than 0 are set to 1 (assuming background is pure black 0, polyp is white or other color)
  363. # This ensures all pixel values can only be 0 or 1, meeting the requirements for out_channels=2
  364. binary_label = (single_channel > 127).float()
  365. return binary_label
  366. # Define training set transformations
  367. train_transforms = Compose([
  368. LoadImaged(keys=["image", "label"]),
  369. EnsureChannelFirstd(keys=["image", "label"]),
  370. # Convert labels to single-channel (take first channel or convert to grayscale)
  371. Lambdad(keys=["label"], func=convert_label_to_single_channel),
  372. Resized(keys=["image", "label"], spatial_size=args.target_spatial_size, mode=("bilinear", "nearest")),
  373. ScaleIntensityd(keys=["image"]),
  374. RandFlipd(keys=["image", "label"], prob=0.5),
  375. RandRotated(keys=["image", "label"], range_x=0.15, prob=0.5),
  376. ])
  377. if args.dataset_enhanced:
  378. train_transforms = create_enhanced_transforms(args.target_spatial_size)
  379. print("✓ 使用增强数据增强策略")
  380. # Define validation set transformations
  381. val_transforms = Compose([
  382. LoadImaged(keys=["image", "label"]),
  383. EnsureChannelFirstd(keys=["image", "label"]),
  384. # Convert labels to single-channel (take first channel or convert to grayscale)
  385. Lambdad(keys=["label"], func=convert_label_to_single_channel),
  386. Resized(keys=["image", "label"], spatial_size=args.target_spatial_size, mode=("bilinear", "nearest")),
  387. ScaleIntensityd(keys=["image"]),
  388. ])
  389. # Create training dataset
  390. train_dataset = PolypDetectionDataset(
  391. root_dir=Path(args.data_root) / args.dataset_name,
  392. flag='train',
  393. transform=train_transforms
  394. )
  395. # Create validation dataset
  396. val_dataset = PolypDetectionDataset(
  397. root_dir=Path(args.data_root) / args.dataset_name,
  398. flag='val',
  399. transform=val_transforms
  400. )
  401. print(f"✓ Training set size: {len(train_dataset)} samples")
  402. print(f"✓ Validation set size: {len(val_dataset)} samples")
  403. print(f"✓ Total samples: {len(train_dataset) + len(val_dataset)} samples")
  404. print("=" * 60)
  405. return train_dataset, val_dataset
  406. def create_dataloaders(args, train_dataset, val_dataset):
  407. """
  408. Create data loaders
  409. Args:
  410. args: Command line arguments
  411. train_dataset: Training dataset
  412. val_dataset: Validation dataset
  413. Returns:
  414. tuple: (train_loader, val_loader)
  415. """
  416. train_loader = monai.data.DataLoader(
  417. train_dataset,
  418. batch_size=args.batch_size,
  419. shuffle=True,
  420. num_workers=args.num_workers,
  421. pin_memory=args.pin_memory,
  422. drop_last=True
  423. )
  424. val_loader = monai.data.DataLoader(
  425. val_dataset,
  426. batch_size=args.batch_size,
  427. shuffle=False,
  428. num_workers=args.num_workers,
  429. pin_memory=args.pin_memory,
  430. drop_last=False
  431. )
  432. print(f"✓ Training data loader: {len(train_loader)} batches")
  433. print(f"✓ Validation data loader: {len(val_loader)} batches")
  434. return train_loader, val_loader
  435. def create_model(args):
  436. """
  437. Create the model
  438. Args:
  439. args: Command line arguments
  440. Returns:
  441. torch.nn.Module: Initialized model
  442. """
  443. print("\n" + "=" * 60)
  444. print("Creating model...")
  445. model = Wavelet_FFT_SwinUNETR(
  446. in_channels=args.in_channels,
  447. out_channels=args.out_channels,
  448. feature_size=args.feature_size,
  449. spatial_dims=args.spatial_dims,
  450. wavelet_enhancement=args.use_wavelet,
  451. wavelet_J=args.wavelet_J,
  452. wavelet_wave=args.wavelet_wave,
  453. wavelet_mode='symmetric',
  454. wavelet_reduction=args.wavelet_reduction,
  455. fft_enhancement=args.use_fft,
  456. use_v2=args.use_v2
  457. )
  458. # Print model information
  459. total_params = sum(p.numel() for p in model.parameters())
  460. trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  461. print(f"\n✓ Total model parameters: {total_params:,}")
  462. print(f"✓ Trainable parameters: {trainable_params:,}")
  463. print(f"✓ Using device: {args.device}")
  464. print("=" * 60)
  465. return model
  466. def create_loss_function(args):
  467. """
  468. Create loss function
  469. Args:
  470. args: Command line arguments
  471. Returns:
  472. Callable: Loss function
  473. """
  474. loss_fn = CombinedDiceCEIoULoss(
  475. dice_weight=args.dice_weight,
  476. ce_weight=args.ce_weight,
  477. iou_weight=args.iou_weight,
  478. include_background=True,
  479. to_onehot_y=False,
  480. softmax=False,
  481. sigmoid=True,
  482. )
  483. return loss_fn
  484. def create_optimizer(args, model):
  485. """
  486. Create optimizer
  487. Args:
  488. args: Command line arguments
  489. model: Model
  490. Returns:
  491. Optimizer: Optimizer
  492. """
  493. optimizer = AdamW(
  494. model.parameters(),
  495. lr=args.learning_rate,
  496. weight_decay=args.weight_decay
  497. )
  498. scheduler = ReduceLROnPlateau(
  499. optimizer,
  500. mode='min', # 验证损失越小越好
  501. factor=0.5, # 每次乘以 0.5
  502. patience=20, # 20 个 epoch 不下降则降低 LR
  503. threshold=1e-4, # 最小变化阈值
  504. cooldown=5, # 降低 LR 后的冷却期
  505. min_lr=1e-6 # 学习率下限
  506. )
  507. print(f"✓ Optimizer: AdamW")
  508. print(f" - Learning rate: {args.learning_rate}")
  509. print(f" - Weight decay: {args.weight_decay}")
  510. print(f"✓ Scheduler: ReduceLROnPlateau")
  511. print(f" - Mode: {scheduler.mode}")
  512. print(f" - Decay factor: {scheduler.factor}")
  513. print(f" - Patience: {scheduler.patience}")
  514. print(f" - Minimum change threshold: {scheduler.threshold}")
  515. print(f" - Cooldown period: {scheduler.cooldown}")
  516. print(f" - Minimum learning rate: {scheduler.min_lrs}")
  517. return optimizer, scheduler
  518. def setup_swanlab(args):
  519. """
  520. Configure SwanLab experiment tracking
  521. Args:
  522. args: Command line arguments
  523. Returns:
  524. swanlab.Run: SwanLab run object
  525. """
  526. # If experiment name is not specified, use timestamp
  527. if args.swanlab_experiment is None:
  528. args.swanlab_experiment = "v2_" + args.dataset_name + "_" + datetime.now().strftime("%Y%m%d_%H%M%S")
  529. # Create log directory
  530. os.makedirs(args.swanlab_log_dir, exist_ok=True)
  531. os.makedirs(args.output_dir, exist_ok=True)
  532. # Initialize SwanLab
  533. run = swanlab.init(
  534. project=args.swanlab_project,
  535. experiment_name=args.swanlab_experiment,
  536. logdir=args.swanlab_log_dir,
  537. config=vars(args)
  538. )
  539. print(f"\n✓ SwanLab experiment initialized: {args.swanlab_experiment}")
  540. print(f" - Project: {args.swanlab_project}")
  541. print(f" - Log directory: {args.swanlab_log_dir}")
  542. return run
  543. def main():
  544. """
  545. Main training function
  546. """
  547. # ==================== Step 1: Parse arguments ====================
  548. args = parse_args()
  549. # Set random seed for reproducibility
  550. torch.manual_seed(args.seed)
  551. if torch.cuda.is_available():
  552. torch.cuda.manual_seed_all(args.seed)
  553. print("\n" + "=" * 60)
  554. print("Polyp Segmentation Model Training Started")
  555. print("=" * 60)
  556. print(f"Using device: {args.device}")
  557. print(f"Batch size: {args.batch_size}")
  558. print(f"Maximum epochs: {args.max_epochs}")
  559. if args.early_stopping:
  560. print(
  561. f"Early stopping: enabled (patience={args.early_stopping_patience}, monitor={args.early_stopping_monitor})")
  562. # ==================== Step 2: Initialize SwanLab ====================
  563. run = setup_swanlab(args)
  564. # ==================== Step 3: Create datasets and data loaders ====================
  565. train_dataset, val_dataset = create_datasets(args)
  566. train_loader, val_loader = create_dataloaders(args, train_dataset, val_dataset)
  567. # ==================== Step 4: Create model, loss function, optimizer ====================
  568. model = create_model(args)
  569. model = model.to(args.device)
  570. loss_function = create_loss_function(args)
  571. optimizer, scheduler = create_optimizer(args, model)
  572. # ==================== Step 5: Create evaluation metrics ====================
  573. dice_metric = DiceMetric(reduction="mean")
  574. iou_metric = MeanIoU(reduction="mean")
  575. hd_metric = HausdorffDistanceMetric(reduction="mean")
  576. # ==================== Step 6: Setup training loop ====================
  577. best_dice = -1
  578. best_dice_epoch = -1
  579. best_metric = -1
  580. best_metric_epoch = -1
  581. best_iou = -1
  582. best_iou_epoch = -1
  583. epoch_loss_values = []
  584. dice_metric_values = []
  585. iou_metric_values = []
  586. hd_metric_values = []
  587. start_epoch = 0
  588. # ==================== Early stopping related variables ====================
  589. early_stopping_counter = 0
  590. should_stop = False
  591. has_restarted = False # Flag indicating whether it has been restarted once
  592. # ==================== Step 7: Resume training (if checkpoint exists) ====================
  593. checkpoint_loaded = False
  594. checkpoint = None
  595. if args.resume:
  596. # User specified checkpoint path
  597. if not os.path.exists(args.resume):
  598. raise FileNotFoundError(f"Checkpoint file does not exist: {args.resume}")
  599. checkpoint_path = args.resume
  600. print(f"\nResuming training from user-specified checkpoint: {checkpoint_path}")
  601. checkpoint = torch.load(checkpoint_path, map_location=args.device)
  602. checkpoint_loaded = True
  603. elif args.auto_resume:
  604. # Automatically find best checkpoint
  605. checkpoint_path = find_best_checkpoint(args)
  606. if checkpoint_path:
  607. print(f"\nAuto-resume mode: loading {checkpoint_path}")
  608. checkpoint = torch.load(checkpoint_path, map_location=args.device)
  609. checkpoint_loaded = True
  610. else:
  611. print("\nNo checkpoints found, starting training from scratch")
  612. if checkpoint_loaded:
  613. # Load model weights - supports migration from v1 to v2
  614. model_dict = model.state_dict()
  615. # Try to load from checkpoint
  616. try:
  617. pretrained_dict = checkpoint["model_state_dict"]
  618. print("✓ Model weights loaded from training checkpoint")
  619. except KeyError:
  620. pretrained_dict = checkpoint
  621. print("Loading model weights from best Dice or best overall model")
  622. # Filter and match parameters (handle structural changes from v1->v2)
  623. matched_params = {}
  624. unmatched_params = []
  625. missing_params = []
  626. for name, param in model_dict.items():
  627. if name in pretrained_dict:
  628. pretrained_param = pretrained_dict[name]
  629. # Check if shape matches
  630. if param.shape == pretrained_param.shape:
  631. matched_params[name] = pretrained_param
  632. else:
  633. unmatched_params.append(f"{name} (shape mismatch: {param.shape} vs {pretrained_param.shape})")
  634. else:
  635. missing_params.append(name)
  636. # Output loading statistics
  637. print(f"\nWeight loading statistics:")
  638. print(f" ✓ Successfully matched parameters: {len(matched_params)}/{len(model_dict)}")
  639. print(f" ⚠ Shape mismatched parameters: {len(unmatched_params)}")
  640. print(f" ✗ Newly added parameters (randomly initialized): {len(missing_params)}")
  641. if unmatched_params:
  642. print(f"\nShape mismatched layers:")
  643. for info in unmatched_params[:5]: # Only show first 5
  644. print(f" - {info}")
  645. if len(unmatched_params) > 5:
  646. print(f" ... {len(unmatched_params) - 5} more")
  647. if missing_params:
  648. print(f"\nNewly added layers (will be randomly initialized):")
  649. for name in missing_params[:5]: # Only show first 5
  650. print(f" - {name}")
  651. if len(missing_params) > 5:
  652. print(f" ... {len(missing_params) - 5} more")
  653. # Update pre-trained dictionary
  654. model_dict.update(matched_params)
  655. # Load matched parameters
  656. model.load_state_dict(model_dict, strict=False)
  657. print(f"\n✓ Model weights loaded (strict mode: False)")
  658. print("=" * 60)
  659. if "optimizer_state_dict" in checkpoint:
  660. # Load optimizer state
  661. optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
  662. print("✓ Optimizer state loaded")
  663. # Load epoch number
  664. # if "epoch" in checkpoint:
  665. # start_epoch = checkpoint["epoch"] + 1 # Start from next epoch
  666. # print(f"✓ Training epoch restored to epoch {start_epoch}")
  667. # Load best metrics
  668. if "best_dice" in checkpoint:
  669. best_dice = checkpoint["best_dice"]
  670. best_dice_epoch = checkpoint["best_dice_epoch"]
  671. print(f"✓ Best metrics restored: Dice={best_dice:.4f} (Epoch {best_dice_epoch})")
  672. # Load historical loss and metric values (optional)
  673. if "epoch_loss_values" in checkpoint:
  674. epoch_loss_values = checkpoint["epoch_loss_values"]
  675. if "dice_metric_values" in checkpoint:
  676. dice_metric_values = checkpoint["dice_metric_values"]
  677. # Load early stopping state
  678. if args.early_stopping:
  679. if "early_stopping_counter" in checkpoint:
  680. early_stopping_counter = checkpoint["early_stopping_counter"]
  681. print(f"✓ Early stopping counter restored: {early_stopping_counter}")
  682. if "should_stop" in checkpoint and checkpoint["should_stop"]:
  683. should_stop = False # Even if marked as stopped, allow continued training
  684. print("✓ Early stopping state reset, can continue training")
  685. print(f"✓ Training will continue from epoch {start_epoch}")
  686. print("=" * 60)
  687. print("\n" + "=" * 60)
  688. print("Starting training...")
  689. print("=" * 60)
  690. start_time = time.time()
  691. try:
  692. for epoch in range(start_epoch, run.config.max_epochs):
  693. # ========== Check early stopping condition ==========
  694. if should_stop and args.early_stopping:
  695. print(f"\n{'=' * 60}")
  696. print(f"Early stopping triggered! Training will terminate early at epoch {epoch + 1}")
  697. print(f"{'=' * 60}")
  698. if not has_restarted:
  699. # First early stopping: load best weights, restart training
  700. print("Early stopping detected, preparing to restart training from best model...")
  701. # 1. Find the best Dice model
  702. best_checkpoint_path = os.path.join(
  703. args.output_dir,
  704. f"best_dice_model_{args.dataset_name}.pt"
  705. )
  706. if os.path.exists(best_checkpoint_path):
  707. print(f"Loading best Dice model: {best_checkpoint_path}")
  708. checkpoint = torch.load(best_checkpoint_path, map_location=args.device)
  709. # 2. Load best weights
  710. model.load_state_dict(checkpoint)
  711. print("✓ Model weights restored to best state")
  712. # 3. Reset optimizer
  713. optimizer, scheduler = create_optimizer(args, model)
  714. print("✓ Optimizer has been reset")
  715. # 4. Reset early stopping counter
  716. early_stopping_counter = 0
  717. should_stop = False
  718. has_restarted = True
  719. print("✓ Training restarted from best model")
  720. print(f"{'=' * 60}\n")
  721. continue # Skip break, continue to next epoch
  722. else:
  723. print(f"Warning: Best model file not found {best_checkpoint_path}")
  724. print("Will stop training directly")
  725. # Second early stopping or best model not found: truly stop
  726. print("Training has been restarted once after early stopping, now stopping training")
  727. break
  728. # ========== Training phase ==========
  729. model.train()
  730. step = 0
  731. epoch_loss = 0
  732. epoch_loss_dice_ce = 0
  733. epoch_loss_iou = 0
  734. for batch_data in train_loader:
  735. step += 1
  736. inputs = batch_data["image"].to(args.device)
  737. targets = batch_data["label"].to(args.device)
  738. optimizer.zero_grad()
  739. outputs = model(inputs)
  740. loss, loss_dice_ce, loss_iou = loss_function(outputs, targets)
  741. loss.backward()
  742. optimizer.step()
  743. epoch_loss += loss.item()
  744. epoch_loss_dice_ce += loss_dice_ce.item()
  745. epoch_loss_iou += loss_iou.item()
  746. # If this is the first epoch resumed from checkpoint, print notification
  747. if epoch == start_epoch and start_epoch > 0:
  748. print(f"\n✓ Training resumed from epoch {start_epoch}")
  749. epoch_loss /= step
  750. epoch_loss_dice_ce /= step
  751. epoch_loss_iou /= step
  752. epoch_loss_values.append(epoch_loss)
  753. print(f"\nEpoch {epoch + 1}/{args.max_epochs} - Training loss: {epoch_loss:.4f}")
  754. # Log to SwanLab
  755. swanlab.log({
  756. "train/loss": epoch_loss,
  757. "train/loss_dice_ce": epoch_loss_dice_ce,
  758. "train/loss_iou": epoch_loss_iou,
  759. "train/lr": optimizer.param_groups[0]['lr'],
  760. }, step=(epoch + 1))
  761. # ========== Validation phase ==========
  762. model.eval()
  763. val_loss_total = 0
  764. with torch.no_grad():
  765. dice_metric.reset()
  766. iou_metric.reset()
  767. hd_metric.reset()
  768. for val_data in val_loader:
  769. val_images = val_data["image"].to(args.device)
  770. val_labels = val_data["label"].to(args.device)
  771. val_outputs = model(val_images)
  772. # Calculate validation loss
  773. val_loss_batch, _, _ = loss_function(val_outputs, val_labels)
  774. val_loss_total += val_loss_batch.item()
  775. # Post-processing
  776. val_outputs = torch.sigmoid(val_outputs)
  777. val_outputs = (val_outputs > 0.5).int()
  778. # Calculate Dice score
  779. dice_metric(y_pred=val_outputs, y=val_labels)
  780. iou_metric(y_pred=val_outputs, y=val_labels)
  781. hd_metric(y_pred=val_outputs, y=val_labels)
  782. # Calculate average validation loss
  783. val_loss_avg = val_loss_total / len(val_loader)
  784. # Update learning rate scheduler
  785. scheduler.step(val_loss_avg)
  786. current_lr = optimizer.param_groups[0]['lr']
  787. # Aggregate results
  788. mean_dice = dice_metric.aggregate().item()
  789. dice_metric_values.append(mean_dice)
  790. mean_iou = iou_metric.aggregate().item()
  791. iou_metric_values.append(mean_iou)
  792. mean_hd = hd_metric.aggregate().item()
  793. hd_metric_values.append(mean_hd)
  794. print(
  795. f"Epoch {epoch + 1} - Validation Dice: {mean_dice:.4f}, Validation loss: {val_loss_avg:.4f}, Current LR: {current_lr:.2e}")
  796. swanlab.log({
  797. "val/loss": val_loss_avg,
  798. "val/mean_dice": mean_dice,
  799. "val/mean_iou": mean_iou,
  800. "val/mean_hd": mean_hd,
  801. "val/lr": current_lr,
  802. }, step=(epoch + 1))
  803. # ========== Early stopping check ==========
  804. if args.early_stopping:
  805. # Get current monitored metric
  806. if args.early_stopping_monitor == "dice":
  807. current_score = mean_dice
  808. best_score = best_dice
  809. is_better = current_score > best_score + args.early_stopping_min_delta
  810. elif args.early_stopping_monitor == "iou":
  811. current_score = mean_iou
  812. best_score = best_iou
  813. is_better = current_score > best_score + args.early_stopping_min_delta
  814. elif args.early_stopping_monitor == "metric":
  815. normalized_hd = 1.0 / (1.0 + mean_hd)
  816. current_score = 1 * mean_dice + 1 * mean_iou + 1 * normalized_hd
  817. best_score = best_metric
  818. is_better = current_score > best_score + args.early_stopping_min_delta
  819. else: # loss
  820. current_score = -val_loss_avg # Lower loss is better, so take negative
  821. best_score = -min(epoch_loss_values) if epoch_loss_values else float('-inf')
  822. is_better = current_score > best_score + args.early_stopping_min_delta
  823. # Check if there is improvement
  824. if is_better:
  825. early_stopping_counter = 0
  826. print(
  827. f" ✓ {args.early_stopping_monitor.upper()} metric improved: {current_score:.4f} > {best_score:.4f}")
  828. else:
  829. early_stopping_counter += 1
  830. print(
  831. f" ⚠ {args.early_stopping_monitor.upper()} metric did not improve, counter: {early_stopping_counter}/{args.early_stopping_patience}")
  832. # Check if early stopping should be triggered
  833. if early_stopping_counter >= args.early_stopping_patience:
  834. should_stop = True
  835. # Save best Dice model
  836. if mean_dice > best_dice:
  837. best_dice = mean_dice
  838. best_dice_epoch = epoch + 1
  839. checkpoint_path = os.path.join(args.output_dir, f"best_dice_model_{args.dataset_name}.pt")
  840. Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
  841. torch.save(model.state_dict(), checkpoint_path)
  842. print(
  843. f"✓ Found better Dice model! Dice: {mean_dice:.4f}, IoU: {mean_iou:.4f}, HD: {mean_hd:.4f}, saved to {checkpoint_path}")
  844. # Save best IoU model
  845. if mean_iou > best_iou:
  846. best_iou = mean_iou
  847. best_iou_epoch = epoch + 1
  848. checkpoint_path = os.path.join(args.output_dir, f"best_iou_model_{args.dataset_name}.pt")
  849. Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
  850. torch.save(model.state_dict(), checkpoint_path)
  851. print(
  852. f"✓ Found better IoU model! IoU: {mean_iou:.4f}, Dice: {mean_dice:.4f}, HD: {mean_hd:.4f}, saved to {checkpoint_path}"
  853. )
  854. # Save best overall model
  855. normalized_hd = 1.0 / (1.0 + mean_hd)
  856. mean_metric = (
  857. 1 * mean_dice +
  858. 1 * mean_iou +
  859. 1 * normalized_hd
  860. )
  861. if mean_metric > best_metric:
  862. best_metric = mean_metric
  863. best_metric_epoch = epoch + 1
  864. checkpoint_path = os.path.join(args.output_dir, f"best_metric_model_{args.dataset_name}.pt")
  865. Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
  866. torch.save(model.state_dict(), checkpoint_path)
  867. print(
  868. f"✓ Found better overall model! Overall score: {mean_metric:.4f}, Dice: {mean_dice:.4f}, IoU: {mean_iou:.4f}, HD: {mean_hd:.4f}, saved to {checkpoint_path}"
  869. )
  870. # Periodically save checkpoint
  871. if (epoch + 1) % args.save_every == 0:
  872. checkpoint_path = os.path.join(args.output_dir, f"checkpoints_{args.dataset_name}",
  873. f"checkpoint_epoch={epoch}.pt")
  874. Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
  875. torch.save({
  876. "epoch": epoch,
  877. "model_state_dict": model.state_dict(),
  878. "optimizer_state_dict": optimizer.state_dict(),
  879. "best_dice": best_dice,
  880. "best_dice_epoch": best_dice_epoch,
  881. "epoch_loss_values": epoch_loss_values,
  882. "dice_metric_values": dice_metric_values,
  883. "iou_metric_values": iou_metric_values,
  884. "hd_metric_values": hd_metric_values,
  885. "best_metric": best_metric,
  886. "best_metric_epoch": best_metric_epoch,
  887. "best_iou": best_iou,
  888. "best_iou_epoch": best_iou_epoch,
  889. "early_stopping_counter": early_stopping_counter,
  890. "should_stop": should_stop
  891. }, checkpoint_path)
  892. print(f"✓ Checkpoint saved: {checkpoint_path}")
  893. except KeyboardInterrupt:
  894. print("\nTraining interrupted by user")
  895. finally:
  896. end_time = time.time()
  897. training_time = end_time - start_time
  898. print("\n" + "=" * 60)
  899. print("Training completed!")
  900. print(f"Total training time: {training_time / 3600:.2f} hours")
  901. print(f"Best validation Dice: {best_dice:.4f} (Epoch {best_dice_epoch})")
  902. print("=" * 60)
  903. # Close SwanLab
  904. swanlab.finish()
  905. print("✓ SwanLab experiment saved")
  906. if __name__ == "__main__":
  907. main()