eval.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. import argparse
  2. import os
  3. from datetime import datetime
  4. from pathlib import Path
  5. import cv2
  6. import numpy as np
  7. import torch
  8. from monai.metrics import DiceMetric, MeanIoU, HausdorffDistanceMetric
  9. from monai.transforms import (
  10. Compose, LoadImaged, ScaleIntensityd, EnsureChannelFirstd,
  11. ToTensord, Resized, Lambdad
  12. )
  13. from torch.utils.data import DataLoader
  14. from datasets.PolypDetectionDataset.PolypDetectionDataset import PolypDetectionDataset
  15. from lib.model.model import Wavelet_FFT_SwinUNETR
  16. def parse_args():
  17. parser = argparse.ArgumentParser(description="Polyp Segmentation Model Evaluation Script")
  18. # ==================== Dataset-related Parameters ====================
  19. parser.add_argument(
  20. "--dataset_name",
  21. type=str,
  22. required=True,
  23. help="Dataset name"
  24. )
  25. parser.add_argument(
  26. "--data_root",
  27. type=str,
  28. default=r"./data/Polyp-Detection-Dataset",
  29. help="Root directory path of the dataset"
  30. )
  31. parser.add_argument(
  32. "--num_workers",
  33. type=int,
  34. default=0,
  35. help="Number of worker processes for data loader"
  36. )
  37. parser.add_argument(
  38. "--pin_memory",
  39. type=bool,
  40. default=True,
  41. help="Whether to enable pinned memory"
  42. )
  43. parser.add_argument(
  44. "--target_spatial_size",
  45. type=tuple,
  46. default=(512, 512),
  47. help="Target spatial size, format like: '512,512' or '(512,512)'"
  48. )
  49. parser.add_argument(
  50. "--batch_size",
  51. type=int,
  52. default=1,
  53. help="Batch size (smaller batch size recommended for evaluation)"
  54. )
  55. # ==================== Model-related Parameters ====================
  56. parser.add_argument(
  57. "--in_channels",
  58. type=int,
  59. default=3,
  60. help="Number of input image channels"
  61. )
  62. parser.add_argument(
  63. "--out_channels",
  64. type=int,
  65. default=1,
  66. help="Number of output foreground channels"
  67. )
  68. parser.add_argument(
  69. "--feature_size",
  70. type=int,
  71. default=48,
  72. help="Network feature dimension"
  73. )
  74. parser.add_argument(
  75. "--spatial_dims",
  76. type=int,
  77. default=2,
  78. choices=[2, 3],
  79. help="Spatial dimension (2D or 3D)"
  80. )
  81. parser.add_argument(
  82. "--use_wavelet",
  83. type=bool,
  84. default=True,
  85. help="Whether to enable wavelet enhancement module"
  86. )
  87. parser.add_argument(
  88. "--wavelet_J",
  89. type=int,
  90. default=2,
  91. help="Wavelet decomposition levels"
  92. )
  93. parser.add_argument(
  94. "--wavelet_wave",
  95. type=str,
  96. default="db4",
  97. help="Wavelet basis type"
  98. )
  99. parser.add_argument(
  100. "--wavelet_reduction",
  101. type=int,
  102. default=16,
  103. help="Wavelet attention compression ratio"
  104. )
  105. parser.add_argument(
  106. "--use_fft",
  107. type=bool,
  108. default=True,
  109. help="Whether to enable FFT enhancement module"
  110. )
  111. parser.add_argument(
  112. "--use_v2",
  113. type=bool,
  114. default=True,
  115. help="Whether to enable Swin-UNETR v2 module"
  116. )
  117. # ==================== Model Loading Parameters ====================
  118. parser.add_argument(
  119. "--device",
  120. type=str,
  121. default="cuda" if torch.cuda.is_available() else "cpu",
  122. help="Training device (cuda or cpu)"
  123. )
  124. # ==================== Other Parameters ====================
  125. parser.add_argument(
  126. "--save_results",
  127. type=bool,
  128. default=False,
  129. help="是否保存预测结果"
  130. )
  131. parser.add_argument(
  132. "--dir_flag",
  133. type=str,
  134. default="_minute",
  135. help="Prediction result save file suffix"
  136. )
  137. parser.add_argument(
  138. "--results_dir",
  139. type=str,
  140. default="./evaluation_results",
  141. help="Directory for saving prediction results"
  142. )
  143. parser.add_argument(
  144. "--outputs_dir",
  145. type=str,
  146. default="./outputs",
  147. help="是否保存预测结果"
  148. )
  149. parser.add_argument(
  150. "--save_visualization",
  151. type=bool,
  152. default=True,
  153. help="Whether to save visualization results"
  154. )
  155. parser.add_argument(
  156. "--vis_num_samples",
  157. type=int,
  158. default=1000,
  159. help="Number of samples to save for visualization"
  160. )
  161. parser.add_argument(
  162. "--best_metric",
  163. type=str,
  164. default=False,
  165. help="Load best overall model, False means load best Dice model by default"
  166. )
  167. return parser.parse_args()
  168. def create_val_transform(target_spatial_size=(512, 512)):
  169. """
  170. Create validation set transformations
  171. Args:
  172. target_spatial_size: Target spatial size
  173. Returns:
  174. Compose: Validation transformation composition
  175. """
  176. def convert_label_to_single_channel(label_tensor):
  177. """Convert RGB labels to single-channel binary mask"""
  178. single_channel = label_tensor[0:1, :, :]
  179. binary_label = (single_channel > 127).float()
  180. return binary_label
  181. val_transforms = Compose([
  182. LoadImaged(keys=["image", "label"]),
  183. EnsureChannelFirstd(keys=["image", "label"]),
  184. Lambdad(keys=["label"], func=convert_label_to_single_channel),
  185. Resized(keys=["image", "label"], spatial_size=target_spatial_size,
  186. mode=("bilinear", "nearest")),
  187. ScaleIntensityd(keys=["image"]),
  188. ToTensord(keys=["image", "label"]),
  189. ])
  190. return val_transforms
  191. def create_dataloader(args, dataset):
  192. """
  193. Create data loader
  194. Args:
  195. args: Command line arguments
  196. dataset: Dataset
  197. Returns:
  198. DataLoader: Data loader
  199. """
  200. loader = DataLoader(
  201. dataset,
  202. batch_size=args.batch_size,
  203. shuffle=False,
  204. num_workers=args.num_workers,
  205. pin_memory=args.pin_memory,
  206. drop_last=False
  207. )
  208. return loader
  209. def load_model(args, checkpoint_path):
  210. """
  211. Load model
  212. Args:
  213. args: Command line arguments
  214. checkpoint_path: Checkpoint path
  215. Returns:
  216. model: Model with loaded weights
  217. """
  218. print(f"\nLoading model: {checkpoint_path}")
  219. # Create model
  220. model = Wavelet_FFT_SwinUNETR(
  221. in_channels=args.in_channels,
  222. out_channels=args.out_channels,
  223. feature_size=args.feature_size,
  224. spatial_dims=args.spatial_dims,
  225. wavelet_enhancement=args.use_wavelet,
  226. wavelet_J=args.wavelet_J,
  227. wavelet_wave=args.wavelet_wave,
  228. wavelet_mode='symmetric',
  229. wavelet_reduction=args.wavelet_reduction,
  230. fft_enhancement=args.use_fft,
  231. use_v2=args.use_v2
  232. )
  233. # Load weights
  234. if not os.path.exists(checkpoint_path):
  235. raise FileNotFoundError(f"Model file does not exist: {checkpoint_path}")
  236. checkpoint = torch.load(checkpoint_path, map_location=args.device)
  237. # Check checkpoint format
  238. if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
  239. model.load_state_dict(checkpoint["model_state_dict"])
  240. print("✓ Model weights loaded from training checkpoint")
  241. else:
  242. model.load_state_dict(checkpoint)
  243. print("✓ Model weights loaded directly")
  244. model = model.to(args.device)
  245. model.eval()
  246. print(f"✓ Model loaded and set to evaluation mode")
  247. print(f"✓ 使用设备:{args.device}")
  248. return model
  249. def evaluate_model(model, dataloader, args):
  250. """
  251. Evaluate model performance
  252. Args:
  253. model: Model
  254. dataloader: Data loader
  255. args: Command line arguments
  256. Returns:
  257. dict: Dictionary containing various metrics
  258. """
  259. print("\n" + "=" * 60)
  260. print("Starting evaluation...")
  261. print("=" * 60)
  262. # Initialize metrics
  263. dice_metric = DiceMetric(reduction="mean")
  264. iou_metric = MeanIoU(reduction="mean")
  265. hd_metric = HausdorffDistanceMetric(reduction="mean")
  266. hd95_metric = HausdorffDistanceMetric(reduction="mean", percentile=95)
  267. total_samples = 0
  268. flag = 0
  269. saved_vis_count = 0
  270. vis_dir = None
  271. # Create visualization save directory
  272. if args.save_visualization:
  273. vis_dir = os.path.join(args.results_dir, f"visualization_{args.dataset_name}")
  274. os.makedirs(vis_dir, exist_ok=True)
  275. print(f"✓ Visualization results will be saved to: {vis_dir}")
  276. with torch.no_grad():
  277. dice_metric.reset()
  278. iou_metric.reset()
  279. hd_metric.reset()
  280. hd95_metric.reset()
  281. for batch_idx, batch_data in enumerate(dataloader):
  282. images = batch_data["image"].to(args.device)
  283. labels = batch_data["label"].to(args.device)
  284. # Forward propagation
  285. outputs = model(images) # [B, 1, H, W]
  286. # Post-processing
  287. outputs = torch.sigmoid(outputs)
  288. outputs = (outputs > 0.5).int()
  289. if flag == 0:
  290. flag = 1
  291. print(f"\n{'=' * 60}")
  292. print(f"[Detailed Debug Info - Batch 0]")
  293. print(f"{'=' * 60}")
  294. print(f"Input image size: {images.shape}")
  295. print(f"Output image size after post-processing: {outputs.shape}")
  296. print(f"Unique values: {torch.unique(outputs)}")
  297. print(f"\n--- Labels ---")
  298. print(f"Labels shape: {labels.shape}")
  299. print(f"Labels unique values: {torch.unique(labels)}")
  300. print(f"{'=' * 60}\n")
  301. # Calculate metrics for current batch - use directly without additional processing
  302. dice_metric(y_pred=outputs, y=labels)
  303. iou_metric(y_pred=outputs, y=labels)
  304. hd_metric(y_pred=outputs, y=labels)
  305. hd95_metric(y_pred=outputs, y=labels)
  306. # Save visualization results
  307. if args.save_visualization and saved_vis_count < args.vis_num_samples:
  308. save_visualization(
  309. images=images,
  310. labels=labels,
  311. predictions=outputs,
  312. save_dir=vis_dir,
  313. batch_idx=batch_idx,
  314. max_samples=args.vis_num_samples - saved_vis_count
  315. )
  316. saved_vis_count += images.shape[0]
  317. # Print progress
  318. if (batch_idx + 1) % 10 == 0 or (batch_idx + 1) == len(dataloader):
  319. print(f"进度:{batch_idx + 1}/{len(dataloader)} batches")
  320. total_samples += images.shape[0]
  321. # Aggregate all metrics
  322. mean_dice = dice_metric.aggregate().item()
  323. mean_iou = iou_metric.aggregate().item()
  324. mean_hd = hd_metric.aggregate().item()
  325. mean_hd95 = hd95_metric.aggregate().item()
  326. results = {
  327. "mDice": mean_dice,
  328. "mIoU": mean_iou,
  329. "mHD": mean_hd,
  330. "mHD95": mean_hd95,
  331. "total_samples": total_samples,
  332. }
  333. return results
  334. def save_visualization(images, labels, predictions, save_dir, batch_idx, max_samples):
  335. """
  336. Save visualization results: original image, ground truth label, prediction label combined
  337. Args:
  338. images: Input image batch [B, C, H, W]
  339. labels: Ground truth labels [B, 1, H, W]
  340. predictions: Prediction labels [B, 1, H, W]
  341. save_dir: Save directory
  342. batch_idx: Batch index
  343. max_samples: Maximum samples to save
  344. """
  345. for i in range(min(images.shape[0], max_samples)):
  346. try:
  347. # Extract single sample
  348. image = images[i].cpu()
  349. label = labels[i].cpu()
  350. prediction = predictions[i].cpu()
  351. # Image processing: de-normalize and convert to RGB
  352. if image.shape[0] == 1: # 灰度图
  353. image_np = image[0].numpy() * 255
  354. image_rgb = np.stack([image_np] * 3, axis=-1).astype(np.uint8)
  355. else: # RGB image
  356. image_np = image.numpy().transpose(1, 2, 0)
  357. # De-normalize (assuming z-score normalization, approximate handling)
  358. image_np = np.clip((image_np - image_np.min()) / (image_np.max() - image_np.min() + 1e-8) * 255, 0, 255)
  359. image_rgb = image_np.astype(np.uint8)
  360. # Label processing: convert to binary mask
  361. label_np = label[0].numpy() if label.shape[0] == 1 else label.numpy()
  362. label_binary = (label_np > 0.5).astype(np.float32)
  363. # Prediction processing: convert to binary mask
  364. pred_np = prediction[0].numpy() if prediction.shape[0] == 1 else prediction.numpy()
  365. pred_binary = (pred_np > 0.5).astype(np.float32)
  366. # Create pure black and white label image
  367. label_bw = (label_binary * 255).astype(np.uint8)
  368. label_bw_3ch = np.stack([label_bw] * 3, axis=-1) # Convert to 3 channels for concatenation
  369. # Create pure black and white prediction image
  370. pred_bw = (pred_binary * 255).astype(np.uint8)
  371. pred_bw_3ch = np.stack([pred_bw] * 3, axis=-1) # Convert to 3 channels for concatenation
  372. # Horizontally concatenate three images: original, ground truth B&W, prediction B&W
  373. combined = np.hstack([image_rgb, label_bw_3ch, pred_bw_3ch])
  374. # Add text annotations (at the top of the image)
  375. h, w = image_rgb.shape[:2]
  376. # Calculate text width for centering
  377. orig_text_size = cv2.getTextSize('Original', cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
  378. gt_text_size = cv2.getTextSize('Ground Truth', cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
  379. pred_text_size = cv2.getTextSize('Prediction', cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
  380. combined = cv2.putText(combined, 'Original', ((w - orig_text_size[0]) // 2, 30),
  381. cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
  382. combined = cv2.putText(combined, 'Ground Truth', (w + (w - gt_text_size[0]) // 2, 30),
  383. cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
  384. combined = cv2.putText(combined, 'Prediction', (2 * w + (w - pred_text_size[0]) // 2, 30),
  385. cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
  386. # Save
  387. sample_idx = batch_idx * images.shape[0] + i
  388. save_path = os.path.join(save_dir, f'sample_{sample_idx:04d}.png')
  389. cv2.imwrite(save_path, cv2.cvtColor(combined, cv2.COLOR_RGB2BGR))
  390. except Exception as e:
  391. print(f"⚠️ Error saving sample {i}: {e}")
  392. continue
  393. def print_results(results, dataset_name, checkpoint_path, model):
  394. """
  395. Print evaluation results
  396. Args:
  397. results: Evaluation results dictionary
  398. dataset_name: Dataset name
  399. checkpoint_path: Model checkpoint path
  400. model: Model
  401. """
  402. print("\n" + "=" * 60)
  403. print("Evaluation Results")
  404. print("=" * 60)
  405. print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")
  406. print(f"Dataset: {dataset_name}")
  407. print(f"Model: {checkpoint_path}")
  408. print(f"Number of samples: {results['total_samples']}")
  409. print("-" * 60)
  410. print(f"mDice (Mean Dice Coefficient): {results['mDice']:.3f}")
  411. print(f"mIoU (Mean Intersection over Union): {results['mIoU']:.3f}")
  412. print(f"mHD (Mean Hausdorff Distance): {results['mHD']:.3f}")
  413. print(f"mHD95 (95% Hausdorff Distance): {results['mHD95']:.3f}")
  414. print("=" * 60)
  415. def save_results(results, dataset_name, checkpoint_path, results_dir, model):
  416. """
  417. Save evaluation results to file
  418. Args:
  419. results: Evaluation results dictionary
  420. dataset_name: Dataset name
  421. checkpoint_path: Model checkpoint path
  422. results_dir: Results save directory
  423. model: Model
  424. """
  425. os.makedirs(results_dir, exist_ok=True)
  426. # Generate filename
  427. result_file = os.path.join(results_dir, f"eval_{dataset_name}.txt")
  428. with open(result_file, 'w', encoding='utf-8') as f:
  429. f.write("=" * 60 + "\n")
  430. f.write("Polyp Segmentation Model Evaluation Report\n")
  431. f.write("=" * 60 + "\n\n")
  432. f.write(f"Model parameters: {sum(p.numel() for p in model.parameters())}\n")
  433. f.write(f"Dataset: {dataset_name}\n")
  434. f.write(f"Model checkpoint: {checkpoint_path}\n")
  435. f.write(f"Evaluation time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
  436. f.write(f"Number of samples: {results['total_samples']}\n\n")
  437. f.write("-" * 60 + "\n")
  438. f.write("Evaluation Metrics:\n")
  439. f.write("-" * 60 + "\n")
  440. f.write(f"mDice (Mean Dice Coefficient): {results['mDice']:.3f}\n")
  441. f.write(f"mIoU (Mean Intersection over Union): {results['mIoU']:.3f}\n")
  442. f.write(f"mHD (Mean Hausdorff Distance): {results['mHD']:.3f}\n")
  443. f.write(f"mHD95 (95% Hausdorff Distance): {results['mHD95']:.3f}\n")
  444. f.write("-" * 60 + "\n")
  445. print(f"\n✓ Evaluation results saved to: {result_file}")
  446. def main():
  447. """
  448. Main evaluation function
  449. """
  450. # ==================== Step 1: Parse arguments ====================
  451. args = parse_args()
  452. checkpoint_path = Path(args.outputs_dir + args.dir_flag) / f"best_dice_model_{args.dataset_name}.pt"
  453. if args.best_metric:
  454. checkpoint_path = Path(args.outputs_dir + args.dir_flag) / f"best_metric_model_{args.dataset_name}.pt"
  455. print("\n" + "=" * 60)
  456. print("Polyp Segmentation Model Evaluation")
  457. print("=" * 60)
  458. print(f"Dataset: {args.dataset_name}")
  459. print(f"Model checkpoint: {checkpoint_path}")
  460. # ==================== Step 2: Create validation set and data loader ====================
  461. print("\nLoading validation set...")
  462. val_transform = create_val_transform(args.target_spatial_size)
  463. val_dataset = PolypDetectionDataset(
  464. root_dir=Path(args.data_root) / args.dataset_name,
  465. flag='val',
  466. transform=val_transform
  467. )
  468. val_loader = create_dataloader(args, val_dataset)
  469. print(f"✓ Validation set size: {len(val_dataset)} samples")
  470. print(f"✓ Data loader: {len(val_loader)} batches")
  471. # ==================== Step 3: Load model ====================
  472. model = load_model(args, checkpoint_path)
  473. # ==================== Step 4: Evaluate model ====================
  474. results = evaluate_model(model, val_loader, args)
  475. # ==================== Step 5: Print and save results ====================
  476. print_results(results, args.dataset_name, checkpoint_path, model)
  477. if args.save_results:
  478. save_results(results, args.dataset_name, checkpoint_path, args.results_dir + args.dir_flag, model)
  479. print("\n" + "=" * 60)
  480. print("Evaluation completed!")
  481. print("=" * 60)
  482. if __name__ == "__main__":
  483. main()