from __future__ import annotations from typing import Any from torch import nn try: from monai.losses.dice import DiceCELoss, DiceFocalLoss, GeneralizedDiceFocalLoss from monai.losses.hausdorff_loss import HausdorffDTLoss except ImportError as exc: DiceCELoss = None DiceFocalLoss = None GeneralizedDiceFocalLoss = None HausdorffDTLoss = None _MONAI_IMPORT_ERROR = exc else: _MONAI_IMPORT_ERROR = None LOSS_REGISTRY = { "dicece": DiceCELoss, "dicefocal": DiceFocalLoss, "generalized_dice_focal": GeneralizedDiceFocalLoss, "gdl_focal": GeneralizedDiceFocalLoss, "hausdorff": HausdorffDTLoss, } DEFAULT_TASK_LOSS = { "lun": { "name": "dicece", "params": { "include_background": True, "lambda_dice": 0.7, "lambda_ce": 0.3, }, }, "pe": { "name": "dicece", "params": { "include_background": True, "lambda_dice": 0.7, "lambda_ce": 0.3, }, }, "b": { "name": "dicefocal", "params": { "include_background": True, "lambda_dice": 1.0, "lambda_focal": 1.0, "gamma": 2.0, }, }, } def _require_monai() -> None: if _MONAI_IMPORT_ERROR is not None: raise ImportError( "MONAI is required for lib.tools.tools. Install monai before building losses." ) from _MONAI_IMPORT_ERROR def _mode_defaults(task_mode: str) -> dict[str, Any]: mode = task_mode.lower() if mode == "binary": return { "sigmoid": True, "softmax": False, "to_onehot_y": False, "reduction": "mean", } if mode == "multiclass": return { "sigmoid": False, "softmax": True, "to_onehot_y": True, "reduction": "mean", } raise ValueError(f"Unsupported task_mode '{task_mode}'. Expected 'binary' or 'multiclass'.") def _get_task_config(task_name: str) -> dict[str, Any]: task = task_name.lower() if task not in DEFAULT_TASK_LOSS: raise ValueError(f"Unsupported task '{task_name}'. Expected one of: {', '.join(DEFAULT_TASK_LOSS)}.") task_cfg = DEFAULT_TASK_LOSS[task] return { "name": task_cfg["name"], "params": dict(task_cfg.get("params", {})), } def build_loss(config: dict[str, Any]) -> nn.Module: """Build a MONAI tools from a plain yaml-style config dict. Supported examples: {"task_name": "lun"} {"name": "dicece", "task_mode": "binary", "params": {"lambda_ce": 0.5}} {"name": "dicece", "task_mode": "multiclass", "params": {"include_background": True}} """ _require_monai() if not isinstance(config, dict) or not config: raise ValueError("Loss config must be a non-empty dict.") task_mode = config.get("task_mode", "binary") if not isinstance(task_mode, str): raise ValueError("Loss config field 'task_mode' must be a string.") if "task_name" in config and config["task_name"] is not None: if not isinstance(config["task_name"], str): raise ValueError("Loss config field 'task_name' must be a string.") merged = _get_task_config(config["task_name"]) merged["params"].update(config.get("params", {})) if "name" in config: merged["name"] = config["name"] else: name = config.get("name") if not isinstance(name, str) or not name: raise ValueError("Loss config must provide 'name' or 'task_name'.") merged = { "name": name, "params": dict(config.get("params", {})), } if not isinstance(merged["params"], dict): raise ValueError("Loss config field 'params' must be a dict if provided.") params = _mode_defaults(task_mode) params.update(merged["params"]) loss_name = merged["name"].lower() loss_cls = LOSS_REGISTRY.get(loss_name) if loss_cls is None: raise ValueError( f"Unsupported tools '{merged['name']}'. Expected one of: {', '.join(LOSS_REGISTRY)}." ) return loss_cls(**params) __all__ = ["DEFAULT_TASK_LOSS", "LOSS_REGISTRY", "build_loss"]