| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232 |
- from typing import List, Union
- import torch
- from monai.networks.nets import SwinUNETR
- from torch import Tensor
- from torch import nn
- from lib.modules.awae import AdaptiveWaveletAugmentedEnhancer
- from lib.modules.frcab import FFTRCAB
- class Wavelet_FFT_SwinUNETR(SwinUNETR):
- """SwinUNETR with Wavelet and FFT enhancement modules."""
-
- def __init__(
- self,
- in_channels=3,
- out_channels=2,
- patch_size=2,
- depths=(2, 2, 2, 2),
- num_heads=(3, 6, 12, 24),
- window_size=7,
- qkv_bias=True,
- mlp_ratio=4.0,
- feature_size=48,
- norm_name="instance",
- drop_rate=0.0,
- attn_drop_rate=0.0,
- dropout_path_rate=0.0,
- normalize=True,
- norm_layer=nn.LayerNorm,
- patch_norm=False,
- use_checkpoint=False,
- spatial_dims=2,
- downsample="merging",
- use_v2=True,
- wavelet_enhancement=True,
- wavelet_J=2,
- wavelet_wave="db4",
- wavelet_mode="symmetric",
- wavelet_reduction=16,
- fft_enhancement=True,
- ):
- """Initialize model.
-
- Args:
- in_channels: Number of input channels
- out_channels: Number of output channels
- feature_size: Base feature dimension
- spatial_dims: Spatial dimensions (2D or 3D)
- wavelet_enhancement: Enable wavelet enhancement module
- wavelet_J: Wavelet decomposition levels
- wavelet_wave: Wavelet basis type
- wavelet_mode: Wavelet mode
- wavelet_reduction: Wavelet attention compression ratio
- fft_enhancement: Enable FFT enhancement module
- """
- # Initialize parent SwinUNETR class
- super().__init__(
- in_channels=in_channels,
- out_channels=out_channels,
- patch_size=patch_size,
- feature_size=feature_size,
- depths=depths,
- num_heads=num_heads,
- window_size=window_size,
- qkv_bias=qkv_bias,
- mlp_ratio=mlp_ratio,
- norm_name=norm_name,
- drop_rate=drop_rate,
- attn_drop_rate=attn_drop_rate,
- dropout_path_rate=dropout_path_rate,
- normalize=normalize,
- norm_layer=norm_layer,
- patch_norm=patch_norm,
- use_checkpoint=use_checkpoint,
- spatial_dims=spatial_dims,
- downsample=downsample,
- use_v2=use_v2,
- )
- # Initialize wavelet enhancement modules for different feature levels
- # Initialize wavelet enhancement modules for different feature levels
- self.wavelet_enhancement = wavelet_enhancement
- if self.wavelet_enhancement:
- self.wavelet_enhancer_f0 = AdaptiveWaveletAugmentedEnhancer(
- in_channels=feature_size,
- J=wavelet_J,
- wave=wavelet_wave,
- mode=wavelet_mode,
- reduction_ratio=wavelet_reduction,
- )
- self.wavelet_enhancer_f1 = AdaptiveWaveletAugmentedEnhancer(
- in_channels=feature_size,
- J=wavelet_J,
- wave=wavelet_wave,
- mode=wavelet_mode,
- reduction_ratio=wavelet_reduction,
- )
- self.wavelet_enhancer_f2 = AdaptiveWaveletAugmentedEnhancer(
- in_channels=2 * feature_size,
- J=wavelet_J,
- wave=wavelet_wave,
- mode=wavelet_mode,
- reduction_ratio=wavelet_reduction,
- )
- self.wavelet_enhancer_f3 = AdaptiveWaveletAugmentedEnhancer(
- in_channels=4 * feature_size,
- J=wavelet_J,
- wave=wavelet_wave,
- mode=wavelet_mode,
- reduction_ratio=wavelet_reduction,
- )
- self.wavelet_enhancer_bottleneck = AdaptiveWaveletAugmentedEnhancer(
- in_channels=16 * feature_size,
- J=1,
- wave=wavelet_wave,
- mode=wavelet_mode,
- reduction_ratio=wavelet_reduction,
- )
- # Initialize FFT enhancement modules for different decoder levels
- # Initialize FFT enhancement modules for different decoder levels
- self.fft_enhancement = fft_enhancement
- if self.fft_enhancement:
- self.fft_enhancer_f1 = FFTRCAB(feature_size)
- self.fft_enhancer_f2 = FFTRCAB(2 * feature_size)
- self.fft_enhancer_f3 = FFTRCAB(4 * feature_size)
- self.fft_enhancer_f4 = FFTRCAB(8 * feature_size)
- def forward(self, x: Tensor) -> Union[Tensor, List[Tensor]]:
- """Forward pass.
-
- Args:
- x: Input tensor [B, C, H, W]
-
- Returns:
- Output logits
- """
- # Check input size compatibility
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- self._check_input_size(x.shape[2:])
- # Extract multiscale features from Swin ViT backbone
- hidden_states_out = self.swinViT(x, self.normalize)
- if self.wavelet_enhancement:
- enc0 = self.wavelet_enhancer_f0(self.encoder1(x))
- enc1 = self.wavelet_enhancer_f1(self.encoder2(hidden_states_out[0]))
- enc2 = self.wavelet_enhancer_f2(self.encoder3(hidden_states_out[1]))
- enc3 = self.wavelet_enhancer_f3(self.encoder4(hidden_states_out[2]))
- else:
- # Use standard encoder features without enhancement
- enc0 = self.encoder1(x)
- enc1 = self.encoder2(hidden_states_out[0])
- enc2 = self.encoder3(hidden_states_out[1])
- enc3 = self.encoder4(hidden_states_out[2])
- # Process bottleneck features
- dec4 = self.encoder10(hidden_states_out[4])
- if self.wavelet_enhancement:
- dec4 = self.wavelet_enhancer_bottleneck(dec4)
- if self.fft_enhancement:
- dec3 = self.decoder5(dec4, hidden_states_out[3])
- dec3 = self.fft_enhancer_f4(dec3)
- dec2 = self.decoder4(dec3, enc3)
- dec2 = self.fft_enhancer_f3(dec2)
- dec1 = self.decoder3(dec2, enc2)
- dec1 = self.fft_enhancer_f2(dec1)
- dec0 = self.decoder2(dec1, enc1)
- dec0 = self.fft_enhancer_f1(dec0)
- out = self.decoder1(dec0, enc0)
- else:
- # Standard decoding without FFT enhancement
- dec3 = self.decoder5(dec4, hidden_states_out[3])
- dec2 = self.decoder4(dec3, enc3)
- dec1 = self.decoder3(dec2, enc2)
- dec0 = self.decoder2(dec1, enc1)
- out = self.decoder1(dec0, enc0)
- # Generate final output logits
- logits = self.out(out)
- return logits
- if __name__ == "__main__":
- image = torch.randn(1, 3, 512, 512)
- model = Wavelet_FFT_SwinUNETR(
- in_channels=3,
- out_channels=1,
- patch_size=2,
- depths=(2, 2, 2, 2),
- num_heads=(3, 6, 12, 24),
- window_size=7,
- qkv_bias=True,
- mlp_ratio=4.0,
- feature_size=48,
- norm_name="instance",
- drop_rate=0.0,
- attn_drop_rate=0.0,
- dropout_path_rate=0.0,
- normalize=True,
- norm_layer=nn.LayerNorm,
- patch_norm=False,
- use_checkpoint=False,
- spatial_dims=2,
- downsample="merging",
- use_v2=True,
- wavelet_enhancement=True,
- wavelet_J=2,
- wavelet_wave="db4",
- wavelet_mode="symmetric",
- wavelet_reduction=16,
- fft_enhancement=True,
- )
- hidden_states_out = model.swinViT(image, normalize=True)
- print("hidden_states_out:", [i.shape for i in hidden_states_out])
- print(model(image).shape)
- print("total parameters: ", sum(p.numel() for p in model.parameters()))
|