from __future__ import annotations from typing import Any import torch from torch import nn try: from monai.metrics.hausdorff_distance import HausdorffDistanceMetric from monai.metrics.meandice import DiceMetric from monai.metrics.meaniou import MeanIoU except ImportError as exc: DiceMetric = None HausdorffDistanceMetric = None MeanIoU = None _MONAI_IMPORT_ERROR = exc else: _MONAI_IMPORT_ERROR = None METRIC_REGISTRY = { "dice": DiceMetric, "iou": MeanIoU, "miou": MeanIoU, "hausdorff": HausdorffDistanceMetric, "hd": HausdorffDistanceMetric, "hd95": HausdorffDistanceMetric, } DEFAULT_METRIC_CONFIG = { "binary": { "task_mode": "binary", "metrics": [ {"name": "dice"}, {"name": "iou"}, ], }, "multiclass": { "task_mode": "multiclass", "metrics": [ {"name": "dice"}, {"name": "miou"}, ], }, } def _require_monai() -> None: if _MONAI_IMPORT_ERROR is not None: raise ImportError( "MONAI is required for lib.tools.metrics. Install monai before building metrics." ) from _MONAI_IMPORT_ERROR def _mode_defaults(task_mode: str) -> dict[str, Any]: mode = task_mode.lower() if mode == "binary": return { "include_background": True, "reduction": "mean", } if mode == "multiclass": return { "include_background": False, "reduction": "mean", } raise ValueError(f"Unsupported task_mode '{task_mode}'. Expected 'binary' or 'multiclass'.") def prepare_metric_inputs( logits: torch.Tensor, target: torch.Tensor, *, task_mode: str = "binary", threshold: float = 0.5, num_classes: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Convert logits and target to MONAI metric-ready tensors. Returns tensors in channel-first one-hot/binary mask format expected by MONAI metrics. """ mode = task_mode.lower() if mode == "binary": pred = (torch.sigmoid(logits) >= threshold).float() if target.ndim == pred.ndim - 1: target = target.unsqueeze(1) elif target.ndim != pred.ndim: raise ValueError("Binary target shape must match logits or be missing the channel dimension.") return pred, (target > 0).float() if mode == "multiclass": if num_classes is None: if logits.ndim < 2: raise ValueError("Multiclass logits must include a class dimension.") num_classes = logits.shape[1] pred_labels = torch.argmax(logits, dim=1) pred = torch.nn.functional.one_hot(pred_labels.long(), num_classes=num_classes) pred = pred.movedim(-1, 1).float() if target.ndim == logits.ndim and target.shape[1] == num_classes: target_out = target.float() else: if target.ndim == logits.ndim and target.shape[1] == 1: target = target.squeeze(1) elif target.ndim != logits.ndim - 1: raise ValueError( "Multiclass target must be class indices [B,H,W] / [B,1,H,W] or one-hot [B,C,H,W]." ) target_out = torch.nn.functional.one_hot(target.long(), num_classes=num_classes) target_out = target_out.movedim(-1, 1).float() return pred, target_out raise ValueError(f"Unsupported task_mode '{task_mode}'. Expected 'binary' or 'multiclass'.") def update_metrics( metrics: dict[str, Any], logits: torch.Tensor, target: torch.Tensor, *, task_mode: str = "binary", threshold: float = 0.5, num_classes: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Prepare prediction/target and update all MONAI metrics in-place.""" pred, target_out = prepare_metric_inputs( logits, target, task_mode=task_mode, threshold=threshold, num_classes=num_classes, ) for metric in metrics.values(): metric(y_pred=pred, y=target_out) return pred, target_out def compute_metrics(metrics: dict[str, Any]) -> dict[str, float]: """Aggregate MONAI metrics into plain python floats.""" results = {} for name, metric in metrics.items(): value = metric.aggregate() if isinstance(value, torch.Tensor): if value.numel() == 1: results[name] = float(value.item()) else: results[name] = float(value.float().mean().item()) else: results[name] = float(value) return results def reset_metrics(metrics: dict[str, Any]) -> None: """Reset all MONAI metrics in-place.""" for metric in metrics.values(): metric.reset() def get_default_metric_config(task_mode: str = "binary") -> dict[str, Any]: mode = task_mode.lower() if mode not in DEFAULT_METRIC_CONFIG: raise ValueError( f"Unsupported task_mode '{task_mode}'. Expected one of: {', '.join(DEFAULT_METRIC_CONFIG)}." ) default_cfg = DEFAULT_METRIC_CONFIG[mode] return { "task_mode": default_cfg["task_mode"], "metrics": [dict(metric_cfg) for metric_cfg in default_cfg["metrics"]], } def build_metric(config: dict[str, Any], *, task_mode: str = "binary"): """Build a single MONAI metric from a yaml-style config dict.""" _require_monai() if not isinstance(config, dict) or not config: raise ValueError("Metric config must be a non-empty dict.") name = config.get("name") if not isinstance(name, str) or not name: raise ValueError("Metric config must provide 'name'.") metric_cls = METRIC_REGISTRY.get(name.lower()) if metric_cls is None: raise ValueError( f"Unsupported metric '{name}'. Expected one of: {', '.join(METRIC_REGISTRY)}." ) params = _mode_defaults(task_mode) params.update(config.get("params", {})) if name.lower() == "hd95": params.setdefault("percentile", 95.0) return metric_cls(**params) def build_metrics(config: dict[str, Any] | list[dict[str, Any]] | None): """Build metrics from a yaml-style config. Supported examples: {"task_mode": "binary", "metrics": [{"name": "dice"}, {"name": "iou"}]} [{"name": "dice"}, {"name": "iou"}] """ _require_monai() if config is None: return {} if isinstance(config, list): metric_list = config task_mode = "binary" elif isinstance(config, dict): task_mode = config.get("task_mode", "binary") metric_list = config.get("metrics") else: raise ValueError("Metrics config must be a dict, a list, or None.") if metric_list is None: metric_list = get_default_metric_config(task_mode)["metrics"] if not isinstance(metric_list, list): raise ValueError("Metrics config field 'metrics' must be a list.") metrics = {} for metric_cfg in metric_list: if not isinstance(metric_cfg, dict): raise ValueError("Each metric config must be a dict.") name = metric_cfg.get("name") if not isinstance(name, str) or not name: raise ValueError("Each metric config must provide 'name'.") metrics[name] = build_metric(metric_cfg, task_mode=task_mode) return metrics __all__ = [ "DEFAULT_METRIC_CONFIG", "METRIC_REGISTRY", "build_metric", "build_metrics", "compute_metrics", "get_default_metric_config", "prepare_metric_inputs", "reset_metrics", "update_metrics", ]