""" Circulant Attention 2D. 核心思想: 自注意力矩阵近似 BC CB 结构,通过 2D FFT 在 O(N log N) 时间内计算。 """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Literal try: import ptwt except ImportError as exc: raise ImportError( "wavelet_fft requires ptwt. Install it before importing this package." ) from exc from .layers_2d import Scale class ComplexLinear(nn.Linear): def __init__(self, in_features, out_features, device=None, dtype=None): super().__init__(in_features, out_features, bias=False, device=device, dtype=dtype) def forward(self, inp): x = torch.view_as_real(inp).transpose(-2, -1) x = F.linear(x, self.weight).transpose(-2, -1) if x.dtype != torch.float32: x = x.to(torch.float32) return torch.view_as_complex(x.contiguous()) class CirculantAttention2d(nn.Module): def __init__(self, dim, proj_drop=0.0): super().__init__() self.qkv = ComplexLinear(dim, dim * 3) self.gate = nn.Sequential(nn.Linear(dim, dim), nn.SiLU()) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): b, c, h, w = x.shape spatial_perm = [0, 2, 3, 1] spatial = x.permute(spatial_perm).contiguous() gate = self.gate(spatial.reshape(b, h * w, c)).reshape(b, h, w, c) freq = torch.fft.rfft2(spatial, dim=(1, 2), norm="ortho") qkv = self.qkv(freq) q, k, v = torch.chunk(qkv, chunks=3, dim=-1) attn = torch.conj(q) * k attn = torch.fft.irfft2(attn, s=(h, w), dim=(1, 2), norm="ortho") attn = attn.reshape(b, h * w, c).softmax(dim=1).reshape(b, h, w, c) attn = torch.fft.rfft2(attn, dim=(1, 2)) out = torch.conj(attn) * v out = torch.fft.irfft2(out, s=(h, w), dim=(1, 2), norm="ortho") out = out.reshape(b, h * w, c) * gate.reshape(b, h * w, c) out = self.proj_drop(self.proj(out)) return out.transpose(1, 2).reshape(b, c, h, w) class WaveletAttentionGlobalBranch2d(nn.Module): def __init__( self, in_channels, kernel_size=5, stride=1, wt_levels=1, wt_type="db1", wt_mode: Literal["constant", "zero", "reflect", "periodic", "symmetric"] = "zero", proj_drop=0.0, ): super().__init__() if in_channels <= 0: raise ValueError("in_channels must be positive.") self.in_channels = in_channels self.wt_levels = wt_levels self.stride = stride self.wavelet = wt_type self.wt_mode = wt_mode self.global_attn = CirculantAttention2d(in_channels, proj_drop=proj_drop) self.base_scale = Scale((1, in_channels, 1, 1)) self.wavelet_convs = nn.ModuleList([ nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size, 1, kernel_size // 2, groups=in_channels * 4, bias=False) for _ in range(wt_levels) ]) self.wavelet_scale = nn.ModuleList([ Scale((1, in_channels * 4, 1, 1), init_scale=0.1) for _ in range(wt_levels) ]) if stride > 1: self.register_buffer("stride_filter", torch.ones(in_channels, 1, 1, 1), persistent=False) else: self.stride_filter = None def forward(self, x): low_levels, high_levels, shapes_in_levels = [], [], [] curr_low = x for level in range(self.wt_levels): shapes_in_levels.append(curr_low.shape[-2:]) coeffs = ptwt.wavedec2(curr_low, self.wavelet, mode=self.wt_mode, level=1) low = coeffs[0] detail = coeffs[1] high = torch.stack([detail.horizontal, detail.vertical, detail.diagonal], dim=2) bands = torch.cat([low.unsqueeze(2), high], dim=2) b, c, _, h_half, w_half = bands.shape bands = bands.reshape(b, c * 4, h_half, w_half) bands = self.wavelet_scale[level](self.wavelet_convs[level](bands)) bands = bands.reshape(b, c, 4, h_half, w_half) low_levels.append(bands[:, :, 0, :, :]) high_levels.append(bands[:, :, 1:4, :, :]) curr_low = low wavelet_out = x if self.wt_levels > 0: next_low = None for level in range(self.wt_levels - 1, -1, -1): low = low_levels.pop() high = high_levels.pop() height, width = shapes_in_levels.pop() if next_low is not None: low = low + next_low cH, cV, cD = high.unbind(dim=2) next_low = ptwt.waverec2((low, ptwt.constants.WaveletDetailTuple2d(cH, cV, cD)), self.wavelet) next_low = next_low[:, :, :height, :width] wavelet_out = next_low out = self.base_scale(self.global_attn(x)) + wavelet_out if self.stride_filter is not None: out = F.conv2d(out, self.stride_filter, stride=self.stride, groups=self.in_channels) return out