from __future__ import annotations import warnings from collections.abc import Callable, Sequence from typing import Any import numpy as np import torch import torch.nn as nn import torch.nn.functional as F # noinspection PyProtectedMember from torch.nn.modules.loss import _Loss from monai.losses.utils import compute_tp_fp_fn from monai.networks import one_hot from monai.utils import LossReduction, Weight, look_up_option from monai.losses import DiceLoss, DiceCELoss, HausdorffDTLoss class IoULoss(_Loss): """ Compute average IoU (Intersection over Union) loss between two tensors. IoU Loss 直接优化交并比指标,相比 Dice Loss 对边界误差更敏感, 能够提供更强的梯度信号用于训练。 """ def __init__( self, include_background: bool = True, to_onehot_y: bool = False, sigmoid: bool = False, softmax: bool = False, other_act: Callable | None = None, squared_pred: bool = False, reduction: LossReduction | str = LossReduction.MEAN, smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, batch: bool = False, weight: Sequence[float] | float | int | torch.Tensor | None = None, soft_label: bool = False, ) -> None: """ Args: include_background: if False, channel index 0 (background category) is excluded from the calculation. 如果非背景区域相对于整个图像较小,排除背景可以帮助收敛。 to_onehot_y: whether to convert the ``target`` into the one-hot format, using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False. sigmoid: if True, apply a sigmoid function to the prediction. softmax: if True, apply a softmax function to the prediction. other_act: callable function to execute other activation layers, Defaults to ``None``. for example: ``other_act = torch.tanh``. squared_pred: use squared versions of targets and predictions in the denominator or not. 使用平方可以加强大误差区域的惩罚。 reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. Defaults to ``"mean"``. - ``"none"``: no reduction will be applied. - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. smooth_nr: a small constant added to the numerator to avoid zero. smooth_dr: a small constant added to the denominator to avoid nan. batch: whether to sum the intersection and union areas over the batch dimension before the dividing. Defaults to False, 每个 batch item 独立计算损失后再 reduction。 weight: weights to apply to the voxels of each class. If None no weights are applied. The input can be a single value (same weight for all classes), a sequence of values (the length of the sequence should be the same as the number of classes. If not ``include_background``, the number of classes should not include the background category class 0). The value/values should be no less than 0. Defaults to None. soft_label: whether the target contains non-binary values (soft labels) or not. If True a soft label formulation of the loss will be used. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. Incompatible values. Example: >>> import torch >>> from lib.modules.iou_loss import IoULoss >>> B, C, H, W = 7, 5, 3, 2 >>> y_input = torch.rand(B, C, H, W) >>> target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long() >>> target = one_hot(target_idx[:, None, ...], num_classes=C) >>> loss_fn = IoULoss(reduction='none') >>> loss = loss_fn(y_input, target) """ super().__init__(reduction=LossReduction(reduction).value) if other_act is not None and not callable(other_act): raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.") if int(sigmoid) + int(softmax) + int(other_act is not None) > 1: raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].") self.include_background = include_background self.to_onehot_y = to_onehot_y self.sigmoid = sigmoid self.softmax = softmax self.other_act = other_act self.squared_pred = squared_pred self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) self.batch = batch weight = torch.as_tensor(weight) if weight is not None else None self.register_buffer("class_weight", weight) self.class_weight: None | torch.Tensor self.soft_label = soft_label def forward(self, y_input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: y_input: the shape should be BNH[WD], where N is the number of classes. target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. Raises: AssertionError: When input and target (after one hot transform if set) have different shapes. ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. """ if self.sigmoid: y_input = torch.sigmoid(y_input) n_pred_ch = y_input.shape[1] if self.softmax: if n_pred_ch == 1: warnings.warn("single channel prediction, `softmax=True` ignored.") else: y_input = torch.softmax(y_input, 1) if self.other_act is not None: y_input = self.other_act(y_input) if self.to_onehot_y: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: target = one_hot(target, num_classes=n_pred_ch) if not self.include_background: if n_pred_ch == 1: warnings.warn("single channel prediction, `include_background=False` ignored.") else: # if skipping background, removing first channel target = target[:, 1:] y_input = y_input[:, 1:] if target.shape != y_input.shape: raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({y_input.shape})") # reducing only spatial dimensions (not batch nor channels) reduce_axis: list[int] = torch.arange(2, len(y_input.shape)).tolist() if self.batch: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis y_ord = 2 if self.squared_pred else 1 tp, fp, fn = compute_tp_fp_fn(y_input, target, reduce_axis, y_ord, self.soft_label) # IoU 的核心公式:IoU = TP / (TP + FP + FN) # 注意:与 Dice 不同,IoU 的分子没有系数 2,分母也少了 TP numerator = tp + self.smooth_nr denominator = tp + fp + fn + self.smooth_dr iou: torch.Tensor = numerator / denominator loss: torch.Tensor = 1 - iou # IoU Loss = 1 - IoU num_of_classes = target.shape[1] if self.class_weight is not None and num_of_classes != 1: # make sure the lengths of weights are equal to the number of classes if self.class_weight.ndim == 0: # noinspection PyAttributeOutsideInit self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes) else: if self.class_weight.shape[0] != num_of_classes: raise ValueError( """the length of the `weight` sequence should be the same as the number of classes. If `include_background=False`, the weight should not include the background category class 0.""" ) if self.class_weight.min() < 0: raise ValueError("the value/values of the `weight` should be no less than 0.") # apply class_weight to loss loss = loss * self.class_weight.to(loss) if self.reduction == LossReduction.MEAN.value: loss = torch.mean(loss) # the batch and channel average elif self.reduction == LossReduction.SUM.value: loss = torch.sum(loss) # sum over the batch and channel dims elif self.reduction == LossReduction.NONE.value: # If we are not computing voxelwise loss components at least # make sure a none reduction maintains a broadcastable shape broadcast_shape = list(loss.shape[0:2]) + [1] * (len(y_input.shape) - 2) loss = loss.view(broadcast_shape) else: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') return loss class CombinedDiceCEIoULoss(_Loss): """ 四合一组合损失:DiceCE + IoU - DiceCE: Dice + CrossEntropy (全局 + 局部) - IoU: 交并比优化 """ def __init__( self, dice_weight: float = 1.0, ce_weight: float = 1.0, iou_weight: float = 1.0, include_background: bool = True, to_onehot_y: bool = False, softmax: bool = False, sigmoid: bool = False, ): super().__init__() self.include_background = include_background self.to_onehot_y = to_onehot_y self.softmax = softmax # DiceBCE Loss (Dice + BCE) self.dice_ce_loss = DiceCELoss( include_background=include_background, to_onehot_y=to_onehot_y, softmax=softmax, sigmoid=sigmoid, lambda_dice=dice_weight / (dice_weight + ce_weight), lambda_ce=ce_weight / (dice_weight + ce_weight), reduction="mean", ) # IoU Loss self.iou_loss = IoULoss( include_background=include_background, to_onehot_y=to_onehot_y, softmax=softmax, sigmoid=sigmoid, reduction="mean", ) # 外部权重 self.dice_ce_weight = dice_weight + ce_weight self.iou_weight = iou_weight def forward(self, y_input: torch.Tensor, target: torch.Tensor) -> tuple[float | Any, Any, Any]: # DiceCE Loss dice_ce_loss = self.dice_ce_loss(y_input, target) # IoU Loss iou_loss = self.iou_loss(y_input, target) # 加权组合 total_loss = ( self.dice_ce_weight * dice_ce_loss + self.iou_weight * iou_loss ) return total_loss, dice_ce_loss, iou_loss