| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990 |
- """
- MRFFI 模块与 WaveletFFTBlock(2D 版本)。
- """
- import torch
- import torch.nn as nn
- from typing import Literal
- from .attentions_2d import WaveletAttentionGlobalBranch2d
- from .layers_2d import (
- Conv2dBN,
- DWConv2dBNReLU,
- DropPath,
- FFN2d,
- Residual,
- )
- class WaveletFFTMRFFIModule2d(nn.Module):
- def __init__(
- self, dim, global_ratio=0.25, local_ratio=0.25,
- kernel_size=5, wt_levels=1, wt_type="db1",
- wt_mode: Literal["constant", "zero", "reflect", "periodic", "symmetric"] = "zero", proj_drop=0.0,
- ):
- super().__init__()
- self.dim = dim
- self.global_channels = min(int(global_ratio * dim), dim)
- tentative_local = int(local_ratio * dim)
- if self.global_channels + tentative_local > dim:
- self.local_channels = max(dim - self.global_channels, 0)
- else:
- self.local_channels = tentative_local
- self.identity_channels = dim - self.global_channels - self.local_channels
- if self.global_channels > 0:
- self.global_op = WaveletAttentionGlobalBranch2d(
- self.global_channels, kernel_size=kernel_size, wt_levels=wt_levels,
- wt_type=wt_type, wt_mode=wt_mode, proj_drop=proj_drop,
- )
- else:
- self.global_op = nn.Identity()
- if self.local_channels > 0:
- self.local_op = DWConv2dBNReLU(self.local_channels, self.local_channels, kernel_size=kernel_size)
- else:
- self.local_op = nn.Identity()
- self.proj = nn.Sequential(
- nn.ReLU(inplace=True),
- Conv2dBN(dim, dim, bn_weight_init=0.0),
- )
- def forward(self, x):
- x_global, x_local, x_identity = torch.split(
- x, [self.global_channels, self.local_channels, self.identity_channels], dim=1,
- )
- x_global = self.global_op(x_global)
- x_local = self.local_op(x_local)
- x = torch.cat([x_global, x_local, x_identity], dim=1)
- return self.proj(x)
- class WaveletFFTBlock2d(nn.Module):
- def __init__(
- self, dim, global_ratio=0.25, local_ratio=0.25,
- kernel_size=5, wt_levels=1, wt_type="db1",
- wt_mode: Literal["constant", "zero", "reflect", "periodic", "symmetric"] = "zero",
- proj_drop=0.0, drop_path=0.0, has_skip=True,
- ):
- super().__init__()
- self.dw0 = Residual(Conv2dBN(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0.0))
- self.ffn0 = Residual(FFN2d(dim, int(dim * 2)))
- self.mixer = Residual(
- WaveletFFTMRFFIModule2d(
- dim, global_ratio=global_ratio, local_ratio=local_ratio,
- kernel_size=kernel_size, wt_levels=wt_levels, wt_type=wt_type,
- wt_mode=wt_mode, proj_drop=proj_drop,
- )
- )
- self.dw1 = Residual(Conv2dBN(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0.0))
- self.ffn1 = Residual(FFN2d(dim, int(dim * 2)))
- self.has_skip = has_skip
- self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
- def forward(self, x):
- shortcut = x
- x = self.ffn1(self.dw1(self.mixer(self.ffn0(self.dw0(x)))))
- if self.has_skip:
- x = shortcut + self.drop_path(x)
- return x
|