| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- from __future__ import annotations
- from lib.SwinTransformer.models.swin_transformer_v2 import SwinTransformerV2
- from lib.modules.fwta_2d import FourierWaveletTokenAggregation
- class SwinTransformerV2FWTA(SwinTransformerV2):
- """
- Keep the original SwinTransformerV2 backbone intact and only replace the
- final global aggregation path.
- """
- def __init__(
- self,
- *args,
- fwta_wavelet: str = "haar",
- fwta_level: int = 1,
- fwta_sigma_ratio: float = 0.35,
- fwta_tau_fourier: float = 0.15,
- fwta_gate_temperature: float = 1.0,
- fwta_fusion_hidden_ratio: float = 0.5,
- fwta_use_global_conditioning: bool = True,
- fwta_residual_scale_init: float = 1.0,
- **kwargs,
- ):
- super().__init__(*args, **kwargs)
- final_resolution = (
- int(self.patches_resolution[0] // (2 ** (self.num_layers - 1))),
- int(self.patches_resolution[1] // (2 ** (self.num_layers - 1))),
- )
- self.fwta = FourierWaveletTokenAggregation(
- dim=int(self.num_features),
- grid_size=final_resolution,
- wavelet=fwta_wavelet,
- wavelet_level=fwta_level,
- sigma_ratio=fwta_sigma_ratio,
- tau_fourier=fwta_tau_fourier,
- gate_temperature=fwta_gate_temperature,
- residual_scale_init=fwta_residual_scale_init,
- fusion_hidden_ratio=fwta_fusion_hidden_ratio,
- use_cls_conditioning=fwta_use_global_conditioning,
- )
- def forward_features(self, x, return_gate: bool = False):
- x = self.patch_embed(x)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
- for layer in self.layers:
- x = layer(x)
- x = self.norm(x) # [B, L, C]
- gap = x.mean(dim=1) # [B, C]
- feat, gate = self.fwta(gap, x)
- if return_gate:
- return feat, gate
- return feat
- def forward(self, x, return_gate: bool = False):
- if return_gate:
- feat, gate = self.forward_features(x, return_gate=True)
- logits = self.head(feat)
- return logits, gate
- feat = self.forward_features(x, return_gate=False)
- logits = self.head(feat)
- return logits
|