| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- from __future__ import annotations
- from typing import Any
- from torch import nn
- try:
- from monai.losses.dice import DiceCELoss, DiceFocalLoss, GeneralizedDiceFocalLoss
- from monai.losses.hausdorff_loss import HausdorffDTLoss
- except ImportError as exc:
- DiceCELoss = None
- DiceFocalLoss = None
- GeneralizedDiceFocalLoss = None
- HausdorffDTLoss = None
- _MONAI_IMPORT_ERROR = exc
- else:
- _MONAI_IMPORT_ERROR = None
- LOSS_REGISTRY = {
- "dicece": DiceCELoss,
- "dicefocal": DiceFocalLoss,
- "generalized_dice_focal": GeneralizedDiceFocalLoss,
- "gdl_focal": GeneralizedDiceFocalLoss,
- "hausdorff": HausdorffDTLoss,
- }
- DEFAULT_TASK_LOSS = {
- "lun": {
- "name": "dicece",
- "params": {
- "include_background": True,
- "lambda_dice": 0.7,
- "lambda_ce": 0.3,
- },
- },
- "pe": {
- "name": "dicece",
- "params": {
- "include_background": True,
- "lambda_dice": 0.7,
- "lambda_ce": 0.3,
- },
- },
- "b": {
- "name": "dicefocal",
- "params": {
- "include_background": True,
- "lambda_dice": 1.0,
- "lambda_focal": 1.0,
- "gamma": 2.0,
- },
- },
- }
- def _require_monai() -> None:
- if _MONAI_IMPORT_ERROR is not None:
- raise ImportError(
- "MONAI is required for lib.tools.tools. Install monai before building losses."
- ) from _MONAI_IMPORT_ERROR
- def _mode_defaults(task_mode: str) -> dict[str, Any]:
- mode = task_mode.lower()
- if mode == "binary":
- return {
- "sigmoid": True,
- "softmax": False,
- "to_onehot_y": False,
- "reduction": "mean",
- }
- if mode == "multiclass":
- return {
- "sigmoid": False,
- "softmax": True,
- "to_onehot_y": True,
- "reduction": "mean",
- }
- raise ValueError(f"Unsupported task_mode '{task_mode}'. Expected 'binary' or 'multiclass'.")
- def _get_task_config(task_name: str) -> dict[str, Any]:
- task = task_name.lower()
- if task not in DEFAULT_TASK_LOSS:
- raise ValueError(f"Unsupported task '{task_name}'. Expected one of: {', '.join(DEFAULT_TASK_LOSS)}.")
- task_cfg = DEFAULT_TASK_LOSS[task]
- return {
- "name": task_cfg["name"],
- "params": dict(task_cfg.get("params", {})),
- }
- def build_loss(config: dict[str, Any]) -> nn.Module:
- """Build a MONAI tools from a plain yaml-style config dict.
- Supported examples:
- {"task_name": "lun"}
- {"name": "dicece", "task_mode": "binary", "params": {"lambda_ce": 0.5}}
- {"name": "dicece", "task_mode": "multiclass", "params": {"include_background": True}}
- """
- _require_monai()
- if not isinstance(config, dict) or not config:
- raise ValueError("Loss config must be a non-empty dict.")
- task_mode = config.get("task_mode", "binary")
- if not isinstance(task_mode, str):
- raise ValueError("Loss config field 'task_mode' must be a string.")
- if "task_name" in config and config["task_name"] is not None:
- if not isinstance(config["task_name"], str):
- raise ValueError("Loss config field 'task_name' must be a string.")
- merged = _get_task_config(config["task_name"])
- merged["params"].update(config.get("params", {}))
- if "name" in config:
- merged["name"] = config["name"]
- else:
- name = config.get("name")
- if not isinstance(name, str) or not name:
- raise ValueError("Loss config must provide 'name' or 'task_name'.")
- merged = {
- "name": name,
- "params": dict(config.get("params", {})),
- }
- if not isinstance(merged["params"], dict):
- raise ValueError("Loss config field 'params' must be a dict if provided.")
- params = _mode_defaults(task_mode)
- params.update(merged["params"])
- loss_name = merged["name"].lower()
- loss_cls = LOSS_REGISTRY.get(loss_name)
- if loss_cls is None:
- raise ValueError(
- f"Unsupported tools '{merged['name']}'. Expected one of: {', '.join(LOSS_REGISTRY)}."
- )
- return loss_cls(**params)
- __all__ = ["DEFAULT_TASK_LOSS", "LOSS_REGISTRY", "build_loss"]
|