| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256 |
- from __future__ import annotations
- import warnings
- from collections.abc import Callable, Sequence
- from typing import Any
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- # noinspection PyProtectedMember
- from torch.nn.modules.loss import _Loss
- from monai.losses.utils import compute_tp_fp_fn
- from monai.networks import one_hot
- from monai.utils import LossReduction, Weight, look_up_option
- from monai.losses import DiceLoss, DiceCELoss, HausdorffDTLoss
- class IoULoss(_Loss):
- """
- Compute average IoU (Intersection over Union) loss between two tensors.
-
- IoU Loss 直接优化交并比指标,相比 Dice Loss 对边界误差更敏感,
- 能够提供更强的梯度信号用于训练。
- """
- def __init__(
- self,
- include_background: bool = True,
- to_onehot_y: bool = False,
- sigmoid: bool = False,
- softmax: bool = False,
- other_act: Callable | None = None,
- squared_pred: bool = False,
- reduction: LossReduction | str = LossReduction.MEAN,
- smooth_nr: float = 1e-5,
- smooth_dr: float = 1e-5,
- batch: bool = False,
- weight: Sequence[float] | float | int | torch.Tensor | None = None,
- soft_label: bool = False,
- ) -> None:
- """
- Args:
- include_background: if False, channel index 0 (background category) is excluded from the calculation.
- 如果非背景区域相对于整个图像较小,排除背景可以帮助收敛。
- to_onehot_y: whether to convert the ``target`` into the one-hot format,
- using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
- sigmoid: if True, apply a sigmoid function to the prediction.
- softmax: if True, apply a softmax function to the prediction.
- other_act: callable function to execute other activation layers, Defaults to ``None``.
- for example: ``other_act = torch.tanh``.
- squared_pred: use squared versions of targets and predictions in the denominator or not.
- 使用平方可以加强大误差区域的惩罚。
- reduction: {``"none"``, ``"mean"``, ``"sum"``}
- Specifies the reduction to apply to the output. Defaults to ``"mean"``.
- - ``"none"``: no reduction will be applied.
- - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- - ``"sum"``: the output will be summed.
- smooth_nr: a small constant added to the numerator to avoid zero.
- smooth_dr: a small constant added to the denominator to avoid nan.
- batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
- Defaults to False, 每个 batch item 独立计算损失后再 reduction。
- weight: weights to apply to the voxels of each class. If None no weights are applied.
- The input can be a single value (same weight for all classes), a sequence of values (the length
- of the sequence should be the same as the number of classes. If not ``include_background``,
- the number of classes should not include the background category class 0).
- The value/values should be no less than 0. Defaults to None.
- soft_label: whether the target contains non-binary values (soft labels) or not.
- If True a soft label formulation of the loss will be used.
- Raises:
- TypeError: When ``other_act`` is not an ``Optional[Callable]``.
- ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
- Incompatible values.
- Example:
- >>> import torch
- >>> from lib.modules.iou_loss import IoULoss
- >>> B, C, H, W = 7, 5, 3, 2
- >>> y_input = torch.rand(B, C, H, W)
- >>> target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long()
- >>> target = one_hot(target_idx[:, None, ...], num_classes=C)
- >>> loss_fn = IoULoss(reduction='none')
- >>> loss = loss_fn(y_input, target)
- """
- super().__init__(reduction=LossReduction(reduction).value)
- if other_act is not None and not callable(other_act):
- raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.")
- if int(sigmoid) + int(softmax) + int(other_act is not None) > 1:
- raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].")
- self.include_background = include_background
- self.to_onehot_y = to_onehot_y
- self.sigmoid = sigmoid
- self.softmax = softmax
- self.other_act = other_act
- self.squared_pred = squared_pred
- self.smooth_nr = float(smooth_nr)
- self.smooth_dr = float(smooth_dr)
- self.batch = batch
- weight = torch.as_tensor(weight) if weight is not None else None
- self.register_buffer("class_weight", weight)
- self.class_weight: None | torch.Tensor
- self.soft_label = soft_label
- def forward(self, y_input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
- """
- Args:
- y_input: the shape should be BNH[WD], where N is the number of classes.
- target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.
- Raises:
- AssertionError: When input and target (after one hot transform if set) have different shapes.
- ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
- """
- if self.sigmoid:
- y_input = torch.sigmoid(y_input)
- n_pred_ch = y_input.shape[1]
- if self.softmax:
- if n_pred_ch == 1:
- warnings.warn("single channel prediction, `softmax=True` ignored.")
- else:
- y_input = torch.softmax(y_input, 1)
- if self.other_act is not None:
- y_input = self.other_act(y_input)
- if self.to_onehot_y:
- if n_pred_ch == 1:
- warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
- else:
- target = one_hot(target, num_classes=n_pred_ch)
- if not self.include_background:
- if n_pred_ch == 1:
- warnings.warn("single channel prediction, `include_background=False` ignored.")
- else:
- # if skipping background, removing first channel
- target = target[:, 1:]
- y_input = y_input[:, 1:]
- if target.shape != y_input.shape:
- raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({y_input.shape})")
- # reducing only spatial dimensions (not batch nor channels)
- reduce_axis: list[int] = torch.arange(2, len(y_input.shape)).tolist()
- if self.batch:
- # reducing spatial dimensions and batch
- reduce_axis = [0] + reduce_axis
- y_ord = 2 if self.squared_pred else 1
- tp, fp, fn = compute_tp_fp_fn(y_input, target, reduce_axis, y_ord, self.soft_label)
- # IoU 的核心公式:IoU = TP / (TP + FP + FN)
- # 注意:与 Dice 不同,IoU 的分子没有系数 2,分母也少了 TP
- numerator = tp + self.smooth_nr
- denominator = tp + fp + fn + self.smooth_dr
- iou: torch.Tensor = numerator / denominator
- loss: torch.Tensor = 1 - iou # IoU Loss = 1 - IoU
- num_of_classes = target.shape[1]
- if self.class_weight is not None and num_of_classes != 1:
- # make sure the lengths of weights are equal to the number of classes
- if self.class_weight.ndim == 0:
- # noinspection PyAttributeOutsideInit
- self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes)
- else:
- if self.class_weight.shape[0] != num_of_classes:
- raise ValueError(
- """the length of the `weight` sequence should be the same as the number of classes.
- If `include_background=False`, the weight should not include
- the background category class 0."""
- )
- if self.class_weight.min() < 0:
- raise ValueError("the value/values of the `weight` should be no less than 0.")
- # apply class_weight to loss
- loss = loss * self.class_weight.to(loss)
- if self.reduction == LossReduction.MEAN.value:
- loss = torch.mean(loss) # the batch and channel average
- elif self.reduction == LossReduction.SUM.value:
- loss = torch.sum(loss) # sum over the batch and channel dims
- elif self.reduction == LossReduction.NONE.value:
- # If we are not computing voxelwise loss components at least
- # make sure a none reduction maintains a broadcastable shape
- broadcast_shape = list(loss.shape[0:2]) + [1] * (len(y_input.shape) - 2)
- loss = loss.view(broadcast_shape)
- else:
- raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
- return loss
- class CombinedDiceCEIoULoss(_Loss):
- """
- 四合一组合损失:DiceCE + IoU
- - DiceCE: Dice + CrossEntropy (全局 + 局部)
- - IoU: 交并比优化
- """
- def __init__(
- self,
- dice_weight: float = 1.0,
- ce_weight: float = 1.0,
- iou_weight: float = 1.0,
- include_background: bool = True,
- to_onehot_y: bool = False,
- softmax: bool = False,
- sigmoid: bool = False,
- ):
- super().__init__()
- self.include_background = include_background
- self.to_onehot_y = to_onehot_y
- self.softmax = softmax
- # DiceBCE Loss (Dice + BCE)
- self.dice_ce_loss = DiceCELoss(
- include_background=include_background,
- to_onehot_y=to_onehot_y,
- softmax=softmax,
- sigmoid=sigmoid,
- lambda_dice=dice_weight / (dice_weight + ce_weight),
- lambda_ce=ce_weight / (dice_weight + ce_weight),
- reduction="mean",
- )
- # IoU Loss
- self.iou_loss = IoULoss(
- include_background=include_background,
- to_onehot_y=to_onehot_y,
- softmax=softmax,
- sigmoid=sigmoid,
- reduction="mean",
- )
- # 外部权重
- self.dice_ce_weight = dice_weight + ce_weight
- self.iou_weight = iou_weight
- def forward(self, y_input: torch.Tensor, target: torch.Tensor) -> tuple[float | Any, Any, Any]:
- # DiceCE Loss
- dice_ce_loss = self.dice_ce_loss(y_input, target)
- # IoU Loss
- iou_loss = self.iou_loss(y_input, target)
- # 加权组合
- total_loss = (
- self.dice_ce_weight * dice_ce_loss +
- self.iou_weight * iou_loss
- )
- return total_loss, dice_ce_loss, iou_loss
|