awae.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import torch
  2. import torch.nn as nn
  3. from pytorch_wavelets import DWT, IDWT
  4. class AdaptiveWaveletAttention(nn.Module):
  5. """Adaptive wavelet attention module for feature enhancement."""
  6. def __init__(
  7. self, in_channels: int, reduction_ratio: int = 4, init_bias: float = 0.2
  8. ):
  9. """Initialize attention module.
  10. Args:
  11. in_channels: Number of input channels
  12. reduction_ratio: Channel compression ratio
  13. init_bias: Initial bias value
  14. """
  15. super().__init__()
  16. # Ensure safe reduction ratio
  17. safe_reduction = max(1, min(reduction_ratio, in_channels))
  18. # Channel attention branch with squeeze-and-excitation
  19. self.channel_attention = nn.Sequential(
  20. nn.AdaptiveAvgPool2d(1),
  21. nn.Conv2d(in_channels, in_channels // safe_reduction, 1, bias=False),
  22. nn.ReLU(inplace=True),
  23. nn.Conv2d(in_channels // safe_reduction, in_channels, 1, bias=False),
  24. nn.Sigmoid(),
  25. )
  26. # Learnable bias scale parameter
  27. self.bias_scale = nn.Parameter(torch.tensor(init_bias))
  28. # Fusion gate for adaptive enhancement
  29. self.fusion_gate = nn.Sequential(
  30. nn.AdaptiveAvgPool2d(1),
  31. nn.Conv2d(in_channels, in_channels // safe_reduction, 1, bias=False),
  32. nn.ReLU(inplace=True),
  33. nn.Conv2d(in_channels // safe_reduction, in_channels, 1, bias=False),
  34. nn.Sigmoid(),
  35. )
  36. self._init_weight()
  37. def _init_weight(self):
  38. """Initialize weights using Kaiming initialization."""
  39. for m in self.modules():
  40. if isinstance(m, nn.Conv2d):
  41. nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  42. if m.bias is not None:
  43. nn.init.constant_(m.bias, 0)
  44. elif isinstance(m, nn.BatchNorm2d):
  45. nn.init.constant_(m.weight, 1)
  46. nn.init.constant_(m.bias, 0)
  47. def forward(self, x: torch.Tensor) -> torch.Tensor:
  48. """Forward pass with adaptive wavelet enhancement.
  49. Args:
  50. x: Input tensor [B, C, H, W]
  51. Returns:
  52. Enhanced tensor with same shape
  53. """
  54. # Compute base channel attention weights
  55. # Compute base channel attention weights
  56. base_weight = self.channel_attention(x)
  57. # Compute adaptive gate factor
  58. gate_factor = self.fusion_gate(x)
  59. enhanced_factor = 1.0 + (self.bias_scale * gate_factor)
  60. final_weight = base_weight * enhanced_factor
  61. return x * torch.clamp(final_weight, min=0.1)
  62. class AdaptiveWaveletAugmentedEnhancer(nn.Module):
  63. """Adaptive wavelet augmented enhancer with multi-level attention."""
  64. def __init__(
  65. self,
  66. in_channels: int,
  67. J: int = 1,
  68. wave: str = "db4",
  69. mode: str = "symmetric",
  70. reduction_ratio: int = 4,
  71. ):
  72. """Initialize wavelet enhancer.
  73. Args:
  74. in_channels: Number of input channels
  75. J: Wavelet decomposition levels
  76. wave: Wavelet basis type
  77. mode: Wavelet padding mode
  78. reduction_ratio: Attention compression ratio
  79. """
  80. super().__init__()
  81. # Validate decomposition levels
  82. assert 1 <= J <= 3, "J must be in [1, 3]"
  83. self.J = J
  84. # Initialize discrete wavelet transform
  85. self.dwt = DWT(J=J, wave=wave, mode=mode)
  86. self.idwt = IDWT(wave=wave, mode=mode)
  87. # Low-frequency attention for approximation coefficients
  88. self.ll_att = AdaptiveWaveletAttention(
  89. in_channels=in_channels, reduction_ratio=reduction_ratio, init_bias=0.2
  90. )
  91. # High-frequency attention for detail coefficients at each level
  92. self.yh_att = nn.ModuleList(
  93. [
  94. AdaptiveWaveletAttention(
  95. in_channels=in_channels,
  96. reduction_ratio=reduction_ratio,
  97. init_bias=0.4,
  98. )
  99. for _ in range(J)
  100. ]
  101. )
  102. def forward(self, x: torch.Tensor) -> torch.Tensor:
  103. """Forward pass with multi-level wavelet enhancement.
  104. Args:
  105. x: Input tensor [B, C, H, W]
  106. Returns:
  107. Enhanced tensor after inverse wavelet transform
  108. """
  109. # Perform wavelet decomposition
  110. # Perform wavelet decomposition
  111. yl, yh = self.dwt(x)
  112. # Enhance low-frequency approximation
  113. yl_enhanced = self.ll_att(yl)
  114. # Enhance high-frequency details at each level
  115. yh_enhanced = []
  116. for i in range(self.J):
  117. level_features = []
  118. for j in range(3): # LH, HL, HH subbands
  119. subband = yh[i][:, :, j, :, :]
  120. enhanced_subband = self.yh_att[i](subband)
  121. level_features.append(enhanced_subband)
  122. yh_enhanced.append(torch.stack(level_features, dim=2))
  123. # Reconstruct enhanced signal via inverse wavelet transform
  124. return self.idwt((yl_enhanced, yh_enhanced))
  125. if __name__ == "__main__":
  126. input_tensor = torch.randn(1, 64, 256, 256)
  127. wavelet_enhancer = AdaptiveWaveletAugmentedEnhancer(in_channels=64, J=2)
  128. enhanced_tensor = wavelet_enhancer(input_tensor)
  129. print("input shape:", input_tensor.shape)
  130. print("output shape:", enhanced_tensor.shape)