| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163 |
- """
- WaveletFFTNet(2D 版本)。
- """
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from typing import Literal
- from .blocks_2d import WaveletFFTBlock2d
- from .layers_2d import (
- BNLinear1d,
- Conv2dBN,
- FFN2d,
- PatchMerging2d,
- Residual,
- )
- class WaveletFFTNet2d(nn.Module):
- def __init__(
- self, img_size=224, in_chans=3, num_classes=1000,
- embed_dim=(192, 384, 448), global_ratio=(0.8, 0.7, 0.6),
- local_ratio=(0.2, 0.2, 0.3), depth=(1, 2, 2),
- kernels=(7, 5, 3), down_ops=(("subsample", 2), ("subsample", 2), ("",)),
- distillation=False, drop_path=0.0, wt_levels=1,
- wt_type="db1", wt_mode: Literal["constant", "zero", "reflect", "periodic", "symmetric"] = "zero",
- proj_drop=0.0,
- ):
- super().__init__()
- self.img_size = img_size
- self.patch_embed = nn.Sequential(
- Conv2dBN(in_chans, embed_dim[0] // 8, 3, 2, 1),
- nn.ReLU(inplace=True),
- Conv2dBN(embed_dim[0] // 8, embed_dim[0] // 4, 3, 2, 1),
- nn.ReLU(inplace=True),
- Conv2dBN(embed_dim[0] // 4, embed_dim[0] // 2, 3, 2, 1),
- nn.ReLU(inplace=True),
- Conv2dBN(embed_dim[0] // 2, embed_dim[0], 3, 2, 1),
- )
- stages = [[], [], []]
- dprs = [x.item() for x in torch.linspace(0, drop_path, sum(depth))]
- for stage_idx, (ed, dpth, gr, lr, down_op, kernel) in enumerate(
- zip(embed_dim, depth, global_ratio, local_ratio, down_ops, kernels)
- ):
- start = sum(depth[:stage_idx])
- stage_drop = dprs[start: start + dpth]
- for block_idx in range(dpth):
- stages[stage_idx].append(
- WaveletFFTBlock2d(
- ed, global_ratio=gr, local_ratio=lr, kernel_size=kernel,
- wt_levels=wt_levels, wt_type=wt_type, wt_mode=wt_mode,
- proj_drop=proj_drop, drop_path=stage_drop[block_idx],
- )
- )
- if stage_idx < len(embed_dim) - 1 and down_op[0] == "subsample":
- stages[stage_idx + 1].append(
- nn.Sequential(
- Residual(
- Conv2dBN(embed_dim[stage_idx], embed_dim[stage_idx], 3, 1, 1, groups=embed_dim[stage_idx])),
- Residual(FFN2d(embed_dim[stage_idx], int(embed_dim[stage_idx] * 2))),
- )
- )
- stages[stage_idx + 1].append(PatchMerging2d(embed_dim[stage_idx], embed_dim[stage_idx + 1]))
- stages[stage_idx + 1].append(
- nn.Sequential(
- Residual(Conv2dBN(embed_dim[stage_idx + 1], embed_dim[stage_idx + 1], 3, 1, 1,
- groups=embed_dim[stage_idx + 1])),
- Residual(FFN2d(embed_dim[stage_idx + 1], int(embed_dim[stage_idx + 1] * 2))),
- )
- )
- self.blocks1 = nn.Sequential(*stages[0])
- self.blocks2 = nn.Sequential(*stages[1])
- self.blocks3 = nn.Sequential(*stages[2])
- self.head = BNLinear1d(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
- self.distillation = distillation
- if distillation:
- self.head_dist = BNLinear1d(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
- def forward_features(self, x):
- x = self.patch_embed(x)
- x = self.blocks1(x)
- x = self.blocks2(x)
- x = self.blocks3(x)
- return F.adaptive_avg_pool2d(x, 1).flatten(1)
- def forward(self, x):
- x = self.forward_features(x)
- if self.distillation:
- x = self.head(x), self.head_dist(x)
- if not self.training:
- x = (x[0] + x[1]) / 2
- return x
- return self.head(x)
- CFG_WAVELET_FFT_T2 = {
- "img_size": 192, "embed_dim": (144, 272, 368), "depth": (1, 2, 2),
- "global_ratio": (0.8, 0.7, 0.6), "local_ratio": (0.2, 0.2, 0.3),
- "kernels": (7, 5, 3), "drop_path": 0.0,
- }
- CFG_WAVELET_FFT_T4 = {
- "img_size": 192, "embed_dim": (176, 368, 448), "depth": (1, 2, 2),
- "global_ratio": (0.8, 0.7, 0.6), "local_ratio": (0.2, 0.2, 0.3),
- "kernels": (7, 5, 3), "drop_path": 0.0,
- }
- CFG_WAVELET_FFT_S6 = {
- "img_size": 224, "embed_dim": (192, 384, 448), "depth": (1, 2, 2),
- "global_ratio": (0.8, 0.7, 0.6), "local_ratio": (0.2, 0.2, 0.3),
- "kernels": (7, 5, 3), "drop_path": 0.0,
- }
- CFG_WAVELET_FFT_B1 = {
- "img_size": 256, "embed_dim": (200, 376, 448), "depth": (2, 3, 2),
- "global_ratio": (0.8, 0.7, 0.6), "local_ratio": (0.2, 0.2, 0.3),
- "kernels": (7, 5, 3), "drop_path": 0.03,
- }
- CFG_WAVELET_FFT_B2 = {
- "img_size": 384, "embed_dim": (200, 376, 448), "depth": (2, 3, 2),
- "global_ratio": (0.8, 0.7, 0.6), "local_ratio": (0.2, 0.2, 0.3),
- "kernels": (7, 5, 3), "drop_path": 0.03,
- }
- CFG_WAVELET_FFT_B4 = {
- "img_size": 512, "embed_dim": (200, 376, 448), "depth": (2, 3, 2),
- "global_ratio": (0.8, 0.7, 0.6), "local_ratio": (0.2, 0.2, 0.3),
- "kernels": (7, 5, 3), "drop_path": 0.03,
- }
- def _build_model(model_cfg, **kwargs):
- cfg = dict(model_cfg)
- cfg.update(kwargs)
- return WaveletFFTNet2d(**cfg)
- def wavelet_fft_t2(**kwargs):
- return _build_model(CFG_WAVELET_FFT_T2, **kwargs)
- def wavelet_fft_t4(**kwargs):
- return _build_model(CFG_WAVELET_FFT_T4, **kwargs)
- def wavelet_fft_s6(**kwargs):
- return _build_model(CFG_WAVELET_FFT_S6, **kwargs)
- def wavelet_fft_b1(**kwargs):
- return _build_model(CFG_WAVELET_FFT_B1, **kwargs)
- def wavelet_fft_b2(**kwargs):
- return _build_model(CFG_WAVELET_FFT_B2, **kwargs)
- def wavelet_fft_b4(**kwargs):
- return _build_model(CFG_WAVELET_FFT_B4, **kwargs)
|