""" WaveletFFTNet(2D 版本)。 """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Literal from .blocks_2d import WaveletFFTBlock2d from .layers_2d import ( BNLinear1d, Conv2dBN, FFN2d, PatchMerging2d, Residual, ) class WaveletFFTNet2d(nn.Module): def __init__( self, img_size=224, in_chans=3, num_classes=1000, embed_dim=(192, 384, 448), global_ratio=(0.8, 0.7, 0.6), local_ratio=(0.2, 0.2, 0.3), depth=(1, 2, 2), kernels=(7, 5, 3), down_ops=(("subsample", 2), ("subsample", 2), ("",)), distillation=False, drop_path=0.0, wt_levels=1, wt_type="db1", wt_mode: Literal["constant", "zero", "reflect", "periodic", "symmetric"] = "zero", proj_drop=0.0, ): super().__init__() self.img_size = img_size self.patch_embed = nn.Sequential( Conv2dBN(in_chans, embed_dim[0] // 8, 3, 2, 1), nn.ReLU(inplace=True), Conv2dBN(embed_dim[0] // 8, embed_dim[0] // 4, 3, 2, 1), nn.ReLU(inplace=True), Conv2dBN(embed_dim[0] // 4, embed_dim[0] // 2, 3, 2, 1), nn.ReLU(inplace=True), Conv2dBN(embed_dim[0] // 2, embed_dim[0], 3, 2, 1), ) stages = [[], [], []] dprs = [x.item() for x in torch.linspace(0, drop_path, sum(depth))] for stage_idx, (ed, dpth, gr, lr, down_op, kernel) in enumerate( zip(embed_dim, depth, global_ratio, local_ratio, down_ops, kernels) ): start = sum(depth[:stage_idx]) stage_drop = dprs[start: start + dpth] for block_idx in range(dpth): stages[stage_idx].append( WaveletFFTBlock2d( ed, global_ratio=gr, local_ratio=lr, kernel_size=kernel, wt_levels=wt_levels, wt_type=wt_type, wt_mode=wt_mode, proj_drop=proj_drop, drop_path=stage_drop[block_idx], ) ) if stage_idx < len(embed_dim) - 1 and down_op[0] == "subsample": stages[stage_idx + 1].append( nn.Sequential( Residual( Conv2dBN(embed_dim[stage_idx], embed_dim[stage_idx], 3, 1, 1, groups=embed_dim[stage_idx])), Residual(FFN2d(embed_dim[stage_idx], int(embed_dim[stage_idx] * 2))), ) ) stages[stage_idx + 1].append(PatchMerging2d(embed_dim[stage_idx], embed_dim[stage_idx + 1])) stages[stage_idx + 1].append( nn.Sequential( Residual(Conv2dBN(embed_dim[stage_idx + 1], embed_dim[stage_idx + 1], 3, 1, 1, groups=embed_dim[stage_idx + 1])), Residual(FFN2d(embed_dim[stage_idx + 1], int(embed_dim[stage_idx + 1] * 2))), ) ) self.blocks1 = nn.Sequential(*stages[0]) self.blocks2 = nn.Sequential(*stages[1]) self.blocks3 = nn.Sequential(*stages[2]) self.head = BNLinear1d(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() self.distillation = distillation if distillation: self.head_dist = BNLinear1d(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.patch_embed(x) x = self.blocks1(x) x = self.blocks2(x) x = self.blocks3(x) return F.adaptive_avg_pool2d(x, 1).flatten(1) def forward(self, x): x = self.forward_features(x) if self.distillation: x = self.head(x), self.head_dist(x) if not self.training: x = (x[0] + x[1]) / 2 return x return self.head(x) CFG_WAVELET_FFT_T2 = { "img_size": 192, "embed_dim": (144, 272, 368), "depth": (1, 2, 2), "global_ratio": (0.8, 0.7, 0.6), "local_ratio": (0.2, 0.2, 0.3), "kernels": (7, 5, 3), "drop_path": 0.0, } CFG_WAVELET_FFT_T4 = { "img_size": 192, "embed_dim": (176, 368, 448), "depth": (1, 2, 2), "global_ratio": (0.8, 0.7, 0.6), "local_ratio": (0.2, 0.2, 0.3), "kernels": (7, 5, 3), "drop_path": 0.0, } CFG_WAVELET_FFT_S6 = { "img_size": 224, "embed_dim": (192, 384, 448), "depth": (1, 2, 2), "global_ratio": (0.8, 0.7, 0.6), "local_ratio": (0.2, 0.2, 0.3), "kernels": (7, 5, 3), "drop_path": 0.0, } CFG_WAVELET_FFT_B1 = { "img_size": 256, "embed_dim": (200, 376, 448), "depth": (2, 3, 2), "global_ratio": (0.8, 0.7, 0.6), "local_ratio": (0.2, 0.2, 0.3), "kernels": (7, 5, 3), "drop_path": 0.03, } CFG_WAVELET_FFT_B2 = { "img_size": 384, "embed_dim": (200, 376, 448), "depth": (2, 3, 2), "global_ratio": (0.8, 0.7, 0.6), "local_ratio": (0.2, 0.2, 0.3), "kernels": (7, 5, 3), "drop_path": 0.03, } CFG_WAVELET_FFT_B4 = { "img_size": 512, "embed_dim": (200, 376, 448), "depth": (2, 3, 2), "global_ratio": (0.8, 0.7, 0.6), "local_ratio": (0.2, 0.2, 0.3), "kernels": (7, 5, 3), "drop_path": 0.03, } def _build_model(model_cfg, **kwargs): cfg = dict(model_cfg) cfg.update(kwargs) return WaveletFFTNet2d(**cfg) def wavelet_fft_t2(**kwargs): return _build_model(CFG_WAVELET_FFT_T2, **kwargs) def wavelet_fft_t4(**kwargs): return _build_model(CFG_WAVELET_FFT_T4, **kwargs) def wavelet_fft_s6(**kwargs): return _build_model(CFG_WAVELET_FFT_S6, **kwargs) def wavelet_fft_b1(**kwargs): return _build_model(CFG_WAVELET_FFT_B1, **kwargs) def wavelet_fft_b2(**kwargs): return _build_model(CFG_WAVELET_FFT_B2, **kwargs) def wavelet_fft_b4(**kwargs): return _build_model(CFG_WAVELET_FFT_B4, **kwargs)