frcab.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import torch
  2. import torch.nn as nn
  3. class FFTRCAB(nn.Module):
  4. def __init__(self, dim):
  5. super(FFTRCAB, self).__init__()
  6. self.CBG3x3 = nn.Sequential(
  7. nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=False),
  8. nn.LeakyReLU(0.1, inplace=True),
  9. nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=False),
  10. )
  11. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  12. self.xc_aEnhance = nn.Sequential(
  13. nn.Conv2d(dim, dim // 2 + 1, 1, 1, 0),
  14. nn.LeakyReLU(0.1, inplace=True),
  15. nn.Conv2d(dim // 2 + 1, dim // 2 + 1, 1, 1, 0),
  16. )
  17. self.xc_pEnhance = nn.Sequential(
  18. nn.Conv2d(dim, dim // 2 + 1, 1, 1, 0),
  19. nn.LeakyReLU(0.1, inplace=True),
  20. nn.Conv2d(dim // 2 + 1, dim // 2 + 1, 1, 1, 0),
  21. )
  22. def forward(self, x):
  23. x_conv = self.CBG3x3(x)
  24. x_conv = x_conv.to(torch.float32)
  25. x_pool = self.avg_pool(x_conv)
  26. xc_a = self.xc_aEnhance(x_pool)
  27. xc_p = self.xc_pEnhance(x_pool)
  28. x_fft = torch.fft.rfft2(x_pool, dim=1, norm="ortho")
  29. x_a = torch.abs(x_fft)
  30. x_p = torch.angle(x_fft)
  31. xa_enh = x_a * xc_a
  32. xp_enh = x_p * xc_p
  33. xa = xa_enh * torch.cos(xp_enh)
  34. xp = xa_enh * torch.sin(xp_enh)
  35. x_comp = torch.complex(xa, xp)
  36. xc = torch.fft.irfft2(x_comp, dim=1, norm="ortho")
  37. x_out = x_conv * xc
  38. return x_out + x
  39. if __name__ == "__main__":
  40. input_tensor = torch.randn(3, 64, 128, 128)
  41. fft_rcab = FFTRCAB(64)
  42. output_tensor = fft_rcab(input_tensor)
  43. print(output_tensor.shape)