| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150 |
- import torch
- import torch.nn as nn
- from pytorch_wavelets import DWT, IDWT
- class AdaptiveWaveletAttention(nn.Module):
- """Adaptive wavelet attention module for feature enhancement."""
- def __init__(
- self, in_channels: int, reduction_ratio: int = 4, init_bias: float = 0.2
- ):
- """Initialize attention module.
-
- Args:
- in_channels: Number of input channels
- reduction_ratio: Channel compression ratio
- init_bias: Initial bias value
- """
- super().__init__()
- # Ensure safe reduction ratio
- safe_reduction = max(1, min(reduction_ratio, in_channels))
- # Channel attention branch with squeeze-and-excitation
- self.channel_attention = nn.Sequential(
- nn.AdaptiveAvgPool2d(1),
- nn.Conv2d(in_channels, in_channels // safe_reduction, 1, bias=False),
- nn.ReLU(inplace=True),
- nn.Conv2d(in_channels // safe_reduction, in_channels, 1, bias=False),
- nn.Sigmoid(),
- )
- # Learnable bias scale parameter
- self.bias_scale = nn.Parameter(torch.tensor(init_bias))
- # Fusion gate for adaptive enhancement
- self.fusion_gate = nn.Sequential(
- nn.AdaptiveAvgPool2d(1),
- nn.Conv2d(in_channels, in_channels // safe_reduction, 1, bias=False),
- nn.ReLU(inplace=True),
- nn.Conv2d(in_channels // safe_reduction, in_channels, 1, bias=False),
- nn.Sigmoid(),
- )
- self._init_weight()
- def _init_weight(self):
- """Initialize weights using Kaiming initialization."""
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.BatchNorm2d):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Forward pass with adaptive wavelet enhancement.
-
- Args:
- x: Input tensor [B, C, H, W]
-
- Returns:
- Enhanced tensor with same shape
- """
- # Compute base channel attention weights
- # Compute base channel attention weights
- base_weight = self.channel_attention(x)
- # Compute adaptive gate factor
- gate_factor = self.fusion_gate(x)
- enhanced_factor = 1.0 + (self.bias_scale * gate_factor)
- final_weight = base_weight * enhanced_factor
- return x * torch.clamp(final_weight, min=0.1)
- class AdaptiveWaveletAugmentedEnhancer(nn.Module):
- """Adaptive wavelet augmented enhancer with multi-level attention."""
- def __init__(
- self,
- in_channels: int,
- J: int = 1,
- wave: str = "db4",
- mode: str = "symmetric",
- reduction_ratio: int = 4,
- ):
- """Initialize wavelet enhancer.
-
- Args:
- in_channels: Number of input channels
- J: Wavelet decomposition levels
- wave: Wavelet basis type
- mode: Wavelet padding mode
- reduction_ratio: Attention compression ratio
- """
- super().__init__()
- # Validate decomposition levels
- assert 1 <= J <= 3, "J must be in [1, 3]"
- self.J = J
- # Initialize discrete wavelet transform
- self.dwt = DWT(J=J, wave=wave, mode=mode)
- self.idwt = IDWT(wave=wave, mode=mode)
- # Low-frequency attention for approximation coefficients
- self.ll_att = AdaptiveWaveletAttention(
- in_channels=in_channels, reduction_ratio=reduction_ratio, init_bias=0.2
- )
- # High-frequency attention for detail coefficients at each level
- self.yh_att = nn.ModuleList(
- [
- AdaptiveWaveletAttention(
- in_channels=in_channels,
- reduction_ratio=reduction_ratio,
- init_bias=0.4,
- )
- for _ in range(J)
- ]
- )
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Forward pass with multi-level wavelet enhancement.
-
- Args:
- x: Input tensor [B, C, H, W]
-
- Returns:
- Enhanced tensor after inverse wavelet transform
- """
- # Perform wavelet decomposition
- # Perform wavelet decomposition
- yl, yh = self.dwt(x)
- # Enhance low-frequency approximation
- yl_enhanced = self.ll_att(yl)
- # Enhance high-frequency details at each level
- yh_enhanced = []
- for i in range(self.J):
- level_features = []
- for j in range(3): # LH, HL, HH subbands
- subband = yh[i][:, :, j, :, :]
- enhanced_subband = self.yh_att[i](subband)
- level_features.append(enhanced_subband)
- yh_enhanced.append(torch.stack(level_features, dim=2))
- # Reconstruct enhanced signal via inverse wavelet transform
- return self.idwt((yl_enhanced, yh_enhanced))
- if __name__ == "__main__":
- input_tensor = torch.randn(1, 64, 256, 256)
- wavelet_enhancer = AdaptiveWaveletAugmentedEnhancer(in_channels=64, J=2)
- enhanced_tensor = wavelet_enhancer(input_tensor)
- print("input shape:", input_tensor.shape)
- print("output shape:", enhanced_tensor.shape)
|