swinv2_fwta_encoder_2d.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. from __future__ import annotations
  2. from argparse import Namespace
  3. from pathlib import Path
  4. from typing import Any
  5. import torch
  6. import torch.nn as nn
  7. from .build_swinv2 import build_swinv2
  8. from .fwta_2d import FourierWaveletTokenAggregation
  9. class SwinV2FWTAEncoder2d(nn.Module):
  10. """
  11. 面向分割的 SwinV2 + FWTA 编码器封装。
  12. """
  13. def __init__(
  14. self,
  15. model_name: str | None = None,
  16. config_path: str | Path | None = None,
  17. weight_path: str | Path | None = None,
  18. args: Namespace | None = None,
  19. *,
  20. load_weights: bool = True,
  21. normalize_features: bool = True,
  22. use_multiscale_features: bool = True,
  23. include_patch_embed: bool = True,
  24. fwta_wavelet: str = "haar",
  25. fwta_level: int = 1,
  26. fwta_sigma_ratio: float = 0.35,
  27. fwta_tau_fourier: float = 0.15,
  28. fwta_gate_temperature: float = 1.0,
  29. fwta_fusion_hidden_ratio: float = 0.5,
  30. fwta_use_global_conditioning: bool = True,
  31. fwta_residual_scale_init: float = 1.0,
  32. fwta_learnable_global_token: bool = True,
  33. fwta_global_token_use_image_conditioning: bool = True,
  34. **model_kwargs: Any,
  35. ) -> None:
  36. super().__init__()
  37. backbone, cfg = build_swinv2(
  38. model_name=model_name,
  39. config_path=config_path,
  40. weight_path=weight_path,
  41. args=args,
  42. load_weights=load_weights,
  43. return_config=True,
  44. **model_kwargs,
  45. )
  46. self.backbone = backbone
  47. self.cfg = cfg
  48. self.normalize_features = normalize_features
  49. self.use_multiscale_features = use_multiscale_features
  50. self.include_patch_embed = include_patch_embed
  51. depths = tuple(cfg.MODEL.SWINV2.DEPTHS)
  52. embed_dim = int(cfg.MODEL.SWINV2.EMBED_DIM)
  53. if self.use_multiscale_features:
  54. stage_channels = []
  55. if self.include_patch_embed:
  56. stage_channels.append(embed_dim)
  57. for i in range(len(depths)):
  58. channel_multiplier = 2 ** min(i + 1, len(depths) - 1)
  59. stage_channels.append(int(embed_dim * channel_multiplier))
  60. self.stage_channels = stage_channels
  61. else:
  62. self.stage_channels = [int(embed_dim * 2 ** i) for i in range(len(depths))]
  63. final_resolution = (
  64. int(self.backbone.patches_resolution[0] // (2 ** (len(depths) - 1))),
  65. int(self.backbone.patches_resolution[1] // (2 ** (len(depths) - 1))),
  66. )
  67. self.fwta = FourierWaveletTokenAggregation(
  68. dim=int(self.backbone.num_features),
  69. grid_size=final_resolution,
  70. wavelet=fwta_wavelet,
  71. wavelet_level=fwta_level,
  72. sigma_ratio=fwta_sigma_ratio,
  73. tau_fourier=fwta_tau_fourier,
  74. gate_temperature=fwta_gate_temperature,
  75. residual_scale_init=fwta_residual_scale_init,
  76. fusion_hidden_ratio=fwta_fusion_hidden_ratio,
  77. use_cls_conditioning=fwta_use_global_conditioning,
  78. learnable_global_token=fwta_learnable_global_token,
  79. global_token_use_image_conditioning=fwta_global_token_use_image_conditioning,
  80. )
  81. def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor | list[torch.Tensor]]:
  82. if self.use_multiscale_features:
  83. features = self.backbone.forward_multiscale_features(
  84. x,
  85. normalize=self.normalize_features,
  86. include_patch_embed=self.include_patch_embed,
  87. )
  88. else:
  89. features = self.backbone.forward_stage_features(x, normalize=self.normalize_features)
  90. deepest = features[-1]
  91. patch_tokens = deepest.flatten(2).transpose(1, 2)
  92. cls_out, gate, stability_prior, saliency_prior = self.fwta.forward_with_map(
  93. patch_tokens=patch_tokens,
  94. )
  95. return {
  96. "features": features,
  97. "deepest_feature": deepest,
  98. "global_token": cls_out,
  99. "token_gate": gate,
  100. "stability_prior": stability_prior,
  101. "saliency_prior": saliency_prior,
  102. }