| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- """
- Circulant Attention 2D.
- 核心思想: 自注意力矩阵近似 BC CB 结构,通过 2D FFT 在 O(N log N) 时间内计算。
- """
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from typing import Literal
- try:
- import ptwt
- except ImportError as exc:
- raise ImportError(
- "wavelet_fft requires ptwt. Install it before importing this package."
- ) from exc
- from .layers_2d import Scale
- class ComplexLinear(nn.Linear):
- def __init__(self, in_features, out_features, device=None, dtype=None):
- super().__init__(in_features, out_features, bias=False, device=device, dtype=dtype)
- def forward(self, inp):
- x = torch.view_as_real(inp).transpose(-2, -1)
- x = F.linear(x, self.weight).transpose(-2, -1)
- if x.dtype != torch.float32:
- x = x.to(torch.float32)
- return torch.view_as_complex(x.contiguous())
- class CirculantAttention2d(nn.Module):
- def __init__(self, dim, proj_drop=0.0):
- super().__init__()
- self.qkv = ComplexLinear(dim, dim * 3)
- self.gate = nn.Sequential(nn.Linear(dim, dim), nn.SiLU())
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
- def forward(self, x):
- b, c, h, w = x.shape
- spatial_perm = [0, 2, 3, 1]
- spatial = x.permute(spatial_perm).contiguous()
- gate = self.gate(spatial.reshape(b, h * w, c)).reshape(b, h, w, c)
- freq = torch.fft.rfft2(spatial, dim=(1, 2), norm="ortho")
- qkv = self.qkv(freq)
- q, k, v = torch.chunk(qkv, chunks=3, dim=-1)
- attn = torch.conj(q) * k
- attn = torch.fft.irfft2(attn, s=(h, w), dim=(1, 2), norm="ortho")
- attn = attn.reshape(b, h * w, c).softmax(dim=1).reshape(b, h, w, c)
- attn = torch.fft.rfft2(attn, dim=(1, 2))
- out = torch.conj(attn) * v
- out = torch.fft.irfft2(out, s=(h, w), dim=(1, 2), norm="ortho")
- out = out.reshape(b, h * w, c) * gate.reshape(b, h * w, c)
- out = self.proj_drop(self.proj(out))
- return out.transpose(1, 2).reshape(b, c, h, w)
- class WaveletAttentionGlobalBranch2d(nn.Module):
- def __init__(
- self, in_channels, kernel_size=5, stride=1, wt_levels=1,
- wt_type="db1", wt_mode: Literal["constant", "zero", "reflect", "periodic", "symmetric"] = "zero",
- proj_drop=0.0,
- ):
- super().__init__()
- if in_channels <= 0:
- raise ValueError("in_channels must be positive.")
- self.in_channels = in_channels
- self.wt_levels = wt_levels
- self.stride = stride
- self.wavelet = wt_type
- self.wt_mode = wt_mode
- self.global_attn = CirculantAttention2d(in_channels, proj_drop=proj_drop)
- self.base_scale = Scale((1, in_channels, 1, 1))
- self.wavelet_convs = nn.ModuleList([
- nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size, 1,
- kernel_size // 2, groups=in_channels * 4, bias=False)
- for _ in range(wt_levels)
- ])
- self.wavelet_scale = nn.ModuleList([
- Scale((1, in_channels * 4, 1, 1), init_scale=0.1)
- for _ in range(wt_levels)
- ])
- if stride > 1:
- self.register_buffer("stride_filter", torch.ones(in_channels, 1, 1, 1), persistent=False)
- else:
- self.stride_filter = None
- def forward(self, x):
- low_levels, high_levels, shapes_in_levels = [], [], []
- curr_low = x
- for level in range(self.wt_levels):
- shapes_in_levels.append(curr_low.shape[-2:])
- coeffs = ptwt.wavedec2(curr_low, self.wavelet, mode=self.wt_mode, level=1)
- low = coeffs[0]
- detail = coeffs[1]
- high = torch.stack([detail.horizontal, detail.vertical, detail.diagonal], dim=2)
- bands = torch.cat([low.unsqueeze(2), high], dim=2)
- b, c, _, h_half, w_half = bands.shape
- bands = bands.reshape(b, c * 4, h_half, w_half)
- bands = self.wavelet_scale[level](self.wavelet_convs[level](bands))
- bands = bands.reshape(b, c, 4, h_half, w_half)
- low_levels.append(bands[:, :, 0, :, :])
- high_levels.append(bands[:, :, 1:4, :, :])
- curr_low = low
- wavelet_out = x
- if self.wt_levels > 0:
- next_low = None
- for level in range(self.wt_levels - 1, -1, -1):
- low = low_levels.pop()
- high = high_levels.pop()
- height, width = shapes_in_levels.pop()
- if next_low is not None:
- low = low + next_low
- cH, cV, cD = high.unbind(dim=2)
- next_low = ptwt.waverec2((low, ptwt.constants.WaveletDetailTuple2d(cH, cV, cD)), self.wavelet)
- next_low = next_low[:, :, :height, :width]
- wavelet_out = next_low
- out = self.base_scale(self.global_attn(x)) + wavelet_out
- if self.stride_filter is not None:
- out = F.conv2d(out, self.stride_filter, stride=self.stride, groups=self.in_channels)
- return out
|