loss.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. from __future__ import annotations
  2. from typing import Any
  3. from torch import nn
  4. try:
  5. from monai.losses.dice import DiceCELoss, DiceFocalLoss, GeneralizedDiceFocalLoss
  6. from monai.losses.hausdorff_loss import HausdorffDTLoss
  7. except ImportError as exc:
  8. DiceCELoss = None
  9. DiceFocalLoss = None
  10. GeneralizedDiceFocalLoss = None
  11. HausdorffDTLoss = None
  12. _MONAI_IMPORT_ERROR = exc
  13. else:
  14. _MONAI_IMPORT_ERROR = None
  15. LOSS_REGISTRY = {
  16. "dicece": DiceCELoss,
  17. "dicefocal": DiceFocalLoss,
  18. "generalized_dice_focal": GeneralizedDiceFocalLoss,
  19. "gdl_focal": GeneralizedDiceFocalLoss,
  20. "hausdorff": HausdorffDTLoss,
  21. }
  22. DEFAULT_TASK_LOSS = {
  23. "lun": {
  24. "name": "dicece",
  25. "params": {
  26. "include_background": True,
  27. "lambda_dice": 0.7,
  28. "lambda_ce": 0.3,
  29. },
  30. },
  31. "pe": {
  32. "name": "dicece",
  33. "params": {
  34. "include_background": True,
  35. "lambda_dice": 0.7,
  36. "lambda_ce": 0.3,
  37. },
  38. },
  39. "b": {
  40. "name": "dicefocal",
  41. "params": {
  42. "include_background": True,
  43. "lambda_dice": 1.0,
  44. "lambda_focal": 1.0,
  45. "gamma": 2.0,
  46. },
  47. },
  48. }
  49. def _require_monai() -> None:
  50. if _MONAI_IMPORT_ERROR is not None:
  51. raise ImportError(
  52. "MONAI is required for lib.tools.tools. Install monai before building losses."
  53. ) from _MONAI_IMPORT_ERROR
  54. def _mode_defaults(task_mode: str) -> dict[str, Any]:
  55. mode = task_mode.lower()
  56. if mode == "binary":
  57. return {
  58. "sigmoid": True,
  59. "softmax": False,
  60. "to_onehot_y": False,
  61. "reduction": "mean",
  62. }
  63. if mode == "multiclass":
  64. return {
  65. "sigmoid": False,
  66. "softmax": True,
  67. "to_onehot_y": True,
  68. "reduction": "mean",
  69. }
  70. raise ValueError(f"Unsupported task_mode '{task_mode}'. Expected 'binary' or 'multiclass'.")
  71. def _get_task_config(task_name: str) -> dict[str, Any]:
  72. task = task_name.lower()
  73. if task not in DEFAULT_TASK_LOSS:
  74. raise ValueError(f"Unsupported task '{task_name}'. Expected one of: {', '.join(DEFAULT_TASK_LOSS)}.")
  75. task_cfg = DEFAULT_TASK_LOSS[task]
  76. return {
  77. "name": task_cfg["name"],
  78. "params": dict(task_cfg.get("params", {})),
  79. }
  80. def build_loss(config: dict[str, Any]) -> nn.Module:
  81. """Build a MONAI tools from a plain yaml-style config dict.
  82. Supported examples:
  83. {"task_name": "lun"}
  84. {"name": "dicece", "task_mode": "binary", "params": {"lambda_ce": 0.5}}
  85. {"name": "dicece", "task_mode": "multiclass", "params": {"include_background": True}}
  86. """
  87. _require_monai()
  88. if not isinstance(config, dict) or not config:
  89. raise ValueError("Loss config must be a non-empty dict.")
  90. task_mode = config.get("task_mode", "binary")
  91. if not isinstance(task_mode, str):
  92. raise ValueError("Loss config field 'task_mode' must be a string.")
  93. if "task_name" in config and config["task_name"] is not None:
  94. if not isinstance(config["task_name"], str):
  95. raise ValueError("Loss config field 'task_name' must be a string.")
  96. merged = _get_task_config(config["task_name"])
  97. merged["params"].update(config.get("params", {}))
  98. if "name" in config:
  99. merged["name"] = config["name"]
  100. else:
  101. name = config.get("name")
  102. if not isinstance(name, str) or not name:
  103. raise ValueError("Loss config must provide 'name' or 'task_name'.")
  104. merged = {
  105. "name": name,
  106. "params": dict(config.get("params", {})),
  107. }
  108. if not isinstance(merged["params"], dict):
  109. raise ValueError("Loss config field 'params' must be a dict if provided.")
  110. params = _mode_defaults(task_mode)
  111. params.update(merged["params"])
  112. loss_name = merged["name"].lower()
  113. loss_cls = LOSS_REGISTRY.get(loss_name)
  114. if loss_cls is None:
  115. raise ValueError(
  116. f"Unsupported tools '{merged['name']}'. Expected one of: {', '.join(LOSS_REGISTRY)}."
  117. )
  118. return loss_cls(**params)
  119. __all__ = ["DEFAULT_TASK_LOSS", "LOSS_REGISTRY", "build_loss"]