loss.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. from typing import Optional
  2. import torch
  3. import torch.nn as nn
  4. import torch.fft as fft
  5. from monai.losses import DiceLoss, DiceCELoss, HausdorffDTLoss
  6. class FocalFrequencyLoss(nn.Module):
  7. """
  8. 焦点频域损失函数 (Focal Frequency Loss)
  9. 核心思想:
  10. 传统的空间域损失(如 MSE、Dice)主要关注像素级别的差异,而频域损失通过傅里叶变换
  11. 将图像转换到频率域,从频率角度衡量预测图像与真实图像的差异。
  12. 该损失的创新点在于"焦点"机制:
  13. 1. 自动计算频谱权重矩阵,对不同频率成分赋予不同的重要性
  14. 2. 对难以重建的频率成分给予更高权重(类似 Focal Loss 的思想)
  15. 3. 可以捕捉图像的全局结构和纹理细节,弥补空间域损失的不足
  16. 适用场景:
  17. - 医学图像分割:增强边缘和纹理的恢复
  18. - 图像超分辨率:重建高频细节
  19. - 图像去噪/去模糊:平衡低频和高频信息
  20. 参数说明:
  21. loss_weight: 损失权重系数,用于平衡该损失与其他损失的重要性
  22. alpha: 频谱权重的幂次参数,控制权重分布的陡峭程度
  23. alpha 越大,困难频率成分的权重越突出
  24. patch_factor: 图像分块因子,将图像分成多个小块分别进行 FFT
  25. 值为 1 表示不分块,对整个图像做 FFT
  26. 值大于 1 时,将图像分成 patch_factor×patch_factor 个小块
  27. ave_spectrum: 是否对 batch 内的频谱进行平均,用于减少 batch 内差异
  28. log_matrix: 是否对频谱差异取对数,用于压缩动态范围
  29. batch_matrix: 权重归一化方式
  30. True: 在整个 batch 范围内归一化到 [0,1]
  31. False: 对每张图像单独归一化
  32. """
  33. def __init__(self, loss_weight=1.0, alpha=1.0, patch_factor=1, ave_spectrum=False, log_matrix=False,
  34. batch_matrix=False):
  35. """
  36. 初始化焦点频域损失函数的所有超参数
  37. Args:
  38. loss_weight (float): 损失权重,默认 1.0
  39. alpha (float): 频谱权重指数,默认 1.0
  40. patch_factor (int): 图像分块因子,默认 1(不分块)
  41. ave_spectrum (bool): 是否对 batch 频谱平均,默认 False
  42. log_matrix (bool): 是否使用对数矩阵,默认 False
  43. batch_matrix (bool): 是否使用 batch 级矩阵,默认 False
  44. """
  45. super(FocalFrequencyLoss, self).__init__()
  46. self.loss_weight = loss_weight
  47. self.alpha = alpha
  48. self.patch_factor = patch_factor
  49. self.ave_spectrum = ave_spectrum
  50. self.log_matrix = log_matrix
  51. self.batch_matrix = batch_matrix
  52. def tensor2freq(self, x):
  53. """
  54. 将空间域图像张量转换为频域表示
  55. 工作原理:
  56. 1. 如果 patch_factor > 1,先将图像分割成多个小块
  57. 2. 对每个小块执行 2D 快速傅里叶变换 (FFT)
  58. 3. 将复数形式的 FFT 结果分解为实部和虚部
  59. 傅里叶变换的物理意义:
  60. - 低频成分:对应图像的平滑区域和整体轮廓
  61. - 高频成分:对应图像的边缘、纹理和噪声
  62. - 通过分析频谱,可以分离和处理不同频率的特征
  63. Args:
  64. x (torch.Tensor): 输入图像张量,形状为 (N, C, H, W)
  65. N=batch_size, C=channels, H=height, W=width
  66. Returns:
  67. freq (torch.Tensor): 频域表示,形状为 (N, patch_factor², C, H/pf, W/pf, 2)
  68. 最后一维的 2 个通道分别是 [实部,虚部]
  69. patch_factor² 表示分成了多少个小块
  70. Example:
  71. 输入:x.shape = (4, 1, 256, 256), patch_factor=4
  72. 输出:freq.shape = (4, 16, 1, 64, 64, 2)
  73. - 16 = 4×4 个小块
  74. - 64×64 = 每个小块的尺寸
  75. - 2 = [实部,虚部]
  76. """
  77. # 获取分块因子
  78. patch_factor = self.patch_factor
  79. # 获取输入图像的尺寸信息
  80. _, _, h, w = x.shape
  81. # 断言检查:确保图像尺寸可以被 patch_factor 整除
  82. # 这是为了保证分块时每个小块大小一致,避免边界问题
  83. assert h % patch_factor == 0 and w % patch_factor == 0, (
  84. 'Patch factor should be divisible by image height and width')
  85. # 初始化列表用于存储所有小块的频域表示
  86. patch_list = []
  87. # 计算每个小块的高度和宽度
  88. # 例如:原图 256×256, patch_factor=4 → 每个小块 64×64
  89. patch_h = h // patch_factor
  90. patch_w = w // patch_factor
  91. # 双重循环遍历所有小块
  92. # i 控制垂直方向的索引,j 控制水平方向的索引
  93. for i in range(patch_factor):
  94. for j in range(patch_factor):
  95. # 切片操作:提取第 (i,j) 个小块
  96. # 垂直方向:从 i*patch_h 到 (i+1)*patch_h
  97. # 水平方向:从 j*patch_w 到 (j+1)*patch_w
  98. # [:, :, ...] 保持 batch 和 channel 维度不变
  99. patch_list.append(x[:, :, i * patch_h:(i + 1) * patch_h, j * patch_w:(j + 1) * patch_w])
  100. # 将所有小块堆叠成一个新的维度
  101. # 原始形状:(N, C, patch_h, patch_w) 的列表,长度为 patch_factor²
  102. # 堆叠后形状:(N, patch_factor², C, patch_h, patch_w)
  103. # dim=1 表示在第 1 个维度(channel 之后)插入新的分块维度
  104. y = torch.stack(patch_list, 1)
  105. # 对每个小块执行 2D 快速傅里叶变换
  106. # torch.fft.fft2: 计算二维离散傅里叶变换
  107. # norm='ortho': 使用正交归一化,保证变换前后能量守恒
  108. # 变换结果是一个复数张量,包含每个频率成分的振幅和相位信息
  109. freq = torch.fft.fft2(y, norm='ortho')
  110. # 将复数形式转换为实数表示,方便后续神经网络处理
  111. # torch.stack([freq.real, freq.imag], -1): 将实部和虚部堆叠到最后一个维度
  112. # freq.real: 复数的实部,代表余弦分量的幅度
  113. # freq.imag: 复数的虚部,代表正弦分量的幅度
  114. # 最终形状:(N, patch_factor², C, patch_h, patch_w, 2)
  115. freq = torch.stack([freq.real, freq.imag], -1)
  116. return freq
  117. def loss_formulation(self, recon_freq, real_freq, matrix=None):
  118. """
  119. 构建并计算焦点频域损失的核心公式
  120. 核心思想:
  121. 传统频域损失直接计算预测频谱与真实频谱的距离(如 MSE)。
  122. 本方法引入"焦点"机制:根据每个频率成分的重建难度动态调整权重。
  123. 权重计算逻辑:
  124. 1. 如果未提供预定义权重矩阵,则在线计算动态权重
  125. 2. 重建误差大的频率 → 高权重(重点关注)
  126. 3. 重建误差小的频率 → 低权重(减少关注)
  127. 这类似于 Focal Loss 中"关注难例"的思想
  128. Args:
  129. recon_freq (torch.Tensor): 重建(预测)图像的频域表示
  130. 形状:(N, P, C, H, W, 2),P=patch_factor²
  131. real_freq (torch.Tensor): 真实(目标)图像的频域表示
  132. 形状:(N, P, C, H, W, 2)
  133. matrix (torch.Tensor, optional): 预定义的频谱权重矩阵
  134. 如果为 None,则动态计算权重
  135. Returns:
  136. loss (torch.Tensor): 标量损失值
  137. 工作流程:
  138. Step 1: 确定权重矩阵(预定义或动态计算)
  139. Step 2: 计算频谱距离(复数空间中的欧氏距离)
  140. Step 3: 加权求和得到最终损失
  141. """
  142. # ==================== Step 1: 确定权重矩阵 ====================
  143. if matrix is not None:
  144. # 情况 A: 使用预定义的权重矩阵
  145. # 这种模式允许外部指定固定的频率权重,适用于某些先验知识已知的场景
  146. # 例如:人工指定某些频率更重要,或者使用其他算法计算的权重
  147. weight_matrix = matrix.detach()
  148. # .detach() 确保权重矩阵不参与梯度反向传播
  149. # 这样权重是固定的,不会影响梯度的流动
  150. else:
  151. # 情况 B: 动态计算自适应权重矩阵(推荐模式)
  152. # --- 子步骤 1: 计算初步的频谱差异 ---
  153. # 逐元素计算预测频谱与真实频谱的差值的平方
  154. # 这是一个复数差的平方,需要分别处理实部和虚部
  155. # 形状:(N, P, C, H, W, 2)
  156. matrix_tmp = (recon_freq - real_freq) ** 2
  157. # --- 子步骤 2: 计算频谱幅度差异 ---
  158. # 复数的模长公式:|a + bi| = sqrt(a² + b²)
  159. # 这里计算的是每个频率成分的预测误差的幅度
  160. # [..., 0]: 实部的平方, [..., 1]: 虚部的平方
  161. # torch.sqrt(...): 开平方得到欧几里得距离
  162. # ** self.alpha: 应用幂次变换,alpha 控制权重的分布特性
  163. # - alpha > 1: 放大差异,使大误差更突出
  164. # - alpha < 1: 缩小差异,使权重分布更均匀
  165. # - alpha = 1: 线性关系(默认情况)
  166. matrix_tmp = torch.sqrt(matrix_tmp[..., 0] + matrix_tmp[..., 1]) ** self.alpha
  167. # --- 子步骤 3: 可选的对数变换 ---
  168. if self.log_matrix:
  169. # 对数变换的作用:压缩动态范围
  170. # 当频谱差异的范围很大时(几个数量级),对数变换可以防止某些频率主导损失
  171. # log(x + 1.0): 加 1 是为了避免 log(0) 的数值不稳定
  172. matrix_tmp = torch.log(matrix_tmp + 1.0)
  173. # --- 子步骤 4: 归一化权重到 [0, 1] 范围 ---
  174. # 归一化的目的是确保权重有统一的尺度,便于控制和解释
  175. if self.batch_matrix:
  176. # 模式 A: Batch 级归一化
  177. # 在整个 batch 的所有像素、所有频率上找最大值,然后统一归一化
  178. # 优点:batch 内所有样本的权重在同一尺度
  179. # 缺点:可能掩盖样本间的差异
  180. matrix_tmp = matrix_tmp / matrix_tmp.max()
  181. else:
  182. # 模式 B: 样本级归一化(推荐)
  183. # 对每个样本单独归一化,保持样本间的相对差异
  184. # .max(-1).values: 沿最后一个维度(实部/虚部维度)取最大值
  185. # .max(-1).values[:, :, :, None, None]: 再沿空间维度取最大值
  186. # [:, :, :, None, None]: 添加维度以保持广播兼容性
  187. # 最终每张图片的权重矩阵独立归一化到 [0,1]
  188. matrix_tmp = matrix_tmp / matrix_tmp.max(-1).values.max(-1).values[:, :, :, None, None]
  189. # --- 子步骤 5: 数值稳定性处理 ---
  190. # 处理可能出现的 NaN 值(例如 0/0 的情况)
  191. # 将 NaN 替换为 0,表示这些位置不参与加权
  192. matrix_tmp[torch.isnan(matrix_tmp)] = 0.0
  193. # --- 子步骤 6: 截断到合法范围 ---
  194. # 确保所有权值都在 [0, 1] 区间内
  195. # torch.clamp: 将小于 0 的值设为 0,大于 1 的值设为 1
  196. # 这是防御性编程,防止数值溢出导致训练不稳定
  197. matrix_tmp = torch.clamp(matrix_tmp, min=0.0, max=1.0)
  198. # --- 子步骤 7: 创建最终的权重矩阵 ---
  199. # .clone().detach() 创建副本并断开梯度连接
  200. # 这样权重矩阵在本次前向传播中是固定的,不会自我影响
  201. # 这是关键设计:权重基于当前误差计算,但不参与本次的梯度回传
  202. weight_matrix = matrix_tmp.clone().detach()
  203. # ==================== 权重矩阵有效性验证 ====================
  204. # 断言检查:确保权重矩阵的所有值都在 [0, 1] 范围内
  205. # 这是一个安全检查,帮助调试时发现问题
  206. # .item() 将标量张量转换为 Python 浮点数,方便打印
  207. assert weight_matrix.min().item() >= 0 and weight_matrix.max().item() <= 1, (
  208. 'The values of spectrum weight matrix should be in the range [0, 1], '
  209. 'but got Min: %.10f Max: %.10f' % (weight_matrix.min().item(), weight_matrix.max().item()))
  210. # ==================== Step 2: 计算频谱距离 ====================
  211. # 计算预测频谱与真实频谱的逐元素差异的平方
  212. # 与前面计算 matrix_tmp 的第一步相同
  213. # 形状:(N, P, C, H, W, 2)
  214. tmp = (recon_freq - real_freq) ** 2
  215. # 将实部和虚部的平方和相加,得到复数空间中的欧几里得距离平方
  216. # 这就是每个频率成分的重建误差
  217. # 形状:(N, P, C, H, W)
  218. freq_distance = tmp[..., 0] + tmp[..., 1]
  219. # ==================== Step 3: 加权求和得到损失 ====================
  220. # 逐元素相乘:权重矩阵 × 频谱距离
  221. # 效果:
  222. # - 权重高的频率(重建困难)→ 损失贡献大 → 梯度大 → 模型重点关注
  223. # - 权重低的频率(重建简单)→ 损失贡献小 → 梯度小 → 模型次要关注
  224. # 形状:(N, P, C, H, W)
  225. loss = weight_matrix * freq_distance
  226. # 对所有维度取平均,得到标量损失值
  227. # torch.mean(): 将整个张量压缩成一个标量
  228. # 这样得到的损失可以直接用于反向传播
  229. return torch.mean(loss)
  230. def forward(self, pred, target, matrix=None):
  231. """
  232. 焦点频域损失的前向传播计算
  233. 这是损失函数的主要入口,当调用 loss_function(pred, target) 时执行此方法
  234. Args:
  235. pred (torch.Tensor): 预测的图像张量
  236. 形状:(N, C, H, W)
  237. N: batch size(批次大小)
  238. C: channels(通道数,对于灰度图 C=1,RGB 图 C=3)
  239. H: height(图像高度)
  240. W: width(图像宽度)
  241. target (torch.Tensor): 目标的图像张量(真实标签)
  242. 形状:(N, C, H, W),必须与 pred 完全相同
  243. matrix (torch.Tensor, optional): 预定义的频谱权重矩阵
  244. 如果提供,则使用该固定权重而非动态计算
  245. 默认:None(动态计算自适应权重)
  246. Returns:
  247. torch.Tensor: 标量损失值(0 维张量),可以直接用于反向传播
  248. 完整计算流程:
  249. ┌─────────────────┐
  250. │ 输入:pred, target │
  251. └────────┬──────────┘
  252. ┌─────────────────┐
  253. │ Step 1: 傅里叶变换 │ tensor2freq()
  254. │ pred → pred_freq │ 空间域 → 频域
  255. │ target → target_freq│
  256. └────────┬──────────┘
  257. ┌─────────────────┐
  258. │ Step 2: 可选的平均 │ if ave_spectrum
  259. │ 对 batch 维度平均 │ 减少样本间差异
  260. └────────┬──────────┘
  261. ┌─────────────────┐
  262. │ Step 3: 计算损失 │ loss_formulation()
  263. │ 动态权重 × 频谱距离 │ 焦点机制核心
  264. └────────┬──────────┘
  265. ┌─────────────────┐
  266. │ Step 4: 应用权重 │ × loss_weight
  267. │ 返回最终损失值 │
  268. └─────────────────┘
  269. 物理意义解释:
  270. 1. 傅里叶变换:将图像从"像素空间"转换到"频率空间"
  271. - 像素空间:关注每个点的亮度值
  272. - 频率空间:关注图像的周期性模式(纹理、边缘、轮廓)
  273. 2. 频谱比较:在频率空间中衡量预测与真实的差异
  274. - 低频误差:反映整体结构的偏差
  275. - 高频误差:反映细节纹理的偏差
  276. 3. 焦点权重:自动识别并强调难以重建的频率成分
  277. - 这是与传统频域损失(如简单的频谱 MSE)的关键区别
  278. - 类似于注意力机制,让模型"聚焦"于困难频率
  279. """
  280. # ==================== Step 1: 将预测图像转换为频域 ====================
  281. # 调用 tensor2freq 方法对预测图像进行傅里叶变换
  282. # 将空间域的像素表示转换为频域的频谱表示
  283. # pred_freq 包含了预测图像在各个频率上的振幅和相位信息
  284. pred_freq = self.tensor2freq(pred)
  285. # ==================== Step 2: 将目标图像转换为频域 ====================
  286. # 同样对真实标签图像进行傅里叶变换
  287. # 这样我们就可以在频率空间中比较预测与真实的差异
  288. target_freq = self.tensor2freq(target)
  289. # ==================== Step 3: 可选的 Batch 频谱平均 ====================
  290. if self.ave_spectrum:
  291. # 如果启用了 ave_spectrum 选项,对 batch 维度(第 0 维)取平均
  292. # keepdim=True 保持维度数量不变,只是将第 0 维的大小设为 1
  293. #
  294. # 这个操作的效果:
  295. # - 原始形状:(N, P, C, H, W, 2)
  296. # - 平均后:(1, P, C, H, W, 2)
  297. #
  298. # 为什么要这样做?
  299. # 1. 减少 batch 内样本间的随机波动
  300. # 2. 计算一个"平均频谱"作为代表
  301. # 3. 在某些任务中可以提高训练稳定性
  302. #
  303. # 注意:这会改变损失的语义,从"逐个样本的损失"变成"batch 级别的损失"
  304. pred_freq = torch.mean(pred_freq, 0, keepdim=True)
  305. target_freq = torch.mean(target_freq, 0, keepdim=True)
  306. # ==================== Step 4: 计算最终的焦点频域损失 ====================
  307. # 调用 loss_formulation 方法计算加权后的频域损失
  308. # 该方法会:
  309. # 1. 动态计算频谱权重矩阵(如果 matrix=None)
  310. # 2. 计算预测频谱与真实频谱的距离
  311. # 3. 用权重矩阵对距离加权,得到最终损失
  312. #
  313. # 返回值是一个标量张量,表示整个 batch 的平均损失
  314. loss_value = self.loss_formulation(pred_freq, target_freq, matrix)
  315. # ==================== Step 5: 应用损失权重系数 ====================
  316. # 将计算得到的损失乘以预设的权重系数 loss_weight
  317. #
  318. # 这个参数的作用:
  319. # - 在多损失联合训练时,平衡不同损失的重要性
  320. # - 例如:总损失 = 1.0 * DiceLoss + 0.1 * FocalFrequencyLoss
  321. # - 这样可以让频域损失作为辅助损失,不会主导训练过程
  322. #
  323. # 为什么需要这样做?
  324. # 不同的损失函数量级可能差异很大
  325. # DiceLoss 可能在 0-1 之间,而频域损失可能在 0-100 之间
  326. # 通过调整 loss_weight,可以确保各个损失在同一数量级
  327. final_loss = loss_value * self.loss_weight
  328. return final_loss