model.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. from typing import List, Union
  2. import torch
  3. from monai.networks.nets import SwinUNETR
  4. from torch import Tensor
  5. from torch import nn
  6. from lib.modules.awae import AdaptiveWaveletAugmentedEnhancer
  7. from lib.modules.frcab import FFTRCAB
  8. class Wavelet_FFT_SwinUNETR(SwinUNETR):
  9. """SwinUNETR with Wavelet and FFT enhancement modules."""
  10. def __init__(
  11. self,
  12. in_channels=3,
  13. out_channels=2,
  14. patch_size=2,
  15. depths=(2, 2, 2, 2),
  16. num_heads=(3, 6, 12, 24),
  17. window_size=7,
  18. qkv_bias=True,
  19. mlp_ratio=4.0,
  20. feature_size=48,
  21. norm_name="instance",
  22. drop_rate=0.0,
  23. attn_drop_rate=0.0,
  24. dropout_path_rate=0.0,
  25. normalize=True,
  26. norm_layer=nn.LayerNorm,
  27. patch_norm=False,
  28. use_checkpoint=False,
  29. spatial_dims=2,
  30. downsample="merging",
  31. use_v2=True,
  32. wavelet_enhancement=True,
  33. wavelet_J=2,
  34. wavelet_wave="db4",
  35. wavelet_mode="symmetric",
  36. wavelet_reduction=16,
  37. fft_enhancement=True,
  38. ):
  39. """Initialize model.
  40. Args:
  41. in_channels: Number of input channels
  42. out_channels: Number of output channels
  43. feature_size: Base feature dimension
  44. spatial_dims: Spatial dimensions (2D or 3D)
  45. wavelet_enhancement: Enable wavelet enhancement module
  46. wavelet_J: Wavelet decomposition levels
  47. wavelet_wave: Wavelet basis type
  48. wavelet_mode: Wavelet mode
  49. wavelet_reduction: Wavelet attention compression ratio
  50. fft_enhancement: Enable FFT enhancement module
  51. """
  52. # Initialize parent SwinUNETR class
  53. super().__init__(
  54. in_channels=in_channels,
  55. out_channels=out_channels,
  56. patch_size=patch_size,
  57. feature_size=feature_size,
  58. depths=depths,
  59. num_heads=num_heads,
  60. window_size=window_size,
  61. qkv_bias=qkv_bias,
  62. mlp_ratio=mlp_ratio,
  63. norm_name=norm_name,
  64. drop_rate=drop_rate,
  65. attn_drop_rate=attn_drop_rate,
  66. dropout_path_rate=dropout_path_rate,
  67. normalize=normalize,
  68. norm_layer=norm_layer,
  69. patch_norm=patch_norm,
  70. use_checkpoint=use_checkpoint,
  71. spatial_dims=spatial_dims,
  72. downsample=downsample,
  73. use_v2=use_v2,
  74. )
  75. # Initialize wavelet enhancement modules for different feature levels
  76. # Initialize wavelet enhancement modules for different feature levels
  77. self.wavelet_enhancement = wavelet_enhancement
  78. if self.wavelet_enhancement:
  79. self.wavelet_enhancer_f0 = AdaptiveWaveletAugmentedEnhancer(
  80. in_channels=feature_size,
  81. J=wavelet_J,
  82. wave=wavelet_wave,
  83. mode=wavelet_mode,
  84. reduction_ratio=wavelet_reduction,
  85. )
  86. self.wavelet_enhancer_f1 = AdaptiveWaveletAugmentedEnhancer(
  87. in_channels=feature_size,
  88. J=wavelet_J,
  89. wave=wavelet_wave,
  90. mode=wavelet_mode,
  91. reduction_ratio=wavelet_reduction,
  92. )
  93. self.wavelet_enhancer_f2 = AdaptiveWaveletAugmentedEnhancer(
  94. in_channels=2 * feature_size,
  95. J=wavelet_J,
  96. wave=wavelet_wave,
  97. mode=wavelet_mode,
  98. reduction_ratio=wavelet_reduction,
  99. )
  100. self.wavelet_enhancer_f3 = AdaptiveWaveletAugmentedEnhancer(
  101. in_channels=4 * feature_size,
  102. J=wavelet_J,
  103. wave=wavelet_wave,
  104. mode=wavelet_mode,
  105. reduction_ratio=wavelet_reduction,
  106. )
  107. self.wavelet_enhancer_bottleneck = AdaptiveWaveletAugmentedEnhancer(
  108. in_channels=16 * feature_size,
  109. J=1,
  110. wave=wavelet_wave,
  111. mode=wavelet_mode,
  112. reduction_ratio=wavelet_reduction,
  113. )
  114. # Initialize FFT enhancement modules for different decoder levels
  115. # Initialize FFT enhancement modules for different decoder levels
  116. self.fft_enhancement = fft_enhancement
  117. if self.fft_enhancement:
  118. self.fft_enhancer_f1 = FFTRCAB(feature_size)
  119. self.fft_enhancer_f2 = FFTRCAB(2 * feature_size)
  120. self.fft_enhancer_f3 = FFTRCAB(4 * feature_size)
  121. self.fft_enhancer_f4 = FFTRCAB(8 * feature_size)
  122. def forward(self, x: Tensor) -> Union[Tensor, List[Tensor]]:
  123. """Forward pass.
  124. Args:
  125. x: Input tensor [B, C, H, W]
  126. Returns:
  127. Output logits
  128. """
  129. # Check input size compatibility
  130. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  131. self._check_input_size(x.shape[2:])
  132. # Extract multiscale features from Swin ViT backbone
  133. hidden_states_out = self.swinViT(x, self.normalize)
  134. if self.wavelet_enhancement:
  135. enc0 = self.wavelet_enhancer_f0(self.encoder1(x))
  136. enc1 = self.wavelet_enhancer_f1(self.encoder2(hidden_states_out[0]))
  137. enc2 = self.wavelet_enhancer_f2(self.encoder3(hidden_states_out[1]))
  138. enc3 = self.wavelet_enhancer_f3(self.encoder4(hidden_states_out[2]))
  139. else:
  140. # Use standard encoder features without enhancement
  141. enc0 = self.encoder1(x)
  142. enc1 = self.encoder2(hidden_states_out[0])
  143. enc2 = self.encoder3(hidden_states_out[1])
  144. enc3 = self.encoder4(hidden_states_out[2])
  145. # Process bottleneck features
  146. dec4 = self.encoder10(hidden_states_out[4])
  147. if self.wavelet_enhancement:
  148. dec4 = self.wavelet_enhancer_bottleneck(dec4)
  149. if self.fft_enhancement:
  150. dec3 = self.decoder5(dec4, hidden_states_out[3])
  151. dec3 = self.fft_enhancer_f4(dec3)
  152. dec2 = self.decoder4(dec3, enc3)
  153. dec2 = self.fft_enhancer_f3(dec2)
  154. dec1 = self.decoder3(dec2, enc2)
  155. dec1 = self.fft_enhancer_f2(dec1)
  156. dec0 = self.decoder2(dec1, enc1)
  157. dec0 = self.fft_enhancer_f1(dec0)
  158. out = self.decoder1(dec0, enc0)
  159. else:
  160. # Standard decoding without FFT enhancement
  161. dec3 = self.decoder5(dec4, hidden_states_out[3])
  162. dec2 = self.decoder4(dec3, enc3)
  163. dec1 = self.decoder3(dec2, enc2)
  164. dec0 = self.decoder2(dec1, enc1)
  165. out = self.decoder1(dec0, enc0)
  166. # Generate final output logits
  167. logits = self.out(out)
  168. return logits
  169. if __name__ == "__main__":
  170. image = torch.randn(1, 3, 512, 512)
  171. model = Wavelet_FFT_SwinUNETR(
  172. in_channels=3,
  173. out_channels=1,
  174. patch_size=2,
  175. depths=(2, 2, 2, 2),
  176. num_heads=(3, 6, 12, 24),
  177. window_size=7,
  178. qkv_bias=True,
  179. mlp_ratio=4.0,
  180. feature_size=48,
  181. norm_name="instance",
  182. drop_rate=0.0,
  183. attn_drop_rate=0.0,
  184. dropout_path_rate=0.0,
  185. normalize=True,
  186. norm_layer=nn.LayerNorm,
  187. patch_norm=False,
  188. use_checkpoint=False,
  189. spatial_dims=2,
  190. downsample="merging",
  191. use_v2=True,
  192. wavelet_enhancement=True,
  193. wavelet_J=2,
  194. wavelet_wave="db4",
  195. wavelet_mode="symmetric",
  196. wavelet_reduction=16,
  197. fft_enhancement=True,
  198. )
  199. hidden_states_out = model.swinViT(image, normalize=True)
  200. print("hidden_states_out:", [i.shape for i in hidden_states_out])
  201. print(model(image).shape)
  202. print("total parameters: ", sum(p.numel() for p in model.parameters()))