nets_2d.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. """
  2. WaveletFFTNet(2D 版本)。
  3. """
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from typing import Literal
  8. from .blocks_2d import WaveletFFTBlock2d
  9. from .layers_2d import (
  10. BNLinear1d,
  11. Conv2dBN,
  12. FFN2d,
  13. PatchMerging2d,
  14. Residual,
  15. )
  16. class WaveletFFTNet2d(nn.Module):
  17. def __init__(
  18. self, img_size=224, in_chans=3, num_classes=1000,
  19. embed_dim=(192, 384, 448), global_ratio=(0.8, 0.7, 0.6),
  20. local_ratio=(0.2, 0.2, 0.3), depth=(1, 2, 2),
  21. kernels=(7, 5, 3), down_ops=(("subsample", 2), ("subsample", 2), ("",)),
  22. distillation=False, drop_path=0.0, wt_levels=1,
  23. wt_type="db1", wt_mode: Literal["constant", "zero", "reflect", "periodic", "symmetric"] = "zero",
  24. proj_drop=0.0,
  25. ):
  26. super().__init__()
  27. self.img_size = img_size
  28. self.patch_embed = nn.Sequential(
  29. Conv2dBN(in_chans, embed_dim[0] // 8, 3, 2, 1),
  30. nn.ReLU(inplace=True),
  31. Conv2dBN(embed_dim[0] // 8, embed_dim[0] // 4, 3, 2, 1),
  32. nn.ReLU(inplace=True),
  33. Conv2dBN(embed_dim[0] // 4, embed_dim[0] // 2, 3, 2, 1),
  34. nn.ReLU(inplace=True),
  35. Conv2dBN(embed_dim[0] // 2, embed_dim[0], 3, 2, 1),
  36. )
  37. stages = [[], [], []]
  38. dprs = [x.item() for x in torch.linspace(0, drop_path, sum(depth))]
  39. for stage_idx, (ed, dpth, gr, lr, down_op, kernel) in enumerate(
  40. zip(embed_dim, depth, global_ratio, local_ratio, down_ops, kernels)
  41. ):
  42. start = sum(depth[:stage_idx])
  43. stage_drop = dprs[start: start + dpth]
  44. for block_idx in range(dpth):
  45. stages[stage_idx].append(
  46. WaveletFFTBlock2d(
  47. ed, global_ratio=gr, local_ratio=lr, kernel_size=kernel,
  48. wt_levels=wt_levels, wt_type=wt_type, wt_mode=wt_mode,
  49. proj_drop=proj_drop, drop_path=stage_drop[block_idx],
  50. )
  51. )
  52. if stage_idx < len(embed_dim) - 1 and down_op[0] == "subsample":
  53. stages[stage_idx + 1].append(
  54. nn.Sequential(
  55. Residual(
  56. Conv2dBN(embed_dim[stage_idx], embed_dim[stage_idx], 3, 1, 1, groups=embed_dim[stage_idx])),
  57. Residual(FFN2d(embed_dim[stage_idx], int(embed_dim[stage_idx] * 2))),
  58. )
  59. )
  60. stages[stage_idx + 1].append(PatchMerging2d(embed_dim[stage_idx], embed_dim[stage_idx + 1]))
  61. stages[stage_idx + 1].append(
  62. nn.Sequential(
  63. Residual(Conv2dBN(embed_dim[stage_idx + 1], embed_dim[stage_idx + 1], 3, 1, 1,
  64. groups=embed_dim[stage_idx + 1])),
  65. Residual(FFN2d(embed_dim[stage_idx + 1], int(embed_dim[stage_idx + 1] * 2))),
  66. )
  67. )
  68. self.blocks1 = nn.Sequential(*stages[0])
  69. self.blocks2 = nn.Sequential(*stages[1])
  70. self.blocks3 = nn.Sequential(*stages[2])
  71. self.head = BNLinear1d(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
  72. self.distillation = distillation
  73. if distillation:
  74. self.head_dist = BNLinear1d(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
  75. def forward_features(self, x):
  76. x = self.patch_embed(x)
  77. x = self.blocks1(x)
  78. x = self.blocks2(x)
  79. x = self.blocks3(x)
  80. return F.adaptive_avg_pool2d(x, 1).flatten(1)
  81. def forward(self, x):
  82. x = self.forward_features(x)
  83. if self.distillation:
  84. x = self.head(x), self.head_dist(x)
  85. if not self.training:
  86. x = (x[0] + x[1]) / 2
  87. return x
  88. return self.head(x)
  89. CFG_WAVELET_FFT_T2 = {
  90. "img_size": 192, "embed_dim": (144, 272, 368), "depth": (1, 2, 2),
  91. "global_ratio": (0.8, 0.7, 0.6), "local_ratio": (0.2, 0.2, 0.3),
  92. "kernels": (7, 5, 3), "drop_path": 0.0,
  93. }
  94. CFG_WAVELET_FFT_T4 = {
  95. "img_size": 192, "embed_dim": (176, 368, 448), "depth": (1, 2, 2),
  96. "global_ratio": (0.8, 0.7, 0.6), "local_ratio": (0.2, 0.2, 0.3),
  97. "kernels": (7, 5, 3), "drop_path": 0.0,
  98. }
  99. CFG_WAVELET_FFT_S6 = {
  100. "img_size": 224, "embed_dim": (192, 384, 448), "depth": (1, 2, 2),
  101. "global_ratio": (0.8, 0.7, 0.6), "local_ratio": (0.2, 0.2, 0.3),
  102. "kernels": (7, 5, 3), "drop_path": 0.0,
  103. }
  104. CFG_WAVELET_FFT_B1 = {
  105. "img_size": 256, "embed_dim": (200, 376, 448), "depth": (2, 3, 2),
  106. "global_ratio": (0.8, 0.7, 0.6), "local_ratio": (0.2, 0.2, 0.3),
  107. "kernels": (7, 5, 3), "drop_path": 0.03,
  108. }
  109. CFG_WAVELET_FFT_B2 = {
  110. "img_size": 384, "embed_dim": (200, 376, 448), "depth": (2, 3, 2),
  111. "global_ratio": (0.8, 0.7, 0.6), "local_ratio": (0.2, 0.2, 0.3),
  112. "kernels": (7, 5, 3), "drop_path": 0.03,
  113. }
  114. CFG_WAVELET_FFT_B4 = {
  115. "img_size": 512, "embed_dim": (200, 376, 448), "depth": (2, 3, 2),
  116. "global_ratio": (0.8, 0.7, 0.6), "local_ratio": (0.2, 0.2, 0.3),
  117. "kernels": (7, 5, 3), "drop_path": 0.03,
  118. }
  119. def _build_model(model_cfg, **kwargs):
  120. cfg = dict(model_cfg)
  121. cfg.update(kwargs)
  122. return WaveletFFTNet2d(**cfg)
  123. def wavelet_fft_t2(**kwargs):
  124. return _build_model(CFG_WAVELET_FFT_T2, **kwargs)
  125. def wavelet_fft_t4(**kwargs):
  126. return _build_model(CFG_WAVELET_FFT_T4, **kwargs)
  127. def wavelet_fft_s6(**kwargs):
  128. return _build_model(CFG_WAVELET_FFT_S6, **kwargs)
  129. def wavelet_fft_b1(**kwargs):
  130. return _build_model(CFG_WAVELET_FFT_B1, **kwargs)
  131. def wavelet_fft_b2(**kwargs):
  132. return _build_model(CFG_WAVELET_FFT_B2, **kwargs)
  133. def wavelet_fft_b4(**kwargs):
  134. return _build_model(CFG_WAVELET_FFT_B4, **kwargs)