| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397 |
- from typing import Optional
- import torch
- import torch.nn as nn
- import torch.fft as fft
- from monai.losses import DiceLoss, DiceCELoss, HausdorffDTLoss
- class FocalFrequencyLoss(nn.Module):
- """
- 焦点频域损失函数 (Focal Frequency Loss)
-
- 核心思想:
- 传统的空间域损失(如 MSE、Dice)主要关注像素级别的差异,而频域损失通过傅里叶变换
- 将图像转换到频率域,从频率角度衡量预测图像与真实图像的差异。
-
- 该损失的创新点在于"焦点"机制:
- 1. 自动计算频谱权重矩阵,对不同频率成分赋予不同的重要性
- 2. 对难以重建的频率成分给予更高权重(类似 Focal Loss 的思想)
- 3. 可以捕捉图像的全局结构和纹理细节,弥补空间域损失的不足
-
- 适用场景:
- - 医学图像分割:增强边缘和纹理的恢复
- - 图像超分辨率:重建高频细节
- - 图像去噪/去模糊:平衡低频和高频信息
-
- 参数说明:
- loss_weight: 损失权重系数,用于平衡该损失与其他损失的重要性
- alpha: 频谱权重的幂次参数,控制权重分布的陡峭程度
- alpha 越大,困难频率成分的权重越突出
- patch_factor: 图像分块因子,将图像分成多个小块分别进行 FFT
- 值为 1 表示不分块,对整个图像做 FFT
- 值大于 1 时,将图像分成 patch_factor×patch_factor 个小块
- ave_spectrum: 是否对 batch 内的频谱进行平均,用于减少 batch 内差异
- log_matrix: 是否对频谱差异取对数,用于压缩动态范围
- batch_matrix: 权重归一化方式
- True: 在整个 batch 范围内归一化到 [0,1]
- False: 对每张图像单独归一化
- """
- def __init__(self, loss_weight=1.0, alpha=1.0, patch_factor=1, ave_spectrum=False, log_matrix=False,
- batch_matrix=False):
- """
- 初始化焦点频域损失函数的所有超参数
-
- Args:
- loss_weight (float): 损失权重,默认 1.0
- alpha (float): 频谱权重指数,默认 1.0
- patch_factor (int): 图像分块因子,默认 1(不分块)
- ave_spectrum (bool): 是否对 batch 频谱平均,默认 False
- log_matrix (bool): 是否使用对数矩阵,默认 False
- batch_matrix (bool): 是否使用 batch 级矩阵,默认 False
- """
- super(FocalFrequencyLoss, self).__init__()
- self.loss_weight = loss_weight
- self.alpha = alpha
- self.patch_factor = patch_factor
- self.ave_spectrum = ave_spectrum
- self.log_matrix = log_matrix
- self.batch_matrix = batch_matrix
- def tensor2freq(self, x):
- """
- 将空间域图像张量转换为频域表示
-
- 工作原理:
- 1. 如果 patch_factor > 1,先将图像分割成多个小块
- 2. 对每个小块执行 2D 快速傅里叶变换 (FFT)
- 3. 将复数形式的 FFT 结果分解为实部和虚部
-
- 傅里叶变换的物理意义:
- - 低频成分:对应图像的平滑区域和整体轮廓
- - 高频成分:对应图像的边缘、纹理和噪声
- - 通过分析频谱,可以分离和处理不同频率的特征
-
- Args:
- x (torch.Tensor): 输入图像张量,形状为 (N, C, H, W)
- N=batch_size, C=channels, H=height, W=width
-
- Returns:
- freq (torch.Tensor): 频域表示,形状为 (N, patch_factor², C, H/pf, W/pf, 2)
- 最后一维的 2 个通道分别是 [实部,虚部]
- patch_factor² 表示分成了多少个小块
-
- Example:
- 输入:x.shape = (4, 1, 256, 256), patch_factor=4
- 输出:freq.shape = (4, 16, 1, 64, 64, 2)
- - 16 = 4×4 个小块
- - 64×64 = 每个小块的尺寸
- - 2 = [实部,虚部]
- """
- # 获取分块因子
- patch_factor = self.patch_factor
- # 获取输入图像的尺寸信息
- _, _, h, w = x.shape
- # 断言检查:确保图像尺寸可以被 patch_factor 整除
- # 这是为了保证分块时每个小块大小一致,避免边界问题
- assert h % patch_factor == 0 and w % patch_factor == 0, (
- 'Patch factor should be divisible by image height and width')
- # 初始化列表用于存储所有小块的频域表示
- patch_list = []
- # 计算每个小块的高度和宽度
- # 例如:原图 256×256, patch_factor=4 → 每个小块 64×64
- patch_h = h // patch_factor
- patch_w = w // patch_factor
- # 双重循环遍历所有小块
- # i 控制垂直方向的索引,j 控制水平方向的索引
- for i in range(patch_factor):
- for j in range(patch_factor):
- # 切片操作:提取第 (i,j) 个小块
- # 垂直方向:从 i*patch_h 到 (i+1)*patch_h
- # 水平方向:从 j*patch_w 到 (j+1)*patch_w
- # [:, :, ...] 保持 batch 和 channel 维度不变
- patch_list.append(x[:, :, i * patch_h:(i + 1) * patch_h, j * patch_w:(j + 1) * patch_w])
- # 将所有小块堆叠成一个新的维度
- # 原始形状:(N, C, patch_h, patch_w) 的列表,长度为 patch_factor²
- # 堆叠后形状:(N, patch_factor², C, patch_h, patch_w)
- # dim=1 表示在第 1 个维度(channel 之后)插入新的分块维度
- y = torch.stack(patch_list, 1)
- # 对每个小块执行 2D 快速傅里叶变换
- # torch.fft.fft2: 计算二维离散傅里叶变换
- # norm='ortho': 使用正交归一化,保证变换前后能量守恒
- # 变换结果是一个复数张量,包含每个频率成分的振幅和相位信息
- freq = torch.fft.fft2(y, norm='ortho')
- # 将复数形式转换为实数表示,方便后续神经网络处理
- # torch.stack([freq.real, freq.imag], -1): 将实部和虚部堆叠到最后一个维度
- # freq.real: 复数的实部,代表余弦分量的幅度
- # freq.imag: 复数的虚部,代表正弦分量的幅度
- # 最终形状:(N, patch_factor², C, patch_h, patch_w, 2)
- freq = torch.stack([freq.real, freq.imag], -1)
- return freq
- def loss_formulation(self, recon_freq, real_freq, matrix=None):
- """
- 构建并计算焦点频域损失的核心公式
-
- 核心思想:
- 传统频域损失直接计算预测频谱与真实频谱的距离(如 MSE)。
- 本方法引入"焦点"机制:根据每个频率成分的重建难度动态调整权重。
-
- 权重计算逻辑:
- 1. 如果未提供预定义权重矩阵,则在线计算动态权重
- 2. 重建误差大的频率 → 高权重(重点关注)
- 3. 重建误差小的频率 → 低权重(减少关注)
- 这类似于 Focal Loss 中"关注难例"的思想
-
- Args:
- recon_freq (torch.Tensor): 重建(预测)图像的频域表示
- 形状:(N, P, C, H, W, 2),P=patch_factor²
- real_freq (torch.Tensor): 真实(目标)图像的频域表示
- 形状:(N, P, C, H, W, 2)
- matrix (torch.Tensor, optional): 预定义的频谱权重矩阵
- 如果为 None,则动态计算权重
-
- Returns:
- loss (torch.Tensor): 标量损失值
-
- 工作流程:
- Step 1: 确定权重矩阵(预定义或动态计算)
- Step 2: 计算频谱距离(复数空间中的欧氏距离)
- Step 3: 加权求和得到最终损失
- """
- # ==================== Step 1: 确定权重矩阵 ====================
- if matrix is not None:
- # 情况 A: 使用预定义的权重矩阵
- # 这种模式允许外部指定固定的频率权重,适用于某些先验知识已知的场景
- # 例如:人工指定某些频率更重要,或者使用其他算法计算的权重
- weight_matrix = matrix.detach()
- # .detach() 确保权重矩阵不参与梯度反向传播
- # 这样权重是固定的,不会影响梯度的流动
- else:
- # 情况 B: 动态计算自适应权重矩阵(推荐模式)
- # --- 子步骤 1: 计算初步的频谱差异 ---
- # 逐元素计算预测频谱与真实频谱的差值的平方
- # 这是一个复数差的平方,需要分别处理实部和虚部
- # 形状:(N, P, C, H, W, 2)
- matrix_tmp = (recon_freq - real_freq) ** 2
- # --- 子步骤 2: 计算频谱幅度差异 ---
- # 复数的模长公式:|a + bi| = sqrt(a² + b²)
- # 这里计算的是每个频率成分的预测误差的幅度
- # [..., 0]: 实部的平方, [..., 1]: 虚部的平方
- # torch.sqrt(...): 开平方得到欧几里得距离
- # ** self.alpha: 应用幂次变换,alpha 控制权重的分布特性
- # - alpha > 1: 放大差异,使大误差更突出
- # - alpha < 1: 缩小差异,使权重分布更均匀
- # - alpha = 1: 线性关系(默认情况)
- matrix_tmp = torch.sqrt(matrix_tmp[..., 0] + matrix_tmp[..., 1]) ** self.alpha
- # --- 子步骤 3: 可选的对数变换 ---
- if self.log_matrix:
- # 对数变换的作用:压缩动态范围
- # 当频谱差异的范围很大时(几个数量级),对数变换可以防止某些频率主导损失
- # log(x + 1.0): 加 1 是为了避免 log(0) 的数值不稳定
- matrix_tmp = torch.log(matrix_tmp + 1.0)
- # --- 子步骤 4: 归一化权重到 [0, 1] 范围 ---
- # 归一化的目的是确保权重有统一的尺度,便于控制和解释
- if self.batch_matrix:
- # 模式 A: Batch 级归一化
- # 在整个 batch 的所有像素、所有频率上找最大值,然后统一归一化
- # 优点:batch 内所有样本的权重在同一尺度
- # 缺点:可能掩盖样本间的差异
- matrix_tmp = matrix_tmp / matrix_tmp.max()
- else:
- # 模式 B: 样本级归一化(推荐)
- # 对每个样本单独归一化,保持样本间的相对差异
- # .max(-1).values: 沿最后一个维度(实部/虚部维度)取最大值
- # .max(-1).values[:, :, :, None, None]: 再沿空间维度取最大值
- # [:, :, :, None, None]: 添加维度以保持广播兼容性
- # 最终每张图片的权重矩阵独立归一化到 [0,1]
- matrix_tmp = matrix_tmp / matrix_tmp.max(-1).values.max(-1).values[:, :, :, None, None]
- # --- 子步骤 5: 数值稳定性处理 ---
- # 处理可能出现的 NaN 值(例如 0/0 的情况)
- # 将 NaN 替换为 0,表示这些位置不参与加权
- matrix_tmp[torch.isnan(matrix_tmp)] = 0.0
- # --- 子步骤 6: 截断到合法范围 ---
- # 确保所有权值都在 [0, 1] 区间内
- # torch.clamp: 将小于 0 的值设为 0,大于 1 的值设为 1
- # 这是防御性编程,防止数值溢出导致训练不稳定
- matrix_tmp = torch.clamp(matrix_tmp, min=0.0, max=1.0)
- # --- 子步骤 7: 创建最终的权重矩阵 ---
- # .clone().detach() 创建副本并断开梯度连接
- # 这样权重矩阵在本次前向传播中是固定的,不会自我影响
- # 这是关键设计:权重基于当前误差计算,但不参与本次的梯度回传
- weight_matrix = matrix_tmp.clone().detach()
- # ==================== 权重矩阵有效性验证 ====================
- # 断言检查:确保权重矩阵的所有值都在 [0, 1] 范围内
- # 这是一个安全检查,帮助调试时发现问题
- # .item() 将标量张量转换为 Python 浮点数,方便打印
- assert weight_matrix.min().item() >= 0 and weight_matrix.max().item() <= 1, (
- 'The values of spectrum weight matrix should be in the range [0, 1], '
- 'but got Min: %.10f Max: %.10f' % (weight_matrix.min().item(), weight_matrix.max().item()))
- # ==================== Step 2: 计算频谱距离 ====================
- # 计算预测频谱与真实频谱的逐元素差异的平方
- # 与前面计算 matrix_tmp 的第一步相同
- # 形状:(N, P, C, H, W, 2)
- tmp = (recon_freq - real_freq) ** 2
- # 将实部和虚部的平方和相加,得到复数空间中的欧几里得距离平方
- # 这就是每个频率成分的重建误差
- # 形状:(N, P, C, H, W)
- freq_distance = tmp[..., 0] + tmp[..., 1]
- # ==================== Step 3: 加权求和得到损失 ====================
- # 逐元素相乘:权重矩阵 × 频谱距离
- # 效果:
- # - 权重高的频率(重建困难)→ 损失贡献大 → 梯度大 → 模型重点关注
- # - 权重低的频率(重建简单)→ 损失贡献小 → 梯度小 → 模型次要关注
- # 形状:(N, P, C, H, W)
- loss = weight_matrix * freq_distance
- # 对所有维度取平均,得到标量损失值
- # torch.mean(): 将整个张量压缩成一个标量
- # 这样得到的损失可以直接用于反向传播
- return torch.mean(loss)
- def forward(self, pred, target, matrix=None):
- """
- 焦点频域损失的前向传播计算
-
- 这是损失函数的主要入口,当调用 loss_function(pred, target) 时执行此方法
-
- Args:
- pred (torch.Tensor): 预测的图像张量
- 形状:(N, C, H, W)
- N: batch size(批次大小)
- C: channels(通道数,对于灰度图 C=1,RGB 图 C=3)
- H: height(图像高度)
- W: width(图像宽度)
-
- target (torch.Tensor): 目标的图像张量(真实标签)
- 形状:(N, C, H, W),必须与 pred 完全相同
-
- matrix (torch.Tensor, optional): 预定义的频谱权重矩阵
- 如果提供,则使用该固定权重而非动态计算
- 默认:None(动态计算自适应权重)
-
- Returns:
- torch.Tensor: 标量损失值(0 维张量),可以直接用于反向传播
-
- 完整计算流程:
- ┌─────────────────┐
- │ 输入:pred, target │
- └────────┬──────────┘
- │
- ▼
- ┌─────────────────┐
- │ Step 1: 傅里叶变换 │ tensor2freq()
- │ pred → pred_freq │ 空间域 → 频域
- │ target → target_freq│
- └────────┬──────────┘
- │
- ▼
- ┌─────────────────┐
- │ Step 2: 可选的平均 │ if ave_spectrum
- │ 对 batch 维度平均 │ 减少样本间差异
- └────────┬──────────┘
- │
- ▼
- ┌─────────────────┐
- │ Step 3: 计算损失 │ loss_formulation()
- │ 动态权重 × 频谱距离 │ 焦点机制核心
- └────────┬──────────┘
- │
- ▼
- ┌─────────────────┐
- │ Step 4: 应用权重 │ × loss_weight
- │ 返回最终损失值 │
- └─────────────────┘
-
- 物理意义解释:
- 1. 傅里叶变换:将图像从"像素空间"转换到"频率空间"
- - 像素空间:关注每个点的亮度值
- - 频率空间:关注图像的周期性模式(纹理、边缘、轮廓)
-
- 2. 频谱比较:在频率空间中衡量预测与真实的差异
- - 低频误差:反映整体结构的偏差
- - 高频误差:反映细节纹理的偏差
-
- 3. 焦点权重:自动识别并强调难以重建的频率成分
- - 这是与传统频域损失(如简单的频谱 MSE)的关键区别
- - 类似于注意力机制,让模型"聚焦"于困难频率
-
- """
- # ==================== Step 1: 将预测图像转换为频域 ====================
- # 调用 tensor2freq 方法对预测图像进行傅里叶变换
- # 将空间域的像素表示转换为频域的频谱表示
- # pred_freq 包含了预测图像在各个频率上的振幅和相位信息
- pred_freq = self.tensor2freq(pred)
- # ==================== Step 2: 将目标图像转换为频域 ====================
- # 同样对真实标签图像进行傅里叶变换
- # 这样我们就可以在频率空间中比较预测与真实的差异
- target_freq = self.tensor2freq(target)
- # ==================== Step 3: 可选的 Batch 频谱平均 ====================
- if self.ave_spectrum:
- # 如果启用了 ave_spectrum 选项,对 batch 维度(第 0 维)取平均
- # keepdim=True 保持维度数量不变,只是将第 0 维的大小设为 1
- #
- # 这个操作的效果:
- # - 原始形状:(N, P, C, H, W, 2)
- # - 平均后:(1, P, C, H, W, 2)
- #
- # 为什么要这样做?
- # 1. 减少 batch 内样本间的随机波动
- # 2. 计算一个"平均频谱"作为代表
- # 3. 在某些任务中可以提高训练稳定性
- #
- # 注意:这会改变损失的语义,从"逐个样本的损失"变成"batch 级别的损失"
- pred_freq = torch.mean(pred_freq, 0, keepdim=True)
- target_freq = torch.mean(target_freq, 0, keepdim=True)
- # ==================== Step 4: 计算最终的焦点频域损失 ====================
- # 调用 loss_formulation 方法计算加权后的频域损失
- # 该方法会:
- # 1. 动态计算频谱权重矩阵(如果 matrix=None)
- # 2. 计算预测频谱与真实频谱的距离
- # 3. 用权重矩阵对距离加权,得到最终损失
- #
- # 返回值是一个标量张量,表示整个 batch 的平均损失
- loss_value = self.loss_formulation(pred_freq, target_freq, matrix)
- # ==================== Step 5: 应用损失权重系数 ====================
- # 将计算得到的损失乘以预设的权重系数 loss_weight
- #
- # 这个参数的作用:
- # - 在多损失联合训练时,平衡不同损失的重要性
- # - 例如:总损失 = 1.0 * DiceLoss + 0.1 * FocalFrequencyLoss
- # - 这样可以让频域损失作为辅助损失,不会主导训练过程
- #
- # 为什么需要这样做?
- # 不同的损失函数量级可能差异很大
- # DiceLoss 可能在 0-1 之间,而频域损失可能在 0-100 之间
- # 通过调整 loss_weight,可以确保各个损失在同一数量级
- final_loss = loss_value * self.loss_weight
- return final_loss
|