| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184 |
- from __future__ import annotations
- from typing import Any
- import torch
- import torch.nn.functional as F
- from torch import nn
- try:
- from monai.losses.dice import DiceCELoss, DiceFocalLoss, GeneralizedDiceFocalLoss
- from monai.losses.hausdorff_loss import HausdorffDTLoss
- except ImportError as exc:
- DiceCELoss = None
- DiceFocalLoss = None
- GeneralizedDiceFocalLoss = None
- HausdorffDTLoss = None
- _MONAI_IMPORT_ERROR = exc
- else:
- _MONAI_IMPORT_ERROR = None
- LOSS_REGISTRY = {
- "dicece": DiceCELoss,
- "dicefocal": DiceFocalLoss,
- "generalized_dice_focal": GeneralizedDiceFocalLoss,
- "gdl_focal": GeneralizedDiceFocalLoss,
- "hausdorff": HausdorffDTLoss,
- }
- DEFAULT_TASK_LOSS = {
- "lun": {
- "name": "dicece",
- "params": {
- "include_background": True,
- "lambda_dice": 0.7,
- "lambda_ce": 0.3,
- },
- },
- "pe": {
- "name": "dicece",
- "params": {
- "include_background": True,
- "lambda_dice": 0.7,
- "lambda_ce": 0.3,
- },
- },
- "b": {
- "name": "dicefocal",
- "params": {
- "include_background": True,
- "lambda_dice": 1.0,
- "lambda_focal": 1.0,
- "gamma": 2.0,
- },
- },
- }
- def _require_monai() -> None:
- if _MONAI_IMPORT_ERROR is not None:
- raise ImportError(
- "MONAI is required for lib.tools.tools. Install monai before building losses."
- ) from _MONAI_IMPORT_ERROR
- def _mode_defaults(task_mode: str) -> dict[str, Any]:
- mode = task_mode.lower()
- if mode == "binary":
- return {
- "sigmoid": True,
- "softmax": False,
- "to_onehot_y": False,
- "reduction": "mean",
- }
- if mode == "multiclass":
- return {
- "sigmoid": False,
- "softmax": True,
- "to_onehot_y": True,
- "reduction": "mean",
- }
- raise ValueError(f"Unsupported task_mode '{task_mode}'. Expected 'binary' or 'multiclass'.")
- def _get_task_config(task_name: str) -> dict[str, Any]:
- task = task_name.lower()
- if task not in DEFAULT_TASK_LOSS:
- raise ValueError(f"Unsupported task '{task_name}'. Expected one of: {', '.join(DEFAULT_TASK_LOSS)}.")
- task_cfg = DEFAULT_TASK_LOSS[task]
- return {
- "name": task_cfg["name"],
- "params": dict(task_cfg.get("params", {})),
- }
- def build_loss(config: dict[str, Any]) -> nn.Module:
- """Build a MONAI tools from a plain yaml-style config dict.
- Supported examples:
- {"task_name": "lun"}
- {"name": "dicece", "task_mode": "binary", "params": {"lambda_ce": 0.5}}
- {"name": "dicece", "task_mode": "multiclass", "params": {"include_background": True}}
- """
- _require_monai()
- if not isinstance(config, dict) or not config:
- raise ValueError("Loss config must be a non-empty dict.")
- task_mode = config.get("task_mode", "binary")
- if not isinstance(task_mode, str):
- raise ValueError("Loss config field 'task_mode' must be a string.")
- if "task_name" in config and config["task_name"] is not None:
- if not isinstance(config["task_name"], str):
- raise ValueError("Loss config field 'task_name' must be a string.")
- merged = _get_task_config(config["task_name"])
- merged["params"].update(config.get("params", {}))
- if "name" in config:
- merged["name"] = config["name"]
- else:
- name = config.get("name")
- if not isinstance(name, str) or not name:
- raise ValueError("Loss config must provide 'name' or 'task_name'.")
- merged = {
- "name": name,
- "params": dict(config.get("params", {})),
- }
- if not isinstance(merged["params"], dict):
- raise ValueError("Loss config field 'params' must be a dict if provided.")
- params = _mode_defaults(task_mode)
- params.update(merged["params"])
- loss_name = merged["name"].lower()
- loss_cls = LOSS_REGISTRY.get(loss_name)
- if loss_cls is None:
- raise ValueError(
- f"Unsupported tools '{merged['name']}'. Expected one of: {', '.join(LOSS_REGISTRY)}."
- )
- return loss_cls(**params)
- class BinaryBoundaryLoss(nn.Module):
- def __init__(self, bce_weight: float = 1.0, dice_weight: float = 1.0, eps: float = 1e-6) -> None:
- super().__init__()
- self.bce_weight = bce_weight
- self.dice_weight = dice_weight
- self.eps = eps
- def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
- target = target.float()
- bce = F.binary_cross_entropy_with_logits(logits, target)
- probs = torch.sigmoid(logits)
- intersection = (probs * target).sum(dim=(1, 2, 3))
- union = probs.sum(dim=(1, 2, 3)) + target.sum(dim=(1, 2, 3))
- dice = 1.0 - ((2.0 * intersection + self.eps) / (union + self.eps))
- return self.bce_weight * bce + self.dice_weight * dice.mean()
- class MaskBoundaryConsistencyLoss(nn.Module):
- def forward(self, seg_logits: torch.Tensor, boundary_logits: torch.Tensor) -> torch.Tensor:
- if seg_logits.shape[1] == 1:
- seg_prob = torch.sigmoid(seg_logits)
- else:
- seg_prob = torch.softmax(seg_logits, dim=1)[:, 1:2]
- boundary_prob = torch.sigmoid(boundary_logits)
- grad_x = torch.abs(seg_prob[:, :, :, 1:] - seg_prob[:, :, :, :-1])
- grad_y = torch.abs(seg_prob[:, :, 1:, :] - seg_prob[:, :, :-1, :])
- grad_x = F.pad(grad_x, (0, 1, 0, 0))
- grad_y = F.pad(grad_y, (0, 0, 0, 1))
- edge_proxy = torch.clamp(grad_x + grad_y, 0.0, 1.0)
- return F.l1_loss(boundary_prob, edge_proxy)
- __all__ = [
- "DEFAULT_TASK_LOSS",
- "LOSS_REGISTRY",
- "build_loss",
- "BinaryBoundaryLoss",
- "MaskBoundaryConsistencyLoss",
- ]
|