combined_loss.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. from __future__ import annotations
  2. import warnings
  3. from collections.abc import Callable, Sequence
  4. from typing import Any
  5. import numpy as np
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. # noinspection PyProtectedMember
  10. from torch.nn.modules.loss import _Loss
  11. from monai.losses.utils import compute_tp_fp_fn
  12. from monai.networks import one_hot
  13. from monai.utils import LossReduction, Weight, look_up_option
  14. from monai.losses import DiceLoss, DiceCELoss, HausdorffDTLoss
  15. class IoULoss(_Loss):
  16. """
  17. Compute average IoU (Intersection over Union) loss between two tensors.
  18. IoU Loss 直接优化交并比指标,相比 Dice Loss 对边界误差更敏感,
  19. 能够提供更强的梯度信号用于训练。
  20. """
  21. def __init__(
  22. self,
  23. include_background: bool = True,
  24. to_onehot_y: bool = False,
  25. sigmoid: bool = False,
  26. softmax: bool = False,
  27. other_act: Callable | None = None,
  28. squared_pred: bool = False,
  29. reduction: LossReduction | str = LossReduction.MEAN,
  30. smooth_nr: float = 1e-5,
  31. smooth_dr: float = 1e-5,
  32. batch: bool = False,
  33. weight: Sequence[float] | float | int | torch.Tensor | None = None,
  34. soft_label: bool = False,
  35. ) -> None:
  36. """
  37. Args:
  38. include_background: if False, channel index 0 (background category) is excluded from the calculation.
  39. 如果非背景区域相对于整个图像较小,排除背景可以帮助收敛。
  40. to_onehot_y: whether to convert the ``target`` into the one-hot format,
  41. using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
  42. sigmoid: if True, apply a sigmoid function to the prediction.
  43. softmax: if True, apply a softmax function to the prediction.
  44. other_act: callable function to execute other activation layers, Defaults to ``None``.
  45. for example: ``other_act = torch.tanh``.
  46. squared_pred: use squared versions of targets and predictions in the denominator or not.
  47. 使用平方可以加强大误差区域的惩罚。
  48. reduction: {``"none"``, ``"mean"``, ``"sum"``}
  49. Specifies the reduction to apply to the output. Defaults to ``"mean"``.
  50. - ``"none"``: no reduction will be applied.
  51. - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
  52. - ``"sum"``: the output will be summed.
  53. smooth_nr: a small constant added to the numerator to avoid zero.
  54. smooth_dr: a small constant added to the denominator to avoid nan.
  55. batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
  56. Defaults to False, 每个 batch item 独立计算损失后再 reduction。
  57. weight: weights to apply to the voxels of each class. If None no weights are applied.
  58. The input can be a single value (same weight for all classes), a sequence of values (the length
  59. of the sequence should be the same as the number of classes. If not ``include_background``,
  60. the number of classes should not include the background category class 0).
  61. The value/values should be no less than 0. Defaults to None.
  62. soft_label: whether the target contains non-binary values (soft labels) or not.
  63. If True a soft label formulation of the loss will be used.
  64. Raises:
  65. TypeError: When ``other_act`` is not an ``Optional[Callable]``.
  66. ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
  67. Incompatible values.
  68. Example:
  69. >>> import torch
  70. >>> from lib.modules.iou_loss import IoULoss
  71. >>> B, C, H, W = 7, 5, 3, 2
  72. >>> y_input = torch.rand(B, C, H, W)
  73. >>> target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long()
  74. >>> target = one_hot(target_idx[:, None, ...], num_classes=C)
  75. >>> loss_fn = IoULoss(reduction='none')
  76. >>> loss = loss_fn(y_input, target)
  77. """
  78. super().__init__(reduction=LossReduction(reduction).value)
  79. if other_act is not None and not callable(other_act):
  80. raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.")
  81. if int(sigmoid) + int(softmax) + int(other_act is not None) > 1:
  82. raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].")
  83. self.include_background = include_background
  84. self.to_onehot_y = to_onehot_y
  85. self.sigmoid = sigmoid
  86. self.softmax = softmax
  87. self.other_act = other_act
  88. self.squared_pred = squared_pred
  89. self.smooth_nr = float(smooth_nr)
  90. self.smooth_dr = float(smooth_dr)
  91. self.batch = batch
  92. weight = torch.as_tensor(weight) if weight is not None else None
  93. self.register_buffer("class_weight", weight)
  94. self.class_weight: None | torch.Tensor
  95. self.soft_label = soft_label
  96. def forward(self, y_input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
  97. """
  98. Args:
  99. y_input: the shape should be BNH[WD], where N is the number of classes.
  100. target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.
  101. Raises:
  102. AssertionError: When input and target (after one hot transform if set) have different shapes.
  103. ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
  104. """
  105. if self.sigmoid:
  106. y_input = torch.sigmoid(y_input)
  107. n_pred_ch = y_input.shape[1]
  108. if self.softmax:
  109. if n_pred_ch == 1:
  110. warnings.warn("single channel prediction, `softmax=True` ignored.")
  111. else:
  112. y_input = torch.softmax(y_input, 1)
  113. if self.other_act is not None:
  114. y_input = self.other_act(y_input)
  115. if self.to_onehot_y:
  116. if n_pred_ch == 1:
  117. warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
  118. else:
  119. target = one_hot(target, num_classes=n_pred_ch)
  120. if not self.include_background:
  121. if n_pred_ch == 1:
  122. warnings.warn("single channel prediction, `include_background=False` ignored.")
  123. else:
  124. # if skipping background, removing first channel
  125. target = target[:, 1:]
  126. y_input = y_input[:, 1:]
  127. if target.shape != y_input.shape:
  128. raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({y_input.shape})")
  129. # reducing only spatial dimensions (not batch nor channels)
  130. reduce_axis: list[int] = torch.arange(2, len(y_input.shape)).tolist()
  131. if self.batch:
  132. # reducing spatial dimensions and batch
  133. reduce_axis = [0] + reduce_axis
  134. y_ord = 2 if self.squared_pred else 1
  135. tp, fp, fn = compute_tp_fp_fn(y_input, target, reduce_axis, y_ord, self.soft_label)
  136. # IoU 的核心公式:IoU = TP / (TP + FP + FN)
  137. # 注意:与 Dice 不同,IoU 的分子没有系数 2,分母也少了 TP
  138. numerator = tp + self.smooth_nr
  139. denominator = tp + fp + fn + self.smooth_dr
  140. iou: torch.Tensor = numerator / denominator
  141. loss: torch.Tensor = 1 - iou # IoU Loss = 1 - IoU
  142. num_of_classes = target.shape[1]
  143. if self.class_weight is not None and num_of_classes != 1:
  144. # make sure the lengths of weights are equal to the number of classes
  145. if self.class_weight.ndim == 0:
  146. # noinspection PyAttributeOutsideInit
  147. self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes)
  148. else:
  149. if self.class_weight.shape[0] != num_of_classes:
  150. raise ValueError(
  151. """the length of the `weight` sequence should be the same as the number of classes.
  152. If `include_background=False`, the weight should not include
  153. the background category class 0."""
  154. )
  155. if self.class_weight.min() < 0:
  156. raise ValueError("the value/values of the `weight` should be no less than 0.")
  157. # apply class_weight to loss
  158. loss = loss * self.class_weight.to(loss)
  159. if self.reduction == LossReduction.MEAN.value:
  160. loss = torch.mean(loss) # the batch and channel average
  161. elif self.reduction == LossReduction.SUM.value:
  162. loss = torch.sum(loss) # sum over the batch and channel dims
  163. elif self.reduction == LossReduction.NONE.value:
  164. # If we are not computing voxelwise loss components at least
  165. # make sure a none reduction maintains a broadcastable shape
  166. broadcast_shape = list(loss.shape[0:2]) + [1] * (len(y_input.shape) - 2)
  167. loss = loss.view(broadcast_shape)
  168. else:
  169. raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
  170. return loss
  171. class CombinedDiceCEIoULoss(_Loss):
  172. """
  173. 四合一组合损失:DiceCE + IoU
  174. - DiceCE: Dice + CrossEntropy (全局 + 局部)
  175. - IoU: 交并比优化
  176. """
  177. def __init__(
  178. self,
  179. dice_weight: float = 1.0,
  180. ce_weight: float = 1.0,
  181. iou_weight: float = 1.0,
  182. include_background: bool = True,
  183. to_onehot_y: bool = False,
  184. softmax: bool = False,
  185. sigmoid: bool = False,
  186. ):
  187. super().__init__()
  188. self.include_background = include_background
  189. self.to_onehot_y = to_onehot_y
  190. self.softmax = softmax
  191. # DiceBCE Loss (Dice + BCE)
  192. self.dice_ce_loss = DiceCELoss(
  193. include_background=include_background,
  194. to_onehot_y=to_onehot_y,
  195. softmax=softmax,
  196. sigmoid=sigmoid,
  197. lambda_dice=dice_weight / (dice_weight + ce_weight),
  198. lambda_ce=ce_weight / (dice_weight + ce_weight),
  199. reduction="mean",
  200. )
  201. # IoU Loss
  202. self.iou_loss = IoULoss(
  203. include_background=include_background,
  204. to_onehot_y=to_onehot_y,
  205. softmax=softmax,
  206. sigmoid=sigmoid,
  207. reduction="mean",
  208. )
  209. # 外部权重
  210. self.dice_ce_weight = dice_weight + ce_weight
  211. self.iou_weight = iou_weight
  212. def forward(self, y_input: torch.Tensor, target: torch.Tensor) -> tuple[float | Any, Any, Any]:
  213. # DiceCE Loss
  214. dice_ce_loss = self.dice_ce_loss(y_input, target)
  215. # IoU Loss
  216. iou_loss = self.iou_loss(y_input, target)
  217. # 加权组合
  218. total_loss = (
  219. self.dice_ce_weight * dice_ce_loss +
  220. self.iou_weight * iou_loss
  221. )
  222. return total_loss, dice_ce_loss, iou_loss