attentions_2d.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. """
  2. Circulant Attention 2D.
  3. 核心思想: 自注意力矩阵近似 BC CB 结构,通过 2D FFT 在 O(N log N) 时间内计算。
  4. """
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from typing import Literal
  9. try:
  10. import ptwt
  11. except ImportError as exc:
  12. raise ImportError(
  13. "wavelet_fft requires ptwt. Install it before importing this package."
  14. ) from exc
  15. from .layers_2d import Scale
  16. class ComplexLinear(nn.Linear):
  17. def __init__(self, in_features, out_features, device=None, dtype=None):
  18. super().__init__(in_features, out_features, bias=False, device=device, dtype=dtype)
  19. def forward(self, inp):
  20. x = torch.view_as_real(inp).transpose(-2, -1)
  21. x = F.linear(x, self.weight).transpose(-2, -1)
  22. if x.dtype != torch.float32:
  23. x = x.to(torch.float32)
  24. return torch.view_as_complex(x.contiguous())
  25. class CirculantAttention2d(nn.Module):
  26. def __init__(self, dim, proj_drop=0.0):
  27. super().__init__()
  28. self.qkv = ComplexLinear(dim, dim * 3)
  29. self.gate = nn.Sequential(nn.Linear(dim, dim), nn.SiLU())
  30. self.proj = nn.Linear(dim, dim)
  31. self.proj_drop = nn.Dropout(proj_drop)
  32. def forward(self, x):
  33. b, c, h, w = x.shape
  34. spatial_perm = [0, 2, 3, 1]
  35. spatial = x.permute(spatial_perm).contiguous()
  36. gate = self.gate(spatial.reshape(b, h * w, c)).reshape(b, h, w, c)
  37. freq = torch.fft.rfft2(spatial, dim=(1, 2), norm="ortho")
  38. qkv = self.qkv(freq)
  39. q, k, v = torch.chunk(qkv, chunks=3, dim=-1)
  40. attn = torch.conj(q) * k
  41. attn = torch.fft.irfft2(attn, s=(h, w), dim=(1, 2), norm="ortho")
  42. attn = attn.reshape(b, h * w, c).softmax(dim=1).reshape(b, h, w, c)
  43. attn = torch.fft.rfft2(attn, dim=(1, 2))
  44. out = torch.conj(attn) * v
  45. out = torch.fft.irfft2(out, s=(h, w), dim=(1, 2), norm="ortho")
  46. out = out.reshape(b, h * w, c) * gate.reshape(b, h * w, c)
  47. out = self.proj_drop(self.proj(out))
  48. return out.transpose(1, 2).reshape(b, c, h, w)
  49. class WaveletAttentionGlobalBranch2d(nn.Module):
  50. def __init__(
  51. self, in_channels, kernel_size=5, stride=1, wt_levels=1,
  52. wt_type="db1", wt_mode: Literal["constant", "zero", "reflect", "periodic", "symmetric"] = "zero",
  53. proj_drop=0.0,
  54. ):
  55. super().__init__()
  56. if in_channels <= 0:
  57. raise ValueError("in_channels must be positive.")
  58. self.in_channels = in_channels
  59. self.wt_levels = wt_levels
  60. self.stride = stride
  61. self.wavelet = wt_type
  62. self.wt_mode = wt_mode
  63. self.global_attn = CirculantAttention2d(in_channels, proj_drop=proj_drop)
  64. self.base_scale = Scale((1, in_channels, 1, 1))
  65. self.wavelet_convs = nn.ModuleList([
  66. nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size, 1,
  67. kernel_size // 2, groups=in_channels * 4, bias=False)
  68. for _ in range(wt_levels)
  69. ])
  70. self.wavelet_scale = nn.ModuleList([
  71. Scale((1, in_channels * 4, 1, 1), init_scale=0.1)
  72. for _ in range(wt_levels)
  73. ])
  74. if stride > 1:
  75. self.register_buffer("stride_filter", torch.ones(in_channels, 1, 1, 1), persistent=False)
  76. else:
  77. self.stride_filter = None
  78. def forward(self, x):
  79. low_levels, high_levels, shapes_in_levels = [], [], []
  80. curr_low = x
  81. for level in range(self.wt_levels):
  82. shapes_in_levels.append(curr_low.shape[-2:])
  83. coeffs = ptwt.wavedec2(curr_low, self.wavelet, mode=self.wt_mode, level=1)
  84. low = coeffs[0]
  85. detail = coeffs[1]
  86. high = torch.stack([detail.horizontal, detail.vertical, detail.diagonal], dim=2)
  87. bands = torch.cat([low.unsqueeze(2), high], dim=2)
  88. b, c, _, h_half, w_half = bands.shape
  89. bands = bands.reshape(b, c * 4, h_half, w_half)
  90. bands = self.wavelet_scale[level](self.wavelet_convs[level](bands))
  91. bands = bands.reshape(b, c, 4, h_half, w_half)
  92. low_levels.append(bands[:, :, 0, :, :])
  93. high_levels.append(bands[:, :, 1:4, :, :])
  94. curr_low = low
  95. wavelet_out = x
  96. if self.wt_levels > 0:
  97. next_low = None
  98. for level in range(self.wt_levels - 1, -1, -1):
  99. low = low_levels.pop()
  100. high = high_levels.pop()
  101. height, width = shapes_in_levels.pop()
  102. if next_low is not None:
  103. low = low + next_low
  104. cH, cV, cD = high.unbind(dim=2)
  105. next_low = ptwt.waverec2((low, ptwt.constants.WaveletDetailTuple2d(cH, cV, cD)), self.wavelet)
  106. next_low = next_low[:, :, :height, :width]
  107. wavelet_out = next_low
  108. out = self.base_scale(self.global_attn(x)) + wavelet_out
  109. if self.stride_filter is not None:
  110. out = F.conv2d(out, self.stride_filter, stride=self.stride, groups=self.in_channels)
  111. return out