""" MRFFI 模块与 WaveletFFTBlock(2D 版本)。 """ import torch import torch.nn as nn from typing import Literal from .attentions_2d import WaveletAttentionGlobalBranch2d from .layers_2d import ( Conv2dBN, DWConv2dBNReLU, DropPath, FFN2d, Residual, ) class WaveletFFTMRFFIModule2d(nn.Module): def __init__( self, dim, global_ratio=0.25, local_ratio=0.25, kernel_size=5, wt_levels=1, wt_type="db1", wt_mode: Literal["constant", "zero", "reflect", "periodic", "symmetric"] = "zero", proj_drop=0.0, ): super().__init__() self.dim = dim self.global_channels = min(int(global_ratio * dim), dim) tentative_local = int(local_ratio * dim) if self.global_channels + tentative_local > dim: self.local_channels = max(dim - self.global_channels, 0) else: self.local_channels = tentative_local self.identity_channels = dim - self.global_channels - self.local_channels if self.global_channels > 0: self.global_op = WaveletAttentionGlobalBranch2d( self.global_channels, kernel_size=kernel_size, wt_levels=wt_levels, wt_type=wt_type, wt_mode=wt_mode, proj_drop=proj_drop, ) else: self.global_op = nn.Identity() if self.local_channels > 0: self.local_op = DWConv2dBNReLU(self.local_channels, self.local_channels, kernel_size=kernel_size) else: self.local_op = nn.Identity() self.proj = nn.Sequential( nn.ReLU(inplace=True), Conv2dBN(dim, dim, bn_weight_init=0.0), ) def forward(self, x): x_global, x_local, x_identity = torch.split( x, [self.global_channels, self.local_channels, self.identity_channels], dim=1, ) x_global = self.global_op(x_global) x_local = self.local_op(x_local) x = torch.cat([x_global, x_local, x_identity], dim=1) return self.proj(x) class WaveletFFTBlock2d(nn.Module): def __init__( self, dim, global_ratio=0.25, local_ratio=0.25, kernel_size=5, wt_levels=1, wt_type="db1", wt_mode: Literal["constant", "zero", "reflect", "periodic", "symmetric"] = "zero", proj_drop=0.0, drop_path=0.0, has_skip=True, ): super().__init__() self.dw0 = Residual(Conv2dBN(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0.0)) self.ffn0 = Residual(FFN2d(dim, int(dim * 2))) self.mixer = Residual( WaveletFFTMRFFIModule2d( dim, global_ratio=global_ratio, local_ratio=local_ratio, kernel_size=kernel_size, wt_levels=wt_levels, wt_type=wt_type, wt_mode=wt_mode, proj_drop=proj_drop, ) ) self.dw1 = Residual(Conv2dBN(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0.0)) self.ffn1 = Residual(FFN2d(dim, int(dim * 2))) self.has_skip = has_skip self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() def forward(self, x): shortcut = x x = self.ffn1(self.dw1(self.mixer(self.ffn0(self.dw0(x))))) if self.has_skip: x = shortcut + self.drop_path(x) return x