loss.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. from __future__ import annotations
  2. from typing import Any
  3. import torch
  4. import torch.nn.functional as F
  5. from torch import nn
  6. try:
  7. from monai.losses.dice import DiceCELoss, DiceFocalLoss, GeneralizedDiceFocalLoss
  8. from monai.losses.hausdorff_loss import HausdorffDTLoss
  9. except ImportError as exc:
  10. DiceCELoss = None
  11. DiceFocalLoss = None
  12. GeneralizedDiceFocalLoss = None
  13. HausdorffDTLoss = None
  14. _MONAI_IMPORT_ERROR = exc
  15. else:
  16. _MONAI_IMPORT_ERROR = None
  17. LOSS_REGISTRY = {
  18. "dicece": DiceCELoss,
  19. "dicefocal": DiceFocalLoss,
  20. "generalized_dice_focal": GeneralizedDiceFocalLoss,
  21. "gdl_focal": GeneralizedDiceFocalLoss,
  22. "hausdorff": HausdorffDTLoss,
  23. }
  24. DEFAULT_TASK_LOSS = {
  25. "lun": {
  26. "name": "dicece",
  27. "params": {
  28. "include_background": True,
  29. "lambda_dice": 0.7,
  30. "lambda_ce": 0.3,
  31. },
  32. },
  33. "pe": {
  34. "name": "dicece",
  35. "params": {
  36. "include_background": True,
  37. "lambda_dice": 0.7,
  38. "lambda_ce": 0.3,
  39. },
  40. },
  41. "b": {
  42. "name": "dicefocal",
  43. "params": {
  44. "include_background": True,
  45. "lambda_dice": 1.0,
  46. "lambda_focal": 1.0,
  47. "gamma": 2.0,
  48. },
  49. },
  50. }
  51. def _require_monai() -> None:
  52. if _MONAI_IMPORT_ERROR is not None:
  53. raise ImportError(
  54. "MONAI is required for lib.tools.tools. Install monai before building losses."
  55. ) from _MONAI_IMPORT_ERROR
  56. def _mode_defaults(task_mode: str) -> dict[str, Any]:
  57. mode = task_mode.lower()
  58. if mode == "binary":
  59. return {
  60. "sigmoid": True,
  61. "softmax": False,
  62. "to_onehot_y": False,
  63. "reduction": "mean",
  64. }
  65. if mode == "multiclass":
  66. return {
  67. "sigmoid": False,
  68. "softmax": True,
  69. "to_onehot_y": True,
  70. "reduction": "mean",
  71. }
  72. raise ValueError(f"Unsupported task_mode '{task_mode}'. Expected 'binary' or 'multiclass'.")
  73. def _get_task_config(task_name: str) -> dict[str, Any]:
  74. task = task_name.lower()
  75. if task not in DEFAULT_TASK_LOSS:
  76. raise ValueError(f"Unsupported task '{task_name}'. Expected one of: {', '.join(DEFAULT_TASK_LOSS)}.")
  77. task_cfg = DEFAULT_TASK_LOSS[task]
  78. return {
  79. "name": task_cfg["name"],
  80. "params": dict(task_cfg.get("params", {})),
  81. }
  82. def build_loss(config: dict[str, Any]) -> nn.Module:
  83. """Build a MONAI tools from a plain yaml-style config dict.
  84. Supported examples:
  85. {"task_name": "lun"}
  86. {"name": "dicece", "task_mode": "binary", "params": {"lambda_ce": 0.5}}
  87. {"name": "dicece", "task_mode": "multiclass", "params": {"include_background": True}}
  88. """
  89. _require_monai()
  90. if not isinstance(config, dict) or not config:
  91. raise ValueError("Loss config must be a non-empty dict.")
  92. task_mode = config.get("task_mode", "binary")
  93. if not isinstance(task_mode, str):
  94. raise ValueError("Loss config field 'task_mode' must be a string.")
  95. if "task_name" in config and config["task_name"] is not None:
  96. if not isinstance(config["task_name"], str):
  97. raise ValueError("Loss config field 'task_name' must be a string.")
  98. merged = _get_task_config(config["task_name"])
  99. merged["params"].update(config.get("params", {}))
  100. if "name" in config:
  101. merged["name"] = config["name"]
  102. else:
  103. name = config.get("name")
  104. if not isinstance(name, str) or not name:
  105. raise ValueError("Loss config must provide 'name' or 'task_name'.")
  106. merged = {
  107. "name": name,
  108. "params": dict(config.get("params", {})),
  109. }
  110. if not isinstance(merged["params"], dict):
  111. raise ValueError("Loss config field 'params' must be a dict if provided.")
  112. params = _mode_defaults(task_mode)
  113. params.update(merged["params"])
  114. loss_name = merged["name"].lower()
  115. loss_cls = LOSS_REGISTRY.get(loss_name)
  116. if loss_cls is None:
  117. raise ValueError(
  118. f"Unsupported tools '{merged['name']}'. Expected one of: {', '.join(LOSS_REGISTRY)}."
  119. )
  120. return loss_cls(**params)
  121. class BinaryBoundaryLoss(nn.Module):
  122. def __init__(self, bce_weight: float = 1.0, dice_weight: float = 1.0, eps: float = 1e-6) -> None:
  123. super().__init__()
  124. self.bce_weight = bce_weight
  125. self.dice_weight = dice_weight
  126. self.eps = eps
  127. def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
  128. target = target.float()
  129. bce = F.binary_cross_entropy_with_logits(logits, target)
  130. probs = torch.sigmoid(logits)
  131. intersection = (probs * target).sum(dim=(1, 2, 3))
  132. union = probs.sum(dim=(1, 2, 3)) + target.sum(dim=(1, 2, 3))
  133. dice = 1.0 - ((2.0 * intersection + self.eps) / (union + self.eps))
  134. return self.bce_weight * bce + self.dice_weight * dice.mean()
  135. class MaskBoundaryConsistencyLoss(nn.Module):
  136. def forward(self, seg_logits: torch.Tensor, boundary_logits: torch.Tensor) -> torch.Tensor:
  137. if seg_logits.shape[1] == 1:
  138. seg_prob = torch.sigmoid(seg_logits)
  139. else:
  140. seg_prob = torch.softmax(seg_logits, dim=1)[:, 1:2]
  141. boundary_prob = torch.sigmoid(boundary_logits)
  142. grad_x = torch.abs(seg_prob[:, :, :, 1:] - seg_prob[:, :, :, :-1])
  143. grad_y = torch.abs(seg_prob[:, :, 1:, :] - seg_prob[:, :, :-1, :])
  144. grad_x = F.pad(grad_x, (0, 1, 0, 0))
  145. grad_y = F.pad(grad_y, (0, 0, 0, 1))
  146. edge_proxy = torch.clamp(grad_x + grad_y, 0.0, 1.0)
  147. return F.l1_loss(boundary_prob, edge_proxy)
  148. __all__ = [
  149. "DEFAULT_TASK_LOSS",
  150. "LOSS_REGISTRY",
  151. "build_loss",
  152. "BinaryBoundaryLoss",
  153. "MaskBoundaryConsistencyLoss",
  154. ]