metrics.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. from __future__ import annotations
  2. from typing import Any
  3. import torch
  4. from torch import nn
  5. try:
  6. from monai.metrics.hausdorff_distance import HausdorffDistanceMetric
  7. from monai.metrics.meandice import DiceMetric
  8. from monai.metrics.meaniou import MeanIoU
  9. except ImportError as exc:
  10. DiceMetric = None
  11. HausdorffDistanceMetric = None
  12. MeanIoU = None
  13. _MONAI_IMPORT_ERROR = exc
  14. else:
  15. _MONAI_IMPORT_ERROR = None
  16. METRIC_REGISTRY = {
  17. "dice": DiceMetric,
  18. "iou": MeanIoU,
  19. "miou": MeanIoU,
  20. "hausdorff": HausdorffDistanceMetric,
  21. "hd": HausdorffDistanceMetric,
  22. "hd95": HausdorffDistanceMetric,
  23. }
  24. DEFAULT_METRIC_CONFIG = {
  25. "binary": {
  26. "task_mode": "binary",
  27. "metrics": [
  28. {"name": "dice"},
  29. {"name": "iou"},
  30. ],
  31. },
  32. "multiclass": {
  33. "task_mode": "multiclass",
  34. "metrics": [
  35. {"name": "dice"},
  36. {"name": "miou"},
  37. ],
  38. },
  39. }
  40. def _require_monai() -> None:
  41. if _MONAI_IMPORT_ERROR is not None:
  42. raise ImportError(
  43. "MONAI is required for lib.tools.metrics. Install monai before building metrics."
  44. ) from _MONAI_IMPORT_ERROR
  45. def _mode_defaults(task_mode: str) -> dict[str, Any]:
  46. mode = task_mode.lower()
  47. if mode == "binary":
  48. return {
  49. "include_background": True,
  50. "reduction": "mean",
  51. }
  52. if mode == "multiclass":
  53. return {
  54. "include_background": False,
  55. "reduction": "mean",
  56. }
  57. raise ValueError(f"Unsupported task_mode '{task_mode}'. Expected 'binary' or 'multiclass'.")
  58. def prepare_metric_inputs(
  59. logits: torch.Tensor,
  60. target: torch.Tensor,
  61. *,
  62. task_mode: str = "binary",
  63. threshold: float = 0.5,
  64. num_classes: int | None = None,
  65. ) -> tuple[torch.Tensor, torch.Tensor]:
  66. """Convert logits and target to MONAI metric-ready tensors.
  67. Returns tensors in channel-first one-hot/binary mask format expected by MONAI metrics.
  68. """
  69. mode = task_mode.lower()
  70. if mode == "binary":
  71. pred = (torch.sigmoid(logits) >= threshold).float()
  72. if target.ndim == pred.ndim - 1:
  73. target = target.unsqueeze(1)
  74. elif target.ndim != pred.ndim:
  75. raise ValueError("Binary target shape must match logits or be missing the channel dimension.")
  76. return pred, (target > 0).float()
  77. if mode == "multiclass":
  78. if num_classes is None:
  79. if logits.ndim < 2:
  80. raise ValueError("Multiclass logits must include a class dimension.")
  81. num_classes = logits.shape[1]
  82. pred_labels = torch.argmax(logits, dim=1)
  83. pred = torch.nn.functional.one_hot(pred_labels.long(), num_classes=num_classes)
  84. pred = pred.movedim(-1, 1).float()
  85. if target.ndim == logits.ndim and target.shape[1] == num_classes:
  86. target_out = target.float()
  87. else:
  88. if target.ndim == logits.ndim and target.shape[1] == 1:
  89. target = target.squeeze(1)
  90. elif target.ndim != logits.ndim - 1:
  91. raise ValueError(
  92. "Multiclass target must be class indices [B,H,W] / [B,1,H,W] or one-hot [B,C,H,W]."
  93. )
  94. target_out = torch.nn.functional.one_hot(target.long(), num_classes=num_classes)
  95. target_out = target_out.movedim(-1, 1).float()
  96. return pred, target_out
  97. raise ValueError(f"Unsupported task_mode '{task_mode}'. Expected 'binary' or 'multiclass'.")
  98. def update_metrics(
  99. metrics: dict[str, Any],
  100. logits: torch.Tensor,
  101. target: torch.Tensor,
  102. *,
  103. task_mode: str = "binary",
  104. threshold: float = 0.5,
  105. num_classes: int | None = None,
  106. ) -> tuple[torch.Tensor, torch.Tensor]:
  107. """Prepare prediction/target and update all MONAI metrics in-place."""
  108. pred, target_out = prepare_metric_inputs(
  109. logits,
  110. target,
  111. task_mode=task_mode,
  112. threshold=threshold,
  113. num_classes=num_classes,
  114. )
  115. for metric in metrics.values():
  116. metric(y_pred=pred, y=target_out)
  117. return pred, target_out
  118. def compute_metrics(metrics: dict[str, Any]) -> dict[str, float]:
  119. """Aggregate MONAI metrics into plain python floats."""
  120. results = {}
  121. for name, metric in metrics.items():
  122. value = metric.aggregate()
  123. if isinstance(value, torch.Tensor):
  124. if value.numel() == 1:
  125. results[name] = float(value.item())
  126. else:
  127. results[name] = float(value.float().mean().item())
  128. else:
  129. results[name] = float(value)
  130. return results
  131. def reset_metrics(metrics: dict[str, Any]) -> None:
  132. """Reset all MONAI metrics in-place."""
  133. for metric in metrics.values():
  134. metric.reset()
  135. def get_default_metric_config(task_mode: str = "binary") -> dict[str, Any]:
  136. mode = task_mode.lower()
  137. if mode not in DEFAULT_METRIC_CONFIG:
  138. raise ValueError(
  139. f"Unsupported task_mode '{task_mode}'. Expected one of: {', '.join(DEFAULT_METRIC_CONFIG)}."
  140. )
  141. default_cfg = DEFAULT_METRIC_CONFIG[mode]
  142. return {
  143. "task_mode": default_cfg["task_mode"],
  144. "metrics": [dict(metric_cfg) for metric_cfg in default_cfg["metrics"]],
  145. }
  146. def build_metric(config: dict[str, Any], *, task_mode: str = "binary"):
  147. """Build a single MONAI metric from a yaml-style config dict."""
  148. _require_monai()
  149. if not isinstance(config, dict) or not config:
  150. raise ValueError("Metric config must be a non-empty dict.")
  151. name = config.get("name")
  152. if not isinstance(name, str) or not name:
  153. raise ValueError("Metric config must provide 'name'.")
  154. metric_cls = METRIC_REGISTRY.get(name.lower())
  155. if metric_cls is None:
  156. raise ValueError(
  157. f"Unsupported metric '{name}'. Expected one of: {', '.join(METRIC_REGISTRY)}."
  158. )
  159. params = _mode_defaults(task_mode)
  160. params.update(config.get("params", {}))
  161. if name.lower() == "hd95":
  162. params.setdefault("percentile", 95.0)
  163. return metric_cls(**params)
  164. def build_metrics(config: dict[str, Any] | list[dict[str, Any]] | None):
  165. """Build metrics from a yaml-style config.
  166. Supported examples:
  167. {"task_mode": "binary", "metrics": [{"name": "dice"}, {"name": "iou"}]}
  168. [{"name": "dice"}, {"name": "iou"}]
  169. """
  170. _require_monai()
  171. if config is None:
  172. return {}
  173. if isinstance(config, list):
  174. metric_list = config
  175. task_mode = "binary"
  176. elif isinstance(config, dict):
  177. task_mode = config.get("task_mode", "binary")
  178. metric_list = config.get("metrics")
  179. else:
  180. raise ValueError("Metrics config must be a dict, a list, or None.")
  181. if metric_list is None:
  182. metric_list = get_default_metric_config(task_mode)["metrics"]
  183. if not isinstance(metric_list, list):
  184. raise ValueError("Metrics config field 'metrics' must be a list.")
  185. metrics = {}
  186. for metric_cfg in metric_list:
  187. if not isinstance(metric_cfg, dict):
  188. raise ValueError("Each metric config must be a dict.")
  189. name = metric_cfg.get("name")
  190. if not isinstance(name, str) or not name:
  191. raise ValueError("Each metric config must provide 'name'.")
  192. metrics[name] = build_metric(metric_cfg, task_mode=task_mode)
  193. return metrics
  194. __all__ = [
  195. "DEFAULT_METRIC_CONFIG",
  196. "METRIC_REGISTRY",
  197. "build_metric",
  198. "build_metrics",
  199. "compute_metrics",
  200. "get_default_metric_config",
  201. "prepare_metric_inputs",
  202. "reset_metrics",
  203. "update_metrics",
  204. ]