| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564 |
- import argparse
- import os
- from datetime import datetime
- from pathlib import Path
- import cv2
- import numpy as np
- import torch
- from monai.metrics import DiceMetric, MeanIoU, HausdorffDistanceMetric
- from monai.transforms import (
- Compose, LoadImaged, ScaleIntensityd, EnsureChannelFirstd,
- ToTensord, Resized, Lambdad
- )
- from torch.utils.data import DataLoader
- from datasets.PolypDetectionDataset.PolypDetectionDataset import PolypDetectionDataset
- from lib.model.model import Wavelet_FFT_SwinUNETR
- def parse_args():
- parser = argparse.ArgumentParser(description="Polyp Segmentation Model Evaluation Script")
- # ==================== Dataset-related Parameters ====================
- parser.add_argument(
- "--dataset_name",
- type=str,
- required=True,
- help="Dataset name"
- )
- parser.add_argument(
- "--data_root",
- type=str,
- default=r"./data/Polyp-Detection-Dataset",
- help="Root directory path of the dataset"
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=0,
- help="Number of worker processes for data loader"
- )
- parser.add_argument(
- "--pin_memory",
- type=bool,
- default=True,
- help="Whether to enable pinned memory"
- )
- parser.add_argument(
- "--target_spatial_size",
- type=tuple,
- default=(512, 512),
- help="Target spatial size, format like: '512,512' or '(512,512)'"
- )
- parser.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="Batch size (smaller batch size recommended for evaluation)"
- )
- # ==================== Model-related Parameters ====================
- parser.add_argument(
- "--in_channels",
- type=int,
- default=3,
- help="Number of input image channels"
- )
- parser.add_argument(
- "--out_channels",
- type=int,
- default=1,
- help="Number of output foreground channels"
- )
- parser.add_argument(
- "--feature_size",
- type=int,
- default=48,
- help="Network feature dimension"
- )
- parser.add_argument(
- "--spatial_dims",
- type=int,
- default=2,
- choices=[2, 3],
- help="Spatial dimension (2D or 3D)"
- )
- parser.add_argument(
- "--use_wavelet",
- type=bool,
- default=True,
- help="Whether to enable wavelet enhancement module"
- )
- parser.add_argument(
- "--wavelet_J",
- type=int,
- default=2,
- help="Wavelet decomposition levels"
- )
- parser.add_argument(
- "--wavelet_wave",
- type=str,
- default="db4",
- help="Wavelet basis type"
- )
- parser.add_argument(
- "--wavelet_reduction",
- type=int,
- default=16,
- help="Wavelet attention compression ratio"
- )
- parser.add_argument(
- "--use_fft",
- type=bool,
- default=True,
- help="Whether to enable FFT enhancement module"
- )
- parser.add_argument(
- "--use_v2",
- type=bool,
- default=True,
- help="Whether to enable Swin-UNETR v2 module"
- )
- # ==================== Model Loading Parameters ====================
- parser.add_argument(
- "--device",
- type=str,
- default="cuda" if torch.cuda.is_available() else "cpu",
- help="Training device (cuda or cpu)"
- )
- # ==================== Other Parameters ====================
- parser.add_argument(
- "--save_results",
- type=bool,
- default=False,
- help="是否保存预测结果"
- )
- parser.add_argument(
- "--dir_flag",
- type=str,
- default="_minute",
- help="Prediction result save file suffix"
- )
- parser.add_argument(
- "--results_dir",
- type=str,
- default="./evaluation_results",
- help="Directory for saving prediction results"
- )
- parser.add_argument(
- "--outputs_dir",
- type=str,
- default="./outputs",
- help="是否保存预测结果"
- )
- parser.add_argument(
- "--save_visualization",
- type=bool,
- default=True,
- help="Whether to save visualization results"
- )
- parser.add_argument(
- "--vis_num_samples",
- type=int,
- default=1000,
- help="Number of samples to save for visualization"
- )
- parser.add_argument(
- "--best_metric",
- type=str,
- default=False,
- help="Load best overall model, False means load best Dice model by default"
- )
- return parser.parse_args()
- def create_val_transform(target_spatial_size=(512, 512)):
- """
- Create validation set transformations
-
- Args:
- target_spatial_size: Target spatial size
-
- Returns:
- Compose: Validation transformation composition
- """
- def convert_label_to_single_channel(label_tensor):
- """Convert RGB labels to single-channel binary mask"""
- single_channel = label_tensor[0:1, :, :]
- binary_label = (single_channel > 127).float()
- return binary_label
- val_transforms = Compose([
- LoadImaged(keys=["image", "label"]),
- EnsureChannelFirstd(keys=["image", "label"]),
- Lambdad(keys=["label"], func=convert_label_to_single_channel),
- Resized(keys=["image", "label"], spatial_size=target_spatial_size,
- mode=("bilinear", "nearest")),
- ScaleIntensityd(keys=["image"]),
- ToTensord(keys=["image", "label"]),
- ])
- return val_transforms
- def create_dataloader(args, dataset):
- """
- Create data loader
-
- Args:
- args: Command line arguments
- dataset: Dataset
-
- Returns:
- DataLoader: Data loader
- """
- loader = DataLoader(
- dataset,
- batch_size=args.batch_size,
- shuffle=False,
- num_workers=args.num_workers,
- pin_memory=args.pin_memory,
- drop_last=False
- )
- return loader
- def load_model(args, checkpoint_path):
- """
- Load model
-
- Args:
- args: Command line arguments
- checkpoint_path: Checkpoint path
-
- Returns:
- model: Model with loaded weights
- """
- print(f"\nLoading model: {checkpoint_path}")
- # Create model
- model = Wavelet_FFT_SwinUNETR(
- in_channels=args.in_channels,
- out_channels=args.out_channels,
- feature_size=args.feature_size,
- spatial_dims=args.spatial_dims,
- wavelet_enhancement=args.use_wavelet,
- wavelet_J=args.wavelet_J,
- wavelet_wave=args.wavelet_wave,
- wavelet_mode='symmetric',
- wavelet_reduction=args.wavelet_reduction,
- fft_enhancement=args.use_fft,
- use_v2=args.use_v2
- )
- # Load weights
- if not os.path.exists(checkpoint_path):
- raise FileNotFoundError(f"Model file does not exist: {checkpoint_path}")
- checkpoint = torch.load(checkpoint_path, map_location=args.device)
- # Check checkpoint format
- if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
- model.load_state_dict(checkpoint["model_state_dict"])
- print("✓ Model weights loaded from training checkpoint")
- else:
- model.load_state_dict(checkpoint)
- print("✓ Model weights loaded directly")
- model = model.to(args.device)
- model.eval()
- print(f"✓ Model loaded and set to evaluation mode")
- print(f"✓ 使用设备:{args.device}")
- return model
- def evaluate_model(model, dataloader, args):
- """
- Evaluate model performance
- Args:
- model: Model
- dataloader: Data loader
- args: Command line arguments
- Returns:
- dict: Dictionary containing various metrics
- """
- print("\n" + "=" * 60)
- print("Starting evaluation...")
- print("=" * 60)
- # Initialize metrics
- dice_metric = DiceMetric(reduction="mean")
- iou_metric = MeanIoU(reduction="mean")
- hd_metric = HausdorffDistanceMetric(reduction="mean")
- hd95_metric = HausdorffDistanceMetric(reduction="mean", percentile=95)
- total_samples = 0
- flag = 0
- saved_vis_count = 0
- vis_dir = None
- # Create visualization save directory
- if args.save_visualization:
- vis_dir = os.path.join(args.results_dir, f"visualization_{args.dataset_name}")
- os.makedirs(vis_dir, exist_ok=True)
- print(f"✓ Visualization results will be saved to: {vis_dir}")
- with torch.no_grad():
- dice_metric.reset()
- iou_metric.reset()
- hd_metric.reset()
- hd95_metric.reset()
- for batch_idx, batch_data in enumerate(dataloader):
- images = batch_data["image"].to(args.device)
- labels = batch_data["label"].to(args.device)
- # Forward propagation
- outputs = model(images) # [B, 1, H, W]
- # Post-processing
- outputs = torch.sigmoid(outputs)
- outputs = (outputs > 0.5).int()
- if flag == 0:
- flag = 1
- print(f"\n{'=' * 60}")
- print(f"[Detailed Debug Info - Batch 0]")
- print(f"{'=' * 60}")
- print(f"Input image size: {images.shape}")
- print(f"Output image size after post-processing: {outputs.shape}")
- print(f"Unique values: {torch.unique(outputs)}")
- print(f"\n--- Labels ---")
- print(f"Labels shape: {labels.shape}")
- print(f"Labels unique values: {torch.unique(labels)}")
- print(f"{'=' * 60}\n")
- # Calculate metrics for current batch - use directly without additional processing
- dice_metric(y_pred=outputs, y=labels)
- iou_metric(y_pred=outputs, y=labels)
- hd_metric(y_pred=outputs, y=labels)
- hd95_metric(y_pred=outputs, y=labels)
- # Save visualization results
- if args.save_visualization and saved_vis_count < args.vis_num_samples:
- save_visualization(
- images=images,
- labels=labels,
- predictions=outputs,
- save_dir=vis_dir,
- batch_idx=batch_idx,
- max_samples=args.vis_num_samples - saved_vis_count
- )
- saved_vis_count += images.shape[0]
- # Print progress
- if (batch_idx + 1) % 10 == 0 or (batch_idx + 1) == len(dataloader):
- print(f"进度:{batch_idx + 1}/{len(dataloader)} batches")
- total_samples += images.shape[0]
- # Aggregate all metrics
- mean_dice = dice_metric.aggregate().item()
- mean_iou = iou_metric.aggregate().item()
- mean_hd = hd_metric.aggregate().item()
- mean_hd95 = hd95_metric.aggregate().item()
- results = {
- "mDice": mean_dice,
- "mIoU": mean_iou,
- "mHD": mean_hd,
- "mHD95": mean_hd95,
- "total_samples": total_samples,
- }
- return results
- def save_visualization(images, labels, predictions, save_dir, batch_idx, max_samples):
- """
- Save visualization results: original image, ground truth label, prediction label combined
-
- Args:
- images: Input image batch [B, C, H, W]
- labels: Ground truth labels [B, 1, H, W]
- predictions: Prediction labels [B, 1, H, W]
- save_dir: Save directory
- batch_idx: Batch index
- max_samples: Maximum samples to save
- """
- for i in range(min(images.shape[0], max_samples)):
- try:
- # Extract single sample
- image = images[i].cpu()
- label = labels[i].cpu()
- prediction = predictions[i].cpu()
- # Image processing: de-normalize and convert to RGB
- if image.shape[0] == 1: # 灰度图
- image_np = image[0].numpy() * 255
- image_rgb = np.stack([image_np] * 3, axis=-1).astype(np.uint8)
- else: # RGB image
- image_np = image.numpy().transpose(1, 2, 0)
- # De-normalize (assuming z-score normalization, approximate handling)
- image_np = np.clip((image_np - image_np.min()) / (image_np.max() - image_np.min() + 1e-8) * 255, 0, 255)
- image_rgb = image_np.astype(np.uint8)
- # Label processing: convert to binary mask
- label_np = label[0].numpy() if label.shape[0] == 1 else label.numpy()
- label_binary = (label_np > 0.5).astype(np.float32)
- # Prediction processing: convert to binary mask
- pred_np = prediction[0].numpy() if prediction.shape[0] == 1 else prediction.numpy()
- pred_binary = (pred_np > 0.5).astype(np.float32)
- # Create pure black and white label image
- label_bw = (label_binary * 255).astype(np.uint8)
- label_bw_3ch = np.stack([label_bw] * 3, axis=-1) # Convert to 3 channels for concatenation
- # Create pure black and white prediction image
- pred_bw = (pred_binary * 255).astype(np.uint8)
- pred_bw_3ch = np.stack([pred_bw] * 3, axis=-1) # Convert to 3 channels for concatenation
- # Horizontally concatenate three images: original, ground truth B&W, prediction B&W
- combined = np.hstack([image_rgb, label_bw_3ch, pred_bw_3ch])
- # Add text annotations (at the top of the image)
- h, w = image_rgb.shape[:2]
- # Calculate text width for centering
- orig_text_size = cv2.getTextSize('Original', cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
- gt_text_size = cv2.getTextSize('Ground Truth', cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
- pred_text_size = cv2.getTextSize('Prediction', cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
- combined = cv2.putText(combined, 'Original', ((w - orig_text_size[0]) // 2, 30),
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
- combined = cv2.putText(combined, 'Ground Truth', (w + (w - gt_text_size[0]) // 2, 30),
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
- combined = cv2.putText(combined, 'Prediction', (2 * w + (w - pred_text_size[0]) // 2, 30),
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
- # Save
- sample_idx = batch_idx * images.shape[0] + i
- save_path = os.path.join(save_dir, f'sample_{sample_idx:04d}.png')
- cv2.imwrite(save_path, cv2.cvtColor(combined, cv2.COLOR_RGB2BGR))
- except Exception as e:
- print(f"⚠️ Error saving sample {i}: {e}")
- continue
- def print_results(results, dataset_name, checkpoint_path, model):
- """
- Print evaluation results
-
- Args:
- results: Evaluation results dictionary
- dataset_name: Dataset name
- checkpoint_path: Model checkpoint path
- model: Model
- """
- print("\n" + "=" * 60)
- print("Evaluation Results")
- print("=" * 60)
- print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")
- print(f"Dataset: {dataset_name}")
- print(f"Model: {checkpoint_path}")
- print(f"Number of samples: {results['total_samples']}")
- print("-" * 60)
- print(f"mDice (Mean Dice Coefficient): {results['mDice']:.3f}")
- print(f"mIoU (Mean Intersection over Union): {results['mIoU']:.3f}")
- print(f"mHD (Mean Hausdorff Distance): {results['mHD']:.3f}")
- print(f"mHD95 (95% Hausdorff Distance): {results['mHD95']:.3f}")
- print("=" * 60)
- def save_results(results, dataset_name, checkpoint_path, results_dir, model):
- """
- Save evaluation results to file
-
- Args:
- results: Evaluation results dictionary
- dataset_name: Dataset name
- checkpoint_path: Model checkpoint path
- results_dir: Results save directory
- model: Model
- """
- os.makedirs(results_dir, exist_ok=True)
- # Generate filename
- result_file = os.path.join(results_dir, f"eval_{dataset_name}.txt")
- with open(result_file, 'w', encoding='utf-8') as f:
- f.write("=" * 60 + "\n")
- f.write("Polyp Segmentation Model Evaluation Report\n")
- f.write("=" * 60 + "\n\n")
- f.write(f"Model parameters: {sum(p.numel() for p in model.parameters())}\n")
- f.write(f"Dataset: {dataset_name}\n")
- f.write(f"Model checkpoint: {checkpoint_path}\n")
- f.write(f"Evaluation time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
- f.write(f"Number of samples: {results['total_samples']}\n\n")
- f.write("-" * 60 + "\n")
- f.write("Evaluation Metrics:\n")
- f.write("-" * 60 + "\n")
- f.write(f"mDice (Mean Dice Coefficient): {results['mDice']:.3f}\n")
- f.write(f"mIoU (Mean Intersection over Union): {results['mIoU']:.3f}\n")
- f.write(f"mHD (Mean Hausdorff Distance): {results['mHD']:.3f}\n")
- f.write(f"mHD95 (95% Hausdorff Distance): {results['mHD95']:.3f}\n")
- f.write("-" * 60 + "\n")
- print(f"\n✓ Evaluation results saved to: {result_file}")
- def main():
- """
- Main evaluation function
- """
- # ==================== Step 1: Parse arguments ====================
- args = parse_args()
- checkpoint_path = Path(args.outputs_dir + args.dir_flag) / f"best_dice_model_{args.dataset_name}.pt"
- if args.best_metric:
- checkpoint_path = Path(args.outputs_dir + args.dir_flag) / f"best_metric_model_{args.dataset_name}.pt"
- print("\n" + "=" * 60)
- print("Polyp Segmentation Model Evaluation")
- print("=" * 60)
- print(f"Dataset: {args.dataset_name}")
- print(f"Model checkpoint: {checkpoint_path}")
- # ==================== Step 2: Create validation set and data loader ====================
- print("\nLoading validation set...")
- val_transform = create_val_transform(args.target_spatial_size)
- val_dataset = PolypDetectionDataset(
- root_dir=Path(args.data_root) / args.dataset_name,
- flag='val',
- transform=val_transform
- )
- val_loader = create_dataloader(args, val_dataset)
- print(f"✓ Validation set size: {len(val_dataset)} samples")
- print(f"✓ Data loader: {len(val_loader)} batches")
- # ==================== Step 3: Load model ====================
- model = load_model(args, checkpoint_path)
- # ==================== Step 4: Evaluate model ====================
- results = evaluate_model(model, val_loader, args)
- # ==================== Step 5: Print and save results ====================
- print_results(results, args.dataset_name, checkpoint_path, model)
- if args.save_results:
- save_results(results, args.dataset_name, checkpoint_path, args.results_dir + args.dir_flag, model)
- print("\n" + "=" * 60)
- print("Evaluation completed!")
- print("=" * 60)
- if __name__ == "__main__":
- main()
|