from __future__ import annotations from lib.SwinTransformer.models.swin_transformer_v2 import SwinTransformerV2 from lib.modules.fwta_2d import FourierWaveletTokenAggregation class SwinTransformerV2FWTA(SwinTransformerV2): """ Keep the original SwinTransformerV2 backbone intact and only replace the final global aggregation path. """ def __init__( self, *args, fwta_wavelet: str = "haar", fwta_level: int = 1, fwta_sigma_ratio: float = 0.35, fwta_tau_fourier: float = 0.15, fwta_gate_temperature: float = 1.0, fwta_fusion_hidden_ratio: float = 0.5, fwta_use_global_conditioning: bool = True, fwta_residual_scale_init: float = 1.0, **kwargs, ): super().__init__(*args, **kwargs) final_resolution = ( int(self.patches_resolution[0] // (2 ** (self.num_layers - 1))), int(self.patches_resolution[1] // (2 ** (self.num_layers - 1))), ) self.fwta = FourierWaveletTokenAggregation( dim=int(self.num_features), grid_size=final_resolution, wavelet=fwta_wavelet, wavelet_level=fwta_level, sigma_ratio=fwta_sigma_ratio, tau_fourier=fwta_tau_fourier, gate_temperature=fwta_gate_temperature, residual_scale_init=fwta_residual_scale_init, fusion_hidden_ratio=fwta_fusion_hidden_ratio, use_cls_conditioning=fwta_use_global_conditioning, ) def forward_features(self, x, return_gate: bool = False): x = self.patch_embed(x) if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) for layer in self.layers: x = layer(x) x = self.norm(x) # [B, L, C] gap = x.mean(dim=1) # [B, C] feat, gate = self.fwta(gap, x) if return_gate: return feat, gate return feat def forward(self, x, return_gate: bool = False): if return_gate: feat, gate = self.forward_features(x, return_gate=True) logits = self.head(feat) return logits, gate feat = self.forward_features(x, return_gate=False) logits = self.head(feat) return logits