| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- import torch
- import torch.nn as nn
- class FFTRCAB(nn.Module):
- def __init__(self, dim):
- super(FFTRCAB, self).__init__()
- self.CBG3x3 = nn.Sequential(
- nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=False),
- nn.LeakyReLU(0.1, inplace=True),
- nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=False),
- )
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
- self.xc_aEnhance = nn.Sequential(
- nn.Conv2d(dim, dim // 2 + 1, 1, 1, 0),
- nn.LeakyReLU(0.1, inplace=True),
- nn.Conv2d(dim // 2 + 1, dim // 2 + 1, 1, 1, 0),
- )
- self.xc_pEnhance = nn.Sequential(
- nn.Conv2d(dim, dim // 2 + 1, 1, 1, 0),
- nn.LeakyReLU(0.1, inplace=True),
- nn.Conv2d(dim // 2 + 1, dim // 2 + 1, 1, 1, 0),
- )
- def forward(self, x):
- x_conv = self.CBG3x3(x)
- x_conv = x_conv.to(torch.float32)
- x_pool = self.avg_pool(x_conv)
- xc_a = self.xc_aEnhance(x_pool)
- xc_p = self.xc_pEnhance(x_pool)
- x_fft = torch.fft.rfft2(x_pool, dim=1, norm="ortho")
- x_a = torch.abs(x_fft)
- x_p = torch.angle(x_fft)
- xa_enh = x_a * xc_a
- xp_enh = x_p * xc_p
- xa = xa_enh * torch.cos(xp_enh)
- xp = xa_enh * torch.sin(xp_enh)
- x_comp = torch.complex(xa, xp)
- xc = torch.fft.irfft2(x_comp, dim=1, norm="ortho")
- x_out = x_conv * xc
- return x_out + x
- if __name__ == "__main__":
- input_tensor = torch.randn(3, 64, 128, 128)
- fft_rcab = FFTRCAB(64)
- output_tensor = fft_rcab(input_tensor)
- print(output_tensor.shape)
|