from typing import Optional import torch import torch.nn as nn import torch.fft as fft from monai.losses import DiceLoss, DiceCELoss, HausdorffDTLoss class FocalFrequencyLoss(nn.Module): """ 焦点频域损失函数 (Focal Frequency Loss) 核心思想: 传统的空间域损失(如 MSE、Dice)主要关注像素级别的差异,而频域损失通过傅里叶变换 将图像转换到频率域,从频率角度衡量预测图像与真实图像的差异。 该损失的创新点在于"焦点"机制: 1. 自动计算频谱权重矩阵,对不同频率成分赋予不同的重要性 2. 对难以重建的频率成分给予更高权重(类似 Focal Loss 的思想) 3. 可以捕捉图像的全局结构和纹理细节,弥补空间域损失的不足 适用场景: - 医学图像分割:增强边缘和纹理的恢复 - 图像超分辨率:重建高频细节 - 图像去噪/去模糊:平衡低频和高频信息 参数说明: loss_weight: 损失权重系数,用于平衡该损失与其他损失的重要性 alpha: 频谱权重的幂次参数,控制权重分布的陡峭程度 alpha 越大,困难频率成分的权重越突出 patch_factor: 图像分块因子,将图像分成多个小块分别进行 FFT 值为 1 表示不分块,对整个图像做 FFT 值大于 1 时,将图像分成 patch_factor×patch_factor 个小块 ave_spectrum: 是否对 batch 内的频谱进行平均,用于减少 batch 内差异 log_matrix: 是否对频谱差异取对数,用于压缩动态范围 batch_matrix: 权重归一化方式 True: 在整个 batch 范围内归一化到 [0,1] False: 对每张图像单独归一化 """ def __init__(self, loss_weight=1.0, alpha=1.0, patch_factor=1, ave_spectrum=False, log_matrix=False, batch_matrix=False): """ 初始化焦点频域损失函数的所有超参数 Args: loss_weight (float): 损失权重,默认 1.0 alpha (float): 频谱权重指数,默认 1.0 patch_factor (int): 图像分块因子,默认 1(不分块) ave_spectrum (bool): 是否对 batch 频谱平均,默认 False log_matrix (bool): 是否使用对数矩阵,默认 False batch_matrix (bool): 是否使用 batch 级矩阵,默认 False """ super(FocalFrequencyLoss, self).__init__() self.loss_weight = loss_weight self.alpha = alpha self.patch_factor = patch_factor self.ave_spectrum = ave_spectrum self.log_matrix = log_matrix self.batch_matrix = batch_matrix def tensor2freq(self, x): """ 将空间域图像张量转换为频域表示 工作原理: 1. 如果 patch_factor > 1,先将图像分割成多个小块 2. 对每个小块执行 2D 快速傅里叶变换 (FFT) 3. 将复数形式的 FFT 结果分解为实部和虚部 傅里叶变换的物理意义: - 低频成分:对应图像的平滑区域和整体轮廓 - 高频成分:对应图像的边缘、纹理和噪声 - 通过分析频谱,可以分离和处理不同频率的特征 Args: x (torch.Tensor): 输入图像张量,形状为 (N, C, H, W) N=batch_size, C=channels, H=height, W=width Returns: freq (torch.Tensor): 频域表示,形状为 (N, patch_factor², C, H/pf, W/pf, 2) 最后一维的 2 个通道分别是 [实部,虚部] patch_factor² 表示分成了多少个小块 Example: 输入:x.shape = (4, 1, 256, 256), patch_factor=4 输出:freq.shape = (4, 16, 1, 64, 64, 2) - 16 = 4×4 个小块 - 64×64 = 每个小块的尺寸 - 2 = [实部,虚部] """ # 获取分块因子 patch_factor = self.patch_factor # 获取输入图像的尺寸信息 _, _, h, w = x.shape # 断言检查:确保图像尺寸可以被 patch_factor 整除 # 这是为了保证分块时每个小块大小一致,避免边界问题 assert h % patch_factor == 0 and w % patch_factor == 0, ( 'Patch factor should be divisible by image height and width') # 初始化列表用于存储所有小块的频域表示 patch_list = [] # 计算每个小块的高度和宽度 # 例如:原图 256×256, patch_factor=4 → 每个小块 64×64 patch_h = h // patch_factor patch_w = w // patch_factor # 双重循环遍历所有小块 # i 控制垂直方向的索引,j 控制水平方向的索引 for i in range(patch_factor): for j in range(patch_factor): # 切片操作:提取第 (i,j) 个小块 # 垂直方向:从 i*patch_h 到 (i+1)*patch_h # 水平方向:从 j*patch_w 到 (j+1)*patch_w # [:, :, ...] 保持 batch 和 channel 维度不变 patch_list.append(x[:, :, i * patch_h:(i + 1) * patch_h, j * patch_w:(j + 1) * patch_w]) # 将所有小块堆叠成一个新的维度 # 原始形状:(N, C, patch_h, patch_w) 的列表,长度为 patch_factor² # 堆叠后形状:(N, patch_factor², C, patch_h, patch_w) # dim=1 表示在第 1 个维度(channel 之后)插入新的分块维度 y = torch.stack(patch_list, 1) # 对每个小块执行 2D 快速傅里叶变换 # torch.fft.fft2: 计算二维离散傅里叶变换 # norm='ortho': 使用正交归一化,保证变换前后能量守恒 # 变换结果是一个复数张量,包含每个频率成分的振幅和相位信息 freq = torch.fft.fft2(y, norm='ortho') # 将复数形式转换为实数表示,方便后续神经网络处理 # torch.stack([freq.real, freq.imag], -1): 将实部和虚部堆叠到最后一个维度 # freq.real: 复数的实部,代表余弦分量的幅度 # freq.imag: 复数的虚部,代表正弦分量的幅度 # 最终形状:(N, patch_factor², C, patch_h, patch_w, 2) freq = torch.stack([freq.real, freq.imag], -1) return freq def loss_formulation(self, recon_freq, real_freq, matrix=None): """ 构建并计算焦点频域损失的核心公式 核心思想: 传统频域损失直接计算预测频谱与真实频谱的距离(如 MSE)。 本方法引入"焦点"机制:根据每个频率成分的重建难度动态调整权重。 权重计算逻辑: 1. 如果未提供预定义权重矩阵,则在线计算动态权重 2. 重建误差大的频率 → 高权重(重点关注) 3. 重建误差小的频率 → 低权重(减少关注) 这类似于 Focal Loss 中"关注难例"的思想 Args: recon_freq (torch.Tensor): 重建(预测)图像的频域表示 形状:(N, P, C, H, W, 2),P=patch_factor² real_freq (torch.Tensor): 真实(目标)图像的频域表示 形状:(N, P, C, H, W, 2) matrix (torch.Tensor, optional): 预定义的频谱权重矩阵 如果为 None,则动态计算权重 Returns: loss (torch.Tensor): 标量损失值 工作流程: Step 1: 确定权重矩阵(预定义或动态计算) Step 2: 计算频谱距离(复数空间中的欧氏距离) Step 3: 加权求和得到最终损失 """ # ==================== Step 1: 确定权重矩阵 ==================== if matrix is not None: # 情况 A: 使用预定义的权重矩阵 # 这种模式允许外部指定固定的频率权重,适用于某些先验知识已知的场景 # 例如:人工指定某些频率更重要,或者使用其他算法计算的权重 weight_matrix = matrix.detach() # .detach() 确保权重矩阵不参与梯度反向传播 # 这样权重是固定的,不会影响梯度的流动 else: # 情况 B: 动态计算自适应权重矩阵(推荐模式) # --- 子步骤 1: 计算初步的频谱差异 --- # 逐元素计算预测频谱与真实频谱的差值的平方 # 这是一个复数差的平方,需要分别处理实部和虚部 # 形状:(N, P, C, H, W, 2) matrix_tmp = (recon_freq - real_freq) ** 2 # --- 子步骤 2: 计算频谱幅度差异 --- # 复数的模长公式:|a + bi| = sqrt(a² + b²) # 这里计算的是每个频率成分的预测误差的幅度 # [..., 0]: 实部的平方, [..., 1]: 虚部的平方 # torch.sqrt(...): 开平方得到欧几里得距离 # ** self.alpha: 应用幂次变换,alpha 控制权重的分布特性 # - alpha > 1: 放大差异,使大误差更突出 # - alpha < 1: 缩小差异,使权重分布更均匀 # - alpha = 1: 线性关系(默认情况) matrix_tmp = torch.sqrt(matrix_tmp[..., 0] + matrix_tmp[..., 1]) ** self.alpha # --- 子步骤 3: 可选的对数变换 --- if self.log_matrix: # 对数变换的作用:压缩动态范围 # 当频谱差异的范围很大时(几个数量级),对数变换可以防止某些频率主导损失 # log(x + 1.0): 加 1 是为了避免 log(0) 的数值不稳定 matrix_tmp = torch.log(matrix_tmp + 1.0) # --- 子步骤 4: 归一化权重到 [0, 1] 范围 --- # 归一化的目的是确保权重有统一的尺度,便于控制和解释 if self.batch_matrix: # 模式 A: Batch 级归一化 # 在整个 batch 的所有像素、所有频率上找最大值,然后统一归一化 # 优点:batch 内所有样本的权重在同一尺度 # 缺点:可能掩盖样本间的差异 matrix_tmp = matrix_tmp / matrix_tmp.max() else: # 模式 B: 样本级归一化(推荐) # 对每个样本单独归一化,保持样本间的相对差异 # .max(-1).values: 沿最后一个维度(实部/虚部维度)取最大值 # .max(-1).values[:, :, :, None, None]: 再沿空间维度取最大值 # [:, :, :, None, None]: 添加维度以保持广播兼容性 # 最终每张图片的权重矩阵独立归一化到 [0,1] matrix_tmp = matrix_tmp / matrix_tmp.max(-1).values.max(-1).values[:, :, :, None, None] # --- 子步骤 5: 数值稳定性处理 --- # 处理可能出现的 NaN 值(例如 0/0 的情况) # 将 NaN 替换为 0,表示这些位置不参与加权 matrix_tmp[torch.isnan(matrix_tmp)] = 0.0 # --- 子步骤 6: 截断到合法范围 --- # 确保所有权值都在 [0, 1] 区间内 # torch.clamp: 将小于 0 的值设为 0,大于 1 的值设为 1 # 这是防御性编程,防止数值溢出导致训练不稳定 matrix_tmp = torch.clamp(matrix_tmp, min=0.0, max=1.0) # --- 子步骤 7: 创建最终的权重矩阵 --- # .clone().detach() 创建副本并断开梯度连接 # 这样权重矩阵在本次前向传播中是固定的,不会自我影响 # 这是关键设计:权重基于当前误差计算,但不参与本次的梯度回传 weight_matrix = matrix_tmp.clone().detach() # ==================== 权重矩阵有效性验证 ==================== # 断言检查:确保权重矩阵的所有值都在 [0, 1] 范围内 # 这是一个安全检查,帮助调试时发现问题 # .item() 将标量张量转换为 Python 浮点数,方便打印 assert weight_matrix.min().item() >= 0 and weight_matrix.max().item() <= 1, ( 'The values of spectrum weight matrix should be in the range [0, 1], ' 'but got Min: %.10f Max: %.10f' % (weight_matrix.min().item(), weight_matrix.max().item())) # ==================== Step 2: 计算频谱距离 ==================== # 计算预测频谱与真实频谱的逐元素差异的平方 # 与前面计算 matrix_tmp 的第一步相同 # 形状:(N, P, C, H, W, 2) tmp = (recon_freq - real_freq) ** 2 # 将实部和虚部的平方和相加,得到复数空间中的欧几里得距离平方 # 这就是每个频率成分的重建误差 # 形状:(N, P, C, H, W) freq_distance = tmp[..., 0] + tmp[..., 1] # ==================== Step 3: 加权求和得到损失 ==================== # 逐元素相乘:权重矩阵 × 频谱距离 # 效果: # - 权重高的频率(重建困难)→ 损失贡献大 → 梯度大 → 模型重点关注 # - 权重低的频率(重建简单)→ 损失贡献小 → 梯度小 → 模型次要关注 # 形状:(N, P, C, H, W) loss = weight_matrix * freq_distance # 对所有维度取平均,得到标量损失值 # torch.mean(): 将整个张量压缩成一个标量 # 这样得到的损失可以直接用于反向传播 return torch.mean(loss) def forward(self, pred, target, matrix=None): """ 焦点频域损失的前向传播计算 这是损失函数的主要入口,当调用 loss_function(pred, target) 时执行此方法 Args: pred (torch.Tensor): 预测的图像张量 形状:(N, C, H, W) N: batch size(批次大小) C: channels(通道数,对于灰度图 C=1,RGB 图 C=3) H: height(图像高度) W: width(图像宽度) target (torch.Tensor): 目标的图像张量(真实标签) 形状:(N, C, H, W),必须与 pred 完全相同 matrix (torch.Tensor, optional): 预定义的频谱权重矩阵 如果提供,则使用该固定权重而非动态计算 默认:None(动态计算自适应权重) Returns: torch.Tensor: 标量损失值(0 维张量),可以直接用于反向传播 完整计算流程: ┌─────────────────┐ │ 输入:pred, target │ └────────┬──────────┘ │ ▼ ┌─────────────────┐ │ Step 1: 傅里叶变换 │ tensor2freq() │ pred → pred_freq │ 空间域 → 频域 │ target → target_freq│ └────────┬──────────┘ │ ▼ ┌─────────────────┐ │ Step 2: 可选的平均 │ if ave_spectrum │ 对 batch 维度平均 │ 减少样本间差异 └────────┬──────────┘ │ ▼ ┌─────────────────┐ │ Step 3: 计算损失 │ loss_formulation() │ 动态权重 × 频谱距离 │ 焦点机制核心 └────────┬──────────┘ │ ▼ ┌─────────────────┐ │ Step 4: 应用权重 │ × loss_weight │ 返回最终损失值 │ └─────────────────┘ 物理意义解释: 1. 傅里叶变换:将图像从"像素空间"转换到"频率空间" - 像素空间:关注每个点的亮度值 - 频率空间:关注图像的周期性模式(纹理、边缘、轮廓) 2. 频谱比较:在频率空间中衡量预测与真实的差异 - 低频误差:反映整体结构的偏差 - 高频误差:反映细节纹理的偏差 3. 焦点权重:自动识别并强调难以重建的频率成分 - 这是与传统频域损失(如简单的频谱 MSE)的关键区别 - 类似于注意力机制,让模型"聚焦"于困难频率 """ # ==================== Step 1: 将预测图像转换为频域 ==================== # 调用 tensor2freq 方法对预测图像进行傅里叶变换 # 将空间域的像素表示转换为频域的频谱表示 # pred_freq 包含了预测图像在各个频率上的振幅和相位信息 pred_freq = self.tensor2freq(pred) # ==================== Step 2: 将目标图像转换为频域 ==================== # 同样对真实标签图像进行傅里叶变换 # 这样我们就可以在频率空间中比较预测与真实的差异 target_freq = self.tensor2freq(target) # ==================== Step 3: 可选的 Batch 频谱平均 ==================== if self.ave_spectrum: # 如果启用了 ave_spectrum 选项,对 batch 维度(第 0 维)取平均 # keepdim=True 保持维度数量不变,只是将第 0 维的大小设为 1 # # 这个操作的效果: # - 原始形状:(N, P, C, H, W, 2) # - 平均后:(1, P, C, H, W, 2) # # 为什么要这样做? # 1. 减少 batch 内样本间的随机波动 # 2. 计算一个"平均频谱"作为代表 # 3. 在某些任务中可以提高训练稳定性 # # 注意:这会改变损失的语义,从"逐个样本的损失"变成"batch 级别的损失" pred_freq = torch.mean(pred_freq, 0, keepdim=True) target_freq = torch.mean(target_freq, 0, keepdim=True) # ==================== Step 4: 计算最终的焦点频域损失 ==================== # 调用 loss_formulation 方法计算加权后的频域损失 # 该方法会: # 1. 动态计算频谱权重矩阵(如果 matrix=None) # 2. 计算预测频谱与真实频谱的距离 # 3. 用权重矩阵对距离加权,得到最终损失 # # 返回值是一个标量张量,表示整个 batch 的平均损失 loss_value = self.loss_formulation(pred_freq, target_freq, matrix) # ==================== Step 5: 应用损失权重系数 ==================== # 将计算得到的损失乘以预设的权重系数 loss_weight # # 这个参数的作用: # - 在多损失联合训练时,平衡不同损失的重要性 # - 例如:总损失 = 1.0 * DiceLoss + 0.1 * FocalFrequencyLoss # - 这样可以让频域损失作为辅助损失,不会主导训练过程 # # 为什么需要这样做? # 不同的损失函数量级可能差异很大 # DiceLoss 可能在 0-1 之间,而频域损失可能在 0-100 之间 # 通过调整 loss_weight,可以确保各个损失在同一数量级 final_loss = loss_value * self.loss_weight return final_loss