import argparse import os from datetime import datetime from pathlib import Path import cv2 import numpy as np import torch from monai.metrics import DiceMetric, MeanIoU, HausdorffDistanceMetric from monai.transforms import ( Compose, LoadImaged, ScaleIntensityd, EnsureChannelFirstd, ToTensord, Resized, Lambdad ) from torch.utils.data import DataLoader from datasets.PolypDetectionDataset.PolypDetectionDataset import PolypDetectionDataset from lib.model.model import Wavelet_FFT_SwinUNETR def parse_args(): parser = argparse.ArgumentParser(description="Polyp Segmentation Model Evaluation Script") # ==================== Dataset-related Parameters ==================== parser.add_argument( "--dataset_name", type=str, required=True, help="Dataset name" ) parser.add_argument( "--data_root", type=str, default=r"./data/Polyp-Detection-Dataset", help="Root directory path of the dataset" ) parser.add_argument( "--num_workers", type=int, default=0, help="Number of worker processes for data loader" ) parser.add_argument( "--pin_memory", type=bool, default=True, help="Whether to enable pinned memory" ) parser.add_argument( "--target_spatial_size", type=tuple, default=(512, 512), help="Target spatial size, format like: '512,512' or '(512,512)'" ) parser.add_argument( "--batch_size", type=int, default=1, help="Batch size (smaller batch size recommended for evaluation)" ) # ==================== Model-related Parameters ==================== parser.add_argument( "--in_channels", type=int, default=3, help="Number of input image channels" ) parser.add_argument( "--out_channels", type=int, default=1, help="Number of output foreground channels" ) parser.add_argument( "--feature_size", type=int, default=48, help="Network feature dimension" ) parser.add_argument( "--spatial_dims", type=int, default=2, choices=[2, 3], help="Spatial dimension (2D or 3D)" ) parser.add_argument( "--use_wavelet", type=bool, default=True, help="Whether to enable wavelet enhancement module" ) parser.add_argument( "--wavelet_J", type=int, default=2, help="Wavelet decomposition levels" ) parser.add_argument( "--wavelet_wave", type=str, default="db4", help="Wavelet basis type" ) parser.add_argument( "--wavelet_reduction", type=int, default=16, help="Wavelet attention compression ratio" ) parser.add_argument( "--use_fft", type=bool, default=True, help="Whether to enable FFT enhancement module" ) parser.add_argument( "--use_v2", type=bool, default=True, help="Whether to enable Swin-UNETR v2 module" ) # ==================== Model Loading Parameters ==================== parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Training device (cuda or cpu)" ) # ==================== Other Parameters ==================== parser.add_argument( "--save_results", type=bool, default=False, help="是否保存预测结果" ) parser.add_argument( "--dir_flag", type=str, default="_v4_minute", help="Prediction result save file suffix" ) parser.add_argument( "--results_dir", type=str, default="./evaluation_results", help="Directory for saving prediction results" ) parser.add_argument( "--outputs_dir", type=str, default="./outputs", help="是否保存预测结果" ) parser.add_argument( "--save_visualization", type=bool, default=True, help="Whether to save visualization results" ) parser.add_argument( "--vis_num_samples", type=int, default=1000, help="Number of samples to save for visualization" ) parser.add_argument( "--best_metric", type=str, default=False, help="Load best overall model, False means load best Dice model by default" ) return parser.parse_args() def create_val_transform(target_spatial_size=(512, 512)): """ Create validation set transformations Args: target_spatial_size: Target spatial size Returns: Compose: Validation transformation composition """ def convert_label_to_single_channel(label_tensor): """Convert RGB labels to single-channel binary mask""" single_channel = label_tensor[0:1, :, :] binary_label = (single_channel > 127).float() return binary_label val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Lambdad(keys=["label"], func=convert_label_to_single_channel), Resized(keys=["image", "label"], spatial_size=target_spatial_size, mode=("bilinear", "nearest")), ScaleIntensityd(keys=["image"]), ToTensord(keys=["image", "label"]), ]) return val_transforms def create_dataloader(args, dataset): """ Create data loader Args: args: Command line arguments dataset: Dataset Returns: DataLoader: Data loader """ loader = DataLoader( dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_memory, drop_last=False ) return loader def load_model(args, checkpoint_path): """ Load model Args: args: Command line arguments checkpoint_path: Checkpoint path Returns: model: Model with loaded weights """ print(f"\nLoading model: {checkpoint_path}") # Create model model = Wavelet_FFT_SwinUNETR( in_channels=args.in_channels, out_channels=args.out_channels, feature_size=args.feature_size, spatial_dims=args.spatial_dims, wavelet_enhancement=args.use_wavelet, wavelet_J=args.wavelet_J, wavelet_wave=args.wavelet_wave, wavelet_mode='symmetric', wavelet_reduction=args.wavelet_reduction, fft_enhancement=args.use_fft, use_v2=args.use_v2 ) # Load weights if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Model file does not exist: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location=args.device) # Check checkpoint format if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: model.load_state_dict(checkpoint["model_state_dict"]) print("✓ Model weights loaded from training checkpoint") else: model.load_state_dict(checkpoint) print("✓ Model weights loaded directly") model = model.to(args.device) model.eval() print(f"✓ Model loaded and set to evaluation mode") print(f"✓ 使用设备:{args.device}") return model def evaluate_model(model, dataloader, args): """ Evaluate model performance Args: model: Model dataloader: Data loader args: Command line arguments Returns: dict: Dictionary containing various metrics """ print("\n" + "=" * 60) print("Starting evaluation...") print("=" * 60) # Initialize metrics dice_metric = DiceMetric(reduction="mean") iou_metric = MeanIoU(reduction="mean") hd_metric = HausdorffDistanceMetric(reduction="mean") hd95_metric = HausdorffDistanceMetric(reduction="mean", percentile=95) total_samples = 0 flag = 0 saved_vis_count = 0 vis_dir = None # Create visualization save directory if args.save_visualization: vis_dir = os.path.join(args.results_dir, f"visualization_{args.dataset_name}") os.makedirs(vis_dir, exist_ok=True) print(f"✓ Visualization results will be saved to: {vis_dir}") with torch.no_grad(): dice_metric.reset() iou_metric.reset() hd_metric.reset() hd95_metric.reset() for batch_idx, batch_data in enumerate(dataloader): images = batch_data["image"].to(args.device) labels = batch_data["label"].to(args.device) # Forward propagation outputs = model(images) # [B, 1, H, W] # Post-processing outputs = torch.sigmoid(outputs) outputs = (outputs > 0.5).int() if flag == 0: flag = 1 print(f"\n{'=' * 60}") print(f"[Detailed Debug Info - Batch 0]") print(f"{'=' * 60}") print(f"Input image size: {images.shape}") print(f"Output image size after post-processing: {outputs.shape}") print(f"Unique values: {torch.unique(outputs)}") print(f"\n--- Labels ---") print(f"Labels shape: {labels.shape}") print(f"Labels unique values: {torch.unique(labels)}") print(f"{'=' * 60}\n") # Calculate metrics for current batch - use directly without additional processing dice_metric(y_pred=outputs, y=labels) iou_metric(y_pred=outputs, y=labels) hd_metric(y_pred=outputs, y=labels) hd95_metric(y_pred=outputs, y=labels) # Save visualization results if args.save_visualization and saved_vis_count < args.vis_num_samples: save_visualization( images=images, labels=labels, predictions=outputs, save_dir=vis_dir, batch_idx=batch_idx, max_samples=args.vis_num_samples - saved_vis_count ) saved_vis_count += images.shape[0] # Print progress if (batch_idx + 1) % 10 == 0 or (batch_idx + 1) == len(dataloader): print(f"进度:{batch_idx + 1}/{len(dataloader)} batches") total_samples += images.shape[0] # Aggregate all metrics mean_dice = dice_metric.aggregate().item() mean_iou = iou_metric.aggregate().item() mean_hd = hd_metric.aggregate().item() mean_hd95 = hd95_metric.aggregate().item() results = { "mDice": mean_dice, "mIoU": mean_iou, "mHD": mean_hd, "mHD95": mean_hd95, "total_samples": total_samples, } return results def save_visualization(images, labels, predictions, save_dir, batch_idx, max_samples): """ Save visualization results: original image, ground truth label, prediction label combined Args: images: Input image batch [B, C, H, W] labels: Ground truth labels [B, 1, H, W] predictions: Prediction labels [B, 1, H, W] save_dir: Save directory batch_idx: Batch index max_samples: Maximum samples to save """ for i in range(min(images.shape[0], max_samples)): try: # Extract single sample image = images[i].cpu() label = labels[i].cpu() prediction = predictions[i].cpu() # Image processing: de-normalize and convert to RGB if image.shape[0] == 1: # 灰度图 image_np = image[0].numpy() * 255 image_rgb = np.stack([image_np] * 3, axis=-1).astype(np.uint8) else: # RGB image image_np = image.numpy().transpose(1, 2, 0) # De-normalize (assuming z-score normalization, approximate handling) image_np = np.clip((image_np - image_np.min()) / (image_np.max() - image_np.min() + 1e-8) * 255, 0, 255) image_rgb = image_np.astype(np.uint8) # Label processing: convert to binary mask label_np = label[0].numpy() if label.shape[0] == 1 else label.numpy() label_binary = (label_np > 0.5).astype(np.float32) # Prediction processing: convert to binary mask pred_np = prediction[0].numpy() if prediction.shape[0] == 1 else prediction.numpy() pred_binary = (pred_np > 0.5).astype(np.float32) # Create pure black and white label image label_bw = (label_binary * 255).astype(np.uint8) label_bw_3ch = np.stack([label_bw] * 3, axis=-1) # Convert to 3 channels for concatenation # Create pure black and white prediction image pred_bw = (pred_binary * 255).astype(np.uint8) pred_bw_3ch = np.stack([pred_bw] * 3, axis=-1) # Convert to 3 channels for concatenation # Horizontally concatenate three images: original, ground truth B&W, prediction B&W combined = np.hstack([image_rgb, label_bw_3ch, pred_bw_3ch]) # Add text annotations (at the top of the image) h, w = image_rgb.shape[:2] # Calculate text width for centering orig_text_size = cv2.getTextSize('Original', cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0] gt_text_size = cv2.getTextSize('Ground Truth', cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0] pred_text_size = cv2.getTextSize('Prediction', cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0] combined = cv2.putText(combined, 'Original', ((w - orig_text_size[0]) // 2, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2) combined = cv2.putText(combined, 'Ground Truth', (w + (w - gt_text_size[0]) // 2, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2) combined = cv2.putText(combined, 'Prediction', (2 * w + (w - pred_text_size[0]) // 2, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2) # Save sample_idx = batch_idx * images.shape[0] + i save_path = os.path.join(save_dir, f'sample_{sample_idx:04d}.png') cv2.imwrite(save_path, cv2.cvtColor(combined, cv2.COLOR_RGB2BGR)) except Exception as e: print(f"⚠️ Error saving sample {i}: {e}") continue def print_results(results, dataset_name, checkpoint_path, model): """ Print evaluation results Args: results: Evaluation results dictionary dataset_name: Dataset name checkpoint_path: Model checkpoint path model: Model """ print("\n" + "=" * 60) print("Evaluation Results") print("=" * 60) print(f"Model parameters: {sum(p.numel() for p in model.parameters())}") print(f"Dataset: {dataset_name}") print(f"Model: {checkpoint_path}") print(f"Number of samples: {results['total_samples']}") print("-" * 60) print(f"mDice (Mean Dice Coefficient): {results['mDice']:.3f}") print(f"mIoU (Mean Intersection over Union): {results['mIoU']:.3f}") print(f"mHD (Mean Hausdorff Distance): {results['mHD']:.3f}") print(f"mHD95 (95% Hausdorff Distance): {results['mHD95']:.3f}") print("=" * 60) def save_results(results, dataset_name, checkpoint_path, results_dir, model): """ Save evaluation results to file Args: results: Evaluation results dictionary dataset_name: Dataset name checkpoint_path: Model checkpoint path results_dir: Results save directory model: Model """ os.makedirs(results_dir, exist_ok=True) # Generate filename result_file = os.path.join(results_dir, f"eval_{dataset_name}.txt") with open(result_file, 'w', encoding='utf-8') as f: f.write("=" * 60 + "\n") f.write("Polyp Segmentation Model Evaluation Report\n") f.write("=" * 60 + "\n\n") f.write(f"Model parameters: {sum(p.numel() for p in model.parameters())}\n") f.write(f"Dataset: {dataset_name}\n") f.write(f"Model checkpoint: {checkpoint_path}\n") f.write(f"Evaluation time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") f.write(f"Number of samples: {results['total_samples']}\n\n") f.write("-" * 60 + "\n") f.write("Evaluation Metrics:\n") f.write("-" * 60 + "\n") f.write(f"mDice (Mean Dice Coefficient): {results['mDice']:.3f}\n") f.write(f"mIoU (Mean Intersection over Union): {results['mIoU']:.3f}\n") f.write(f"mHD (Mean Hausdorff Distance): {results['mHD']:.3f}\n") f.write(f"mHD95 (95% Hausdorff Distance): {results['mHD95']:.3f}\n") f.write("-" * 60 + "\n") print(f"\n✓ Evaluation results saved to: {result_file}") def main(): """ Main evaluation function """ # ==================== Step 1: Parse arguments ==================== args = parse_args() checkpoint_path = Path(args.outputs_dir + args.dir_flag) / f"best_dice_model_{args.dataset_name}.pt" if args.best_metric: checkpoint_path = Path(args.outputs_dir + args.dir_flag) / f"best_metric_model_{args.dataset_name}.pt" print("\n" + "=" * 60) print("Polyp Segmentation Model Evaluation") print("=" * 60) print(f"Dataset: {args.dataset_name}") print(f"Model checkpoint: {checkpoint_path}") # ==================== Step 2: Create validation set and data loader ==================== print("\nLoading validation set...") val_transform = create_val_transform(args.target_spatial_size) val_dataset = PolypDetectionDataset( root_dir=Path(args.data_root) / args.dataset_name, flag='val', transform=val_transform ) val_loader = create_dataloader(args, val_dataset) print(f"✓ Validation set size: {len(val_dataset)} samples") print(f"✓ Data loader: {len(val_loader)} batches") # ==================== Step 3: Load model ==================== model = load_model(args, checkpoint_path) # ==================== Step 4: Evaluate model ==================== results = evaluate_model(model, val_loader, args) # ==================== Step 5: Print and save results ==================== print_results(results, args.dataset_name, checkpoint_path, model) if args.save_results: save_results(results, args.dataset_name, checkpoint_path, args.results_dir + args.dir_flag, model) print("\n" + "=" * 60) print("Evaluation completed!") print("=" * 60) if __name__ == "__main__": main()