attentions_2d.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. """
  2. Circulant Attention 2D.
  3. 核心思想: 自注意力矩阵近似 BC CB 结构,通过 2D FFT 在 O(N log N) 时间内计算。
  4. """
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from typing import Literal
  9. try:
  10. import ptwt
  11. except ImportError as exc:
  12. raise ImportError(
  13. "wavelet_fft requires ptwt. Install it before importing this package."
  14. ) from exc
  15. from .layers_2d import Scale
  16. class ComplexLinear(nn.Linear):
  17. def __init__(self, in_features, out_features, device=None, dtype=None):
  18. super().__init__(
  19. in_features, out_features, bias=False, device=device, dtype=dtype
  20. )
  21. def forward(self, inp):
  22. x = torch.view_as_real(inp).transpose(-2, -1)
  23. x = F.linear(x, self.weight).transpose(-2, -1)
  24. if x.dtype != torch.float32:
  25. x = x.to(torch.float32)
  26. return torch.view_as_complex(x.contiguous())
  27. class CirculantAttention2d(nn.Module):
  28. def __init__(self, dim, proj_drop=0.0):
  29. super().__init__()
  30. self.qkv = ComplexLinear(dim, dim * 3)
  31. self.gate = nn.Sequential(nn.Linear(dim, dim), nn.SiLU())
  32. self.proj = nn.Linear(dim, dim)
  33. self.proj_drop = nn.Dropout(proj_drop)
  34. def forward(self, x):
  35. b, c, h, w = x.shape
  36. spatial_perm = [0, 2, 3, 1]
  37. spatial = x.permute(spatial_perm).contiguous()
  38. gate = self.gate(spatial.reshape(b, h * w, c)).reshape(b, h, w, c)
  39. freq = torch.fft.rfft2(spatial, dim=(1, 2), norm="ortho")
  40. qkv = self.qkv(freq)
  41. q, k, v = torch.chunk(qkv, chunks=3, dim=-1)
  42. attn = torch.conj(q) * k
  43. attn = torch.fft.irfft2(attn, s=(h, w), dim=(1, 2), norm="ortho")
  44. attn = attn.reshape(b, h * w, c).softmax(dim=1).reshape(b, h, w, c)
  45. attn = torch.fft.rfft2(attn, dim=(1, 2))
  46. out = torch.conj(attn) * v
  47. out = torch.fft.irfft2(out, s=(h, w), dim=(1, 2), norm="ortho")
  48. out = out.reshape(b, h * w, c) * gate.reshape(b, h * w, c)
  49. out = self.proj_drop(self.proj(out))
  50. return out.transpose(1, 2).reshape(b, c, h, w)
  51. class WaveletAttentionGlobalBranch2d(nn.Module):
  52. def __init__(
  53. self,
  54. in_channels,
  55. kernel_size=5,
  56. stride=1,
  57. wt_levels=1,
  58. wt_type="db1",
  59. wt_mode: Literal[
  60. "constant", "zero", "reflect", "periodic", "symmetric"
  61. ] = "zero",
  62. proj_drop=0.0,
  63. ):
  64. super().__init__()
  65. if in_channels <= 0:
  66. raise ValueError("in_channels must be positive.")
  67. self.in_channels = in_channels
  68. self.wt_levels = wt_levels
  69. self.stride = stride
  70. self.wavelet = wt_type
  71. self.wt_mode = wt_mode
  72. self.global_attn = CirculantAttention2d(in_channels, proj_drop=proj_drop)
  73. self.base_scale = Scale((1, in_channels, 1, 1))
  74. self.wavelet_convs = nn.ModuleList(
  75. [
  76. nn.Conv2d(
  77. in_channels * 4,
  78. in_channels * 4,
  79. kernel_size,
  80. 1,
  81. kernel_size // 2,
  82. groups=in_channels * 4,
  83. bias=False,
  84. )
  85. for _ in range(wt_levels)
  86. ]
  87. )
  88. self.wavelet_scale = nn.ModuleList(
  89. [
  90. Scale((1, in_channels * 4, 1, 1), init_scale=0.1)
  91. for _ in range(wt_levels)
  92. ]
  93. )
  94. if stride > 1:
  95. self.register_buffer(
  96. "stride_filter", torch.ones(in_channels, 1, 1, 1), persistent=False
  97. )
  98. else:
  99. self.stride_filter = None
  100. def forward(self, x):
  101. low_levels, high_levels, shapes_in_levels = [], [], []
  102. curr_low = x
  103. for level in range(self.wt_levels):
  104. shapes_in_levels.append(curr_low.shape[-2:])
  105. coeffs = ptwt.wavedec2(curr_low, self.wavelet, mode=self.wt_mode, level=1)
  106. low = coeffs[0]
  107. detail = coeffs[1]
  108. high = torch.stack(
  109. [detail.horizontal, detail.vertical, detail.diagonal], dim=2
  110. )
  111. bands = torch.cat([low.unsqueeze(2), high], dim=2)
  112. b, c, _, h_half, w_half = bands.shape
  113. bands = bands.reshape(b, c * 4, h_half, w_half)
  114. bands = self.wavelet_scale[level](self.wavelet_convs[level](bands))
  115. bands = bands.reshape(b, c, 4, h_half, w_half)
  116. low_levels.append(bands[:, :, 0, :, :])
  117. high_levels.append(bands[:, :, 1:4, :, :])
  118. curr_low = low
  119. wavelet_out = x
  120. if self.wt_levels > 0:
  121. next_low = None
  122. for level in range(self.wt_levels - 1, -1, -1):
  123. low = low_levels.pop()
  124. high = high_levels.pop()
  125. height, width = shapes_in_levels.pop()
  126. if next_low is not None:
  127. low = low + next_low
  128. cH, cV, cD = high.unbind(dim=2)
  129. next_low = ptwt.waverec2(
  130. (low, ptwt.constants.WaveletDetailTuple2d(cH, cV, cD)), self.wavelet
  131. )
  132. next_low = next_low[:, :, :height, :width]
  133. wavelet_out = next_low
  134. out = self.base_scale(self.global_attn(x)) + wavelet_out
  135. if self.stride_filter is not None:
  136. out = F.conv2d(
  137. out, self.stride_filter, stride=self.stride, groups=self.in_channels
  138. )
  139. return out