blocks_2d.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. """
  2. MRFFI 模块与 WaveletFFTBlock(2D 版本)。
  3. """
  4. import torch
  5. import torch.nn as nn
  6. from typing import Literal
  7. from .attentions_2d import WaveletAttentionGlobalBranch2d
  8. from .layers_2d import (
  9. Conv2dBN,
  10. DWConv2dBNReLU,
  11. DropPath,
  12. FFN2d,
  13. Residual,
  14. )
  15. class WaveletFFTMRFFIModule2d(nn.Module):
  16. def __init__(
  17. self, dim, global_ratio=0.25, local_ratio=0.25,
  18. kernel_size=5, wt_levels=1, wt_type="db1",
  19. wt_mode: Literal["constant", "zero", "reflect", "periodic", "symmetric"] = "zero", proj_drop=0.0,
  20. ):
  21. super().__init__()
  22. self.dim = dim
  23. self.global_channels = min(int(global_ratio * dim), dim)
  24. tentative_local = int(local_ratio * dim)
  25. if self.global_channels + tentative_local > dim:
  26. self.local_channels = max(dim - self.global_channels, 0)
  27. else:
  28. self.local_channels = tentative_local
  29. self.identity_channels = dim - self.global_channels - self.local_channels
  30. if self.global_channels > 0:
  31. self.global_op = WaveletAttentionGlobalBranch2d(
  32. self.global_channels, kernel_size=kernel_size, wt_levels=wt_levels,
  33. wt_type=wt_type, wt_mode=wt_mode, proj_drop=proj_drop,
  34. )
  35. else:
  36. self.global_op = nn.Identity()
  37. if self.local_channels > 0:
  38. self.local_op = DWConv2dBNReLU(self.local_channels, self.local_channels, kernel_size=kernel_size)
  39. else:
  40. self.local_op = nn.Identity()
  41. self.proj = nn.Sequential(
  42. nn.ReLU(inplace=True),
  43. Conv2dBN(dim, dim, bn_weight_init=0.0),
  44. )
  45. def forward(self, x):
  46. x_global, x_local, x_identity = torch.split(
  47. x, [self.global_channels, self.local_channels, self.identity_channels], dim=1,
  48. )
  49. x_global = self.global_op(x_global)
  50. x_local = self.local_op(x_local)
  51. x = torch.cat([x_global, x_local, x_identity], dim=1)
  52. return self.proj(x)
  53. class WaveletFFTBlock2d(nn.Module):
  54. def __init__(
  55. self, dim, global_ratio=0.25, local_ratio=0.25,
  56. kernel_size=5, wt_levels=1, wt_type="db1",
  57. wt_mode: Literal["constant", "zero", "reflect", "periodic", "symmetric"] = "zero",
  58. proj_drop=0.0, drop_path=0.0, has_skip=True,
  59. ):
  60. super().__init__()
  61. self.dw0 = Residual(Conv2dBN(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0.0))
  62. self.ffn0 = Residual(FFN2d(dim, int(dim * 2)))
  63. self.mixer = Residual(
  64. WaveletFFTMRFFIModule2d(
  65. dim, global_ratio=global_ratio, local_ratio=local_ratio,
  66. kernel_size=kernel_size, wt_levels=wt_levels, wt_type=wt_type,
  67. wt_mode=wt_mode, proj_drop=proj_drop,
  68. )
  69. )
  70. self.dw1 = Residual(Conv2dBN(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0.0))
  71. self.ffn1 = Residual(FFN2d(dim, int(dim * 2)))
  72. self.has_skip = has_skip
  73. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  74. def forward(self, x):
  75. shortcut = x
  76. x = self.ffn1(self.dw1(self.mixer(self.ffn0(self.dw0(x)))))
  77. if self.has_skip:
  78. x = shortcut + self.drop_path(x)
  79. return x