import torch import torch.nn as nn from pytorch_wavelets import DWT, IDWT class AdaptiveWaveletAttention(nn.Module): """Adaptive wavelet attention module for feature enhancement.""" def __init__( self, in_channels: int, reduction_ratio: int = 4, init_bias: float = 0.2 ): """Initialize attention module. Args: in_channels: Number of input channels reduction_ratio: Channel compression ratio init_bias: Initial bias value """ super().__init__() # Ensure safe reduction ratio safe_reduction = max(1, min(reduction_ratio, in_channels)) # Channel attention branch with squeeze-and-excitation self.channel_attention = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, in_channels // safe_reduction, 1, bias=False), nn.ReLU(inplace=True), nn.Conv2d(in_channels // safe_reduction, in_channels, 1, bias=False), nn.Sigmoid(), ) # Learnable bias scale parameter self.bias_scale = nn.Parameter(torch.tensor(init_bias)) # Fusion gate for adaptive enhancement self.fusion_gate = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, in_channels // safe_reduction, 1, bias=False), nn.ReLU(inplace=True), nn.Conv2d(in_channels // safe_reduction, in_channels, 1, bias=False), nn.Sigmoid(), ) self._init_weight() def _init_weight(self): """Initialize weights using Kaiming initialization.""" for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass with adaptive wavelet enhancement. Args: x: Input tensor [B, C, H, W] Returns: Enhanced tensor with same shape """ # Compute base channel attention weights # Compute base channel attention weights base_weight = self.channel_attention(x) # Compute adaptive gate factor gate_factor = self.fusion_gate(x) enhanced_factor = 1.0 + (self.bias_scale * gate_factor) final_weight = base_weight * enhanced_factor return x * torch.clamp(final_weight, min=0.1) class AdaptiveWaveletAugmentedEnhancer(nn.Module): """Adaptive wavelet augmented enhancer with multi-level attention.""" def __init__( self, in_channels: int, J: int = 1, wave: str = "db4", mode: str = "symmetric", reduction_ratio: int = 4, ): """Initialize wavelet enhancer. Args: in_channels: Number of input channels J: Wavelet decomposition levels wave: Wavelet basis type mode: Wavelet padding mode reduction_ratio: Attention compression ratio """ super().__init__() # Validate decomposition levels assert 1 <= J <= 3, "J must be in [1, 3]" self.J = J # Initialize discrete wavelet transform self.dwt = DWT(J=J, wave=wave, mode=mode) self.idwt = IDWT(wave=wave, mode=mode) # Low-frequency attention for approximation coefficients self.ll_att = AdaptiveWaveletAttention( in_channels=in_channels, reduction_ratio=reduction_ratio, init_bias=0.2 ) # High-frequency attention for detail coefficients at each level self.yh_att = nn.ModuleList( [ AdaptiveWaveletAttention( in_channels=in_channels, reduction_ratio=reduction_ratio, init_bias=0.4, ) for _ in range(J) ] ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass with multi-level wavelet enhancement. Args: x: Input tensor [B, C, H, W] Returns: Enhanced tensor after inverse wavelet transform """ # Perform wavelet decomposition # Perform wavelet decomposition yl, yh = self.dwt(x) # Enhance low-frequency approximation yl_enhanced = self.ll_att(yl) # Enhance high-frequency details at each level yh_enhanced = [] for i in range(self.J): level_features = [] for j in range(3): # LH, HL, HH subbands subband = yh[i][:, :, j, :, :] enhanced_subband = self.yh_att[i](subband) level_features.append(enhanced_subband) yh_enhanced.append(torch.stack(level_features, dim=2)) # Reconstruct enhanced signal via inverse wavelet transform return self.idwt((yl_enhanced, yh_enhanced)) if __name__ == "__main__": input_tensor = torch.randn(1, 64, 256, 256) wavelet_enhancer = AdaptiveWaveletAugmentedEnhancer(in_channels=64, J=2) enhanced_tensor = wavelet_enhancer(input_tensor) print("input shape:", input_tensor.shape) print("output shape:", enhanced_tensor.shape)