import torch import torch.nn as nn class FFTRCAB(nn.Module): def __init__(self, dim): super(FFTRCAB, self).__init__() self.CBG3x3 = nn.Sequential( nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=False), nn.LeakyReLU(0.1, inplace=True), nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=False), ) self.avg_pool = nn.AdaptiveAvgPool2d(1) self.xc_aEnhance = nn.Sequential( nn.Conv2d(dim, dim // 2 + 1, 1, 1, 0), nn.LeakyReLU(0.1, inplace=True), nn.Conv2d(dim // 2 + 1, dim // 2 + 1, 1, 1, 0), ) self.xc_pEnhance = nn.Sequential( nn.Conv2d(dim, dim // 2 + 1, 1, 1, 0), nn.LeakyReLU(0.1, inplace=True), nn.Conv2d(dim // 2 + 1, dim // 2 + 1, 1, 1, 0), ) def forward(self, x): x_conv = self.CBG3x3(x) x_conv = x_conv.to(torch.float32) x_pool = self.avg_pool(x_conv) xc_a = self.xc_aEnhance(x_pool) xc_p = self.xc_pEnhance(x_pool) x_fft = torch.fft.rfft2(x_pool, dim=1, norm="ortho") x_a = torch.abs(x_fft) x_p = torch.angle(x_fft) xa_enh = x_a * xc_a xp_enh = x_p * xc_p xa = xa_enh * torch.cos(xp_enh) xp = xa_enh * torch.sin(xp_enh) x_comp = torch.complex(xa, xp) xc = torch.fft.irfft2(x_comp, dim=1, norm="ortho") x_out = x_conv * xc return x_out + x if __name__ == "__main__": input_tensor = torch.randn(3, 64, 128, 128) fft_rcab = FFTRCAB(64) output_tensor = fft_rcab(input_tensor) print(output_tensor.shape)