swin_transformer_v2_fwta.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. from __future__ import annotations
  2. from lib.SwinTransformer.models.swin_transformer_v2 import SwinTransformerV2
  3. from lib.modules.fwta_2d import FourierWaveletTokenAggregation
  4. class SwinTransformerV2FWTA(SwinTransformerV2):
  5. """
  6. Keep the original SwinTransformerV2 backbone intact and only replace the
  7. final global aggregation path.
  8. """
  9. def __init__(
  10. self,
  11. *args,
  12. fwta_wavelet: str = "haar",
  13. fwta_level: int = 1,
  14. fwta_sigma_ratio: float = 0.35,
  15. fwta_tau_fourier: float = 0.15,
  16. fwta_gate_temperature: float = 1.0,
  17. fwta_fusion_hidden_ratio: float = 0.5,
  18. fwta_use_global_conditioning: bool = True,
  19. fwta_residual_scale_init: float = 1.0,
  20. **kwargs,
  21. ):
  22. super().__init__(*args, **kwargs)
  23. final_resolution = (
  24. int(self.patches_resolution[0] // (2 ** (self.num_layers - 1))),
  25. int(self.patches_resolution[1] // (2 ** (self.num_layers - 1))),
  26. )
  27. self.fwta = FourierWaveletTokenAggregation(
  28. dim=int(self.num_features),
  29. grid_size=final_resolution,
  30. wavelet=fwta_wavelet,
  31. wavelet_level=fwta_level,
  32. sigma_ratio=fwta_sigma_ratio,
  33. tau_fourier=fwta_tau_fourier,
  34. gate_temperature=fwta_gate_temperature,
  35. residual_scale_init=fwta_residual_scale_init,
  36. fusion_hidden_ratio=fwta_fusion_hidden_ratio,
  37. use_cls_conditioning=fwta_use_global_conditioning,
  38. )
  39. def forward_features(self, x, return_gate: bool = False):
  40. x = self.patch_embed(x)
  41. if self.ape:
  42. x = x + self.absolute_pos_embed
  43. x = self.pos_drop(x)
  44. for layer in self.layers:
  45. x = layer(x)
  46. x = self.norm(x) # [B, L, C]
  47. gap = x.mean(dim=1) # [B, C]
  48. feat, gate = self.fwta(gap, x)
  49. if return_gate:
  50. return feat, gate
  51. return feat
  52. def forward(self, x, return_gate: bool = False):
  53. if return_gate:
  54. feat, gate = self.forward_features(x, return_gate=True)
  55. logits = self.head(feat)
  56. return logits, gate
  57. feat = self.forward_features(x, return_gate=False)
  58. logits = self.head(feat)
  59. return logits