from typing import List, Union import torch from monai.networks.nets import SwinUNETR from torch import Tensor from torch import nn from lib.modules.awae import AdaptiveWaveletAugmentedEnhancer from lib.modules.frcab import FFTRCAB class Wavelet_FFT_SwinUNETR(SwinUNETR): """SwinUNETR with Wavelet and FFT enhancement modules.""" def __init__( self, in_channels=3, out_channels=2, patch_size=2, depths=(2, 2, 2, 2), num_heads=(3, 6, 12, 24), window_size=7, qkv_bias=True, mlp_ratio=4.0, feature_size=48, norm_name="instance", drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, normalize=True, norm_layer=nn.LayerNorm, patch_norm=False, use_checkpoint=False, spatial_dims=2, downsample="merging", use_v2=True, wavelet_enhancement=True, wavelet_J=2, wavelet_wave="db4", wavelet_mode="symmetric", wavelet_reduction=16, fft_enhancement=True, ): """Initialize model. Args: in_channels: Number of input channels out_channels: Number of output channels feature_size: Base feature dimension spatial_dims: Spatial dimensions (2D or 3D) wavelet_enhancement: Enable wavelet enhancement module wavelet_J: Wavelet decomposition levels wavelet_wave: Wavelet basis type wavelet_mode: Wavelet mode wavelet_reduction: Wavelet attention compression ratio fft_enhancement: Enable FFT enhancement module """ # Initialize parent SwinUNETR class super().__init__( in_channels=in_channels, out_channels=out_channels, patch_size=patch_size, feature_size=feature_size, depths=depths, num_heads=num_heads, window_size=window_size, qkv_bias=qkv_bias, mlp_ratio=mlp_ratio, norm_name=norm_name, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, dropout_path_rate=dropout_path_rate, normalize=normalize, norm_layer=norm_layer, patch_norm=patch_norm, use_checkpoint=use_checkpoint, spatial_dims=spatial_dims, downsample=downsample, use_v2=use_v2, ) # Initialize wavelet enhancement modules for different feature levels # Initialize wavelet enhancement modules for different feature levels self.wavelet_enhancement = wavelet_enhancement if self.wavelet_enhancement: self.wavelet_enhancer_f0 = AdaptiveWaveletAugmentedEnhancer( in_channels=feature_size, J=wavelet_J, wave=wavelet_wave, mode=wavelet_mode, reduction_ratio=wavelet_reduction, ) self.wavelet_enhancer_f1 = AdaptiveWaveletAugmentedEnhancer( in_channels=feature_size, J=wavelet_J, wave=wavelet_wave, mode=wavelet_mode, reduction_ratio=wavelet_reduction, ) self.wavelet_enhancer_f2 = AdaptiveWaveletAugmentedEnhancer( in_channels=2 * feature_size, J=wavelet_J, wave=wavelet_wave, mode=wavelet_mode, reduction_ratio=wavelet_reduction, ) self.wavelet_enhancer_f3 = AdaptiveWaveletAugmentedEnhancer( in_channels=4 * feature_size, J=wavelet_J, wave=wavelet_wave, mode=wavelet_mode, reduction_ratio=wavelet_reduction, ) self.wavelet_enhancer_bottleneck = AdaptiveWaveletAugmentedEnhancer( in_channels=16 * feature_size, J=1, wave=wavelet_wave, mode=wavelet_mode, reduction_ratio=wavelet_reduction, ) # Initialize FFT enhancement modules for different decoder levels # Initialize FFT enhancement modules for different decoder levels self.fft_enhancement = fft_enhancement if self.fft_enhancement: self.fft_enhancer_f1 = FFTRCAB(feature_size) self.fft_enhancer_f2 = FFTRCAB(2 * feature_size) self.fft_enhancer_f3 = FFTRCAB(4 * feature_size) self.fft_enhancer_f4 = FFTRCAB(8 * feature_size) def forward(self, x: Tensor) -> Union[Tensor, List[Tensor]]: """Forward pass. Args: x: Input tensor [B, C, H, W] Returns: Output logits """ # Check input size compatibility if not torch.jit.is_scripting() and not torch.jit.is_tracing(): self._check_input_size(x.shape[2:]) # Extract multiscale features from Swin ViT backbone hidden_states_out = self.swinViT(x, self.normalize) if self.wavelet_enhancement: enc0 = self.wavelet_enhancer_f0(self.encoder1(x)) enc1 = self.wavelet_enhancer_f1(self.encoder2(hidden_states_out[0])) enc2 = self.wavelet_enhancer_f2(self.encoder3(hidden_states_out[1])) enc3 = self.wavelet_enhancer_f3(self.encoder4(hidden_states_out[2])) else: # Use standard encoder features without enhancement enc0 = self.encoder1(x) enc1 = self.encoder2(hidden_states_out[0]) enc2 = self.encoder3(hidden_states_out[1]) enc3 = self.encoder4(hidden_states_out[2]) # Process bottleneck features dec4 = self.encoder10(hidden_states_out[4]) if self.wavelet_enhancement: dec4 = self.wavelet_enhancer_bottleneck(dec4) if self.fft_enhancement: dec3 = self.decoder5(dec4, hidden_states_out[3]) dec3 = self.fft_enhancer_f4(dec3) dec2 = self.decoder4(dec3, enc3) dec2 = self.fft_enhancer_f3(dec2) dec1 = self.decoder3(dec2, enc2) dec1 = self.fft_enhancer_f2(dec1) dec0 = self.decoder2(dec1, enc1) dec0 = self.fft_enhancer_f1(dec0) out = self.decoder1(dec0, enc0) else: # Standard decoding without FFT enhancement dec3 = self.decoder5(dec4, hidden_states_out[3]) dec2 = self.decoder4(dec3, enc3) dec1 = self.decoder3(dec2, enc2) dec0 = self.decoder2(dec1, enc1) out = self.decoder1(dec0, enc0) # Generate final output logits logits = self.out(out) return logits if __name__ == "__main__": image = torch.randn(1, 3, 512, 512) model = Wavelet_FFT_SwinUNETR( in_channels=3, out_channels=1, patch_size=2, depths=(2, 2, 2, 2), num_heads=(3, 6, 12, 24), window_size=7, qkv_bias=True, mlp_ratio=4.0, feature_size=48, norm_name="instance", drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, normalize=True, norm_layer=nn.LayerNorm, patch_norm=False, use_checkpoint=False, spatial_dims=2, downsample="merging", use_v2=True, wavelet_enhancement=True, wavelet_J=2, wavelet_wave="db4", wavelet_mode="symmetric", wavelet_reduction=16, fft_enhancement=True, ) hidden_states_out = model.swinViT(image, normalize=True) print("hidden_states_out:", [i.shape for i in hidden_states_out]) print(model(image).shape) print("total parameters: ", sum(p.numel() for p in model.parameters()))