swinv2_fwta_encoder_2d.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. from __future__ import annotations
  2. from argparse import Namespace
  3. from pathlib import Path
  4. from typing import Any
  5. import torch
  6. import torch.nn as nn
  7. from .build_swinv2 import build_swinv2
  8. from .fwta_2d import FourierWaveletTokenAggregation
  9. class SwinV2FWTAEncoder2d(nn.Module):
  10. """
  11. 面向分割的 SwinV2 + FWTA 编码器封装。
  12. """
  13. def __init__(
  14. self,
  15. model_name: str | None = None,
  16. config_path: str | Path | None = None,
  17. weight_path: str | Path | None = None,
  18. args: Namespace | None = None,
  19. *,
  20. load_weights: bool = True,
  21. normalize_features: bool = True,
  22. use_multiscale_features: bool = True,
  23. include_patch_embed: bool = True,
  24. fwta_wavelet: str = "haar",
  25. fwta_level: int = 1,
  26. fwta_sigma_ratio: float = 0.35,
  27. fwta_tau_fourier: float = 0.15,
  28. fwta_gate_temperature: float = 1.0,
  29. fwta_fusion_hidden_ratio: float = 0.5,
  30. fwta_use_global_conditioning: bool = True,
  31. fwta_residual_scale_init: float = 1.0,
  32. **model_kwargs: Any,
  33. ) -> None:
  34. super().__init__()
  35. backbone, cfg = build_swinv2(
  36. model_name=model_name,
  37. config_path=config_path,
  38. weight_path=weight_path,
  39. args=args,
  40. load_weights=load_weights,
  41. return_config=True,
  42. **model_kwargs,
  43. )
  44. self.backbone = backbone
  45. self.cfg = cfg
  46. self.normalize_features = normalize_features
  47. self.use_multiscale_features = use_multiscale_features
  48. self.include_patch_embed = include_patch_embed
  49. depths = tuple(cfg.MODEL.SWINV2.DEPTHS)
  50. embed_dim = int(cfg.MODEL.SWINV2.EMBED_DIM)
  51. if self.use_multiscale_features:
  52. stage_channels = []
  53. if self.include_patch_embed:
  54. stage_channels.append(embed_dim)
  55. for i in range(len(depths)):
  56. # forward_multiscale_features appends each layer output after its internal downsample.
  57. channel_multiplier = 2 ** min(i + 1, len(depths) - 1)
  58. stage_channels.append(int(embed_dim * channel_multiplier))
  59. self.stage_channels = stage_channels
  60. else:
  61. self.stage_channels = [int(embed_dim * 2 ** i) for i in range(len(depths))]
  62. final_resolution = (
  63. int(self.backbone.patches_resolution[0] // (2 ** (len(depths) - 1))),
  64. int(self.backbone.patches_resolution[1] // (2 ** (len(depths) - 1))),
  65. )
  66. self.fwta = FourierWaveletTokenAggregation(
  67. dim=int(self.backbone.num_features),
  68. grid_size=final_resolution,
  69. wavelet=fwta_wavelet,
  70. wavelet_level=fwta_level,
  71. sigma_ratio=fwta_sigma_ratio,
  72. tau_fourier=fwta_tau_fourier,
  73. gate_temperature=fwta_gate_temperature,
  74. residual_scale_init=fwta_residual_scale_init,
  75. fusion_hidden_ratio=fwta_fusion_hidden_ratio,
  76. use_cls_conditioning=fwta_use_global_conditioning,
  77. )
  78. def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor | list[torch.Tensor]]:
  79. if self.use_multiscale_features:
  80. features = self.backbone.forward_multiscale_features(
  81. x,
  82. normalize=self.normalize_features,
  83. include_patch_embed=self.include_patch_embed,
  84. )
  85. else:
  86. features = self.backbone.forward_stage_features(x, normalize=self.normalize_features)
  87. deepest = features[-1]
  88. b, c, h, w = deepest.shape
  89. patch_tokens = deepest.flatten(2).transpose(1, 2)
  90. cls_token = patch_tokens.mean(dim=1)
  91. cls_out, gate, stability_map = self.fwta.forward_with_map(cls_token, patch_tokens)
  92. return {
  93. "features": features,
  94. "deepest_feature": deepest,
  95. "global_token": cls_out,
  96. "token_gate": gate,
  97. "stability_map": stability_map,
  98. }