| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261 |
- from __future__ import annotations
- from typing import Any
- import torch
- from torch import nn
- try:
- from monai.metrics.hausdorff_distance import HausdorffDistanceMetric
- from monai.metrics.meandice import DiceMetric
- from monai.metrics.meaniou import MeanIoU
- except ImportError as exc:
- DiceMetric = None
- HausdorffDistanceMetric = None
- MeanIoU = None
- _MONAI_IMPORT_ERROR = exc
- else:
- _MONAI_IMPORT_ERROR = None
- METRIC_REGISTRY = {
- "dice": DiceMetric,
- "iou": MeanIoU,
- "miou": MeanIoU,
- "hausdorff": HausdorffDistanceMetric,
- "hd": HausdorffDistanceMetric,
- "hd95": HausdorffDistanceMetric,
- }
- DEFAULT_METRIC_CONFIG = {
- "binary": {
- "task_mode": "binary",
- "metrics": [
- {"name": "dice"},
- {"name": "iou"},
- ],
- },
- "multiclass": {
- "task_mode": "multiclass",
- "metrics": [
- {"name": "dice"},
- {"name": "miou"},
- ],
- },
- }
- def _require_monai() -> None:
- if _MONAI_IMPORT_ERROR is not None:
- raise ImportError(
- "MONAI is required for lib.tools.metrics. Install monai before building metrics."
- ) from _MONAI_IMPORT_ERROR
- def _mode_defaults(task_mode: str) -> dict[str, Any]:
- mode = task_mode.lower()
- if mode == "binary":
- return {
- "include_background": True,
- "reduction": "mean",
- }
- if mode == "multiclass":
- return {
- "include_background": False,
- "reduction": "mean",
- }
- raise ValueError(f"Unsupported task_mode '{task_mode}'. Expected 'binary' or 'multiclass'.")
- def prepare_metric_inputs(
- logits: torch.Tensor,
- target: torch.Tensor,
- *,
- task_mode: str = "binary",
- threshold: float = 0.5,
- num_classes: int | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """Convert logits and target to MONAI metric-ready tensors.
- Returns tensors in channel-first one-hot/binary mask format expected by MONAI metrics.
- """
- mode = task_mode.lower()
- if mode == "binary":
- pred = (torch.sigmoid(logits) >= threshold).float()
- if target.ndim == pred.ndim - 1:
- target = target.unsqueeze(1)
- elif target.ndim != pred.ndim:
- raise ValueError("Binary target shape must match logits or be missing the channel dimension.")
- return pred, (target > 0).float()
- if mode == "multiclass":
- if num_classes is None:
- if logits.ndim < 2:
- raise ValueError("Multiclass logits must include a class dimension.")
- num_classes = logits.shape[1]
- pred_labels = torch.argmax(logits, dim=1)
- pred = torch.nn.functional.one_hot(pred_labels.long(), num_classes=num_classes)
- pred = pred.movedim(-1, 1).float()
- if target.ndim == logits.ndim and target.shape[1] == num_classes:
- target_out = target.float()
- else:
- if target.ndim == logits.ndim and target.shape[1] == 1:
- target = target.squeeze(1)
- elif target.ndim != logits.ndim - 1:
- raise ValueError(
- "Multiclass target must be class indices [B,H,W] / [B,1,H,W] or one-hot [B,C,H,W]."
- )
- target_out = torch.nn.functional.one_hot(target.long(), num_classes=num_classes)
- target_out = target_out.movedim(-1, 1).float()
- return pred, target_out
- raise ValueError(f"Unsupported task_mode '{task_mode}'. Expected 'binary' or 'multiclass'.")
- def update_metrics(
- metrics: dict[str, Any],
- logits: torch.Tensor,
- target: torch.Tensor,
- *,
- task_mode: str = "binary",
- threshold: float = 0.5,
- num_classes: int | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """Prepare prediction/target and update all MONAI metrics in-place."""
- pred, target_out = prepare_metric_inputs(
- logits,
- target,
- task_mode=task_mode,
- threshold=threshold,
- num_classes=num_classes,
- )
- for metric in metrics.values():
- metric(y_pred=pred, y=target_out)
- return pred, target_out
- def compute_metrics(metrics: dict[str, Any]) -> dict[str, float]:
- """Aggregate MONAI metrics into plain python floats."""
- results = {}
- for name, metric in metrics.items():
- value = metric.aggregate()
- if isinstance(value, torch.Tensor):
- if value.numel() == 1:
- results[name] = float(value.item())
- else:
- results[name] = float(value.float().mean().item())
- else:
- results[name] = float(value)
- return results
- def reset_metrics(metrics: dict[str, Any]) -> None:
- """Reset all MONAI metrics in-place."""
- for metric in metrics.values():
- metric.reset()
- def get_default_metric_config(task_mode: str = "binary") -> dict[str, Any]:
- mode = task_mode.lower()
- if mode not in DEFAULT_METRIC_CONFIG:
- raise ValueError(
- f"Unsupported task_mode '{task_mode}'. Expected one of: {', '.join(DEFAULT_METRIC_CONFIG)}."
- )
- default_cfg = DEFAULT_METRIC_CONFIG[mode]
- return {
- "task_mode": default_cfg["task_mode"],
- "metrics": [dict(metric_cfg) for metric_cfg in default_cfg["metrics"]],
- }
- def build_metric(config: dict[str, Any], *, task_mode: str = "binary"):
- """Build a single MONAI metric from a yaml-style config dict."""
- _require_monai()
- if not isinstance(config, dict) or not config:
- raise ValueError("Metric config must be a non-empty dict.")
- name = config.get("name")
- if not isinstance(name, str) or not name:
- raise ValueError("Metric config must provide 'name'.")
- metric_cls = METRIC_REGISTRY.get(name.lower())
- if metric_cls is None:
- raise ValueError(
- f"Unsupported metric '{name}'. Expected one of: {', '.join(METRIC_REGISTRY)}."
- )
- params = _mode_defaults(task_mode)
- params.update(config.get("params", {}))
- if name.lower() == "hd95":
- params.setdefault("percentile", 95.0)
- return metric_cls(**params)
- def build_metrics(config: dict[str, Any] | list[dict[str, Any]] | None):
- """Build metrics from a yaml-style config.
- Supported examples:
- {"task_mode": "binary", "metrics": [{"name": "dice"}, {"name": "iou"}]}
- [{"name": "dice"}, {"name": "iou"}]
- """
- _require_monai()
- if config is None:
- return {}
- if isinstance(config, list):
- metric_list = config
- task_mode = "binary"
- elif isinstance(config, dict):
- task_mode = config.get("task_mode", "binary")
- metric_list = config.get("metrics")
- else:
- raise ValueError("Metrics config must be a dict, a list, or None.")
- if metric_list is None:
- metric_list = get_default_metric_config(task_mode)["metrics"]
- if not isinstance(metric_list, list):
- raise ValueError("Metrics config field 'metrics' must be a list.")
- metrics = {}
- for metric_cfg in metric_list:
- if not isinstance(metric_cfg, dict):
- raise ValueError("Each metric config must be a dict.")
- name = metric_cfg.get("name")
- if not isinstance(name, str) or not name:
- raise ValueError("Each metric config must provide 'name'.")
- metrics[name] = build_metric(metric_cfg, task_mode=task_mode)
- return metrics
- __all__ = [
- "DEFAULT_METRIC_CONFIG",
- "METRIC_REGISTRY",
- "build_metric",
- "build_metrics",
- "compute_metrics",
- "get_default_metric_config",
- "prepare_metric_inputs",
- "reset_metrics",
- "update_metrics",
- ]
|