from __future__ import annotations from argparse import Namespace from pathlib import Path from typing import Any import torch import torch.nn as nn from .build_swinv2 import build_swinv2 from .fwta_2d import FourierWaveletTokenAggregation class SwinV2FWTAEncoder2d(nn.Module): """ 面向分割的 SwinV2 + FWTA 编码器封装。 """ def __init__( self, model_name: str | None = None, config_path: str | Path | None = None, weight_path: str | Path | None = None, args: Namespace | None = None, *, load_weights: bool = True, normalize_features: bool = True, use_multiscale_features: bool = True, include_patch_embed: bool = True, fwta_wavelet: str = "haar", fwta_level: int = 1, fwta_sigma_ratio: float = 0.35, fwta_tau_fourier: float = 0.15, fwta_gate_temperature: float = 1.0, fwta_fusion_hidden_ratio: float = 0.5, fwta_use_global_conditioning: bool = True, fwta_residual_scale_init: float = 1.0, fwta_learnable_global_token: bool = True, fwta_global_token_use_image_conditioning: bool = True, **model_kwargs: Any, ) -> None: super().__init__() backbone, cfg = build_swinv2( model_name=model_name, config_path=config_path, weight_path=weight_path, args=args, load_weights=load_weights, return_config=True, **model_kwargs, ) self.backbone = backbone self.cfg = cfg self.normalize_features = normalize_features self.use_multiscale_features = use_multiscale_features self.include_patch_embed = include_patch_embed depths = tuple(cfg.MODEL.SWINV2.DEPTHS) embed_dim = int(cfg.MODEL.SWINV2.EMBED_DIM) if self.use_multiscale_features: stage_channels = [] if self.include_patch_embed: stage_channels.append(embed_dim) for i in range(len(depths)): channel_multiplier = 2 ** min(i + 1, len(depths) - 1) stage_channels.append(int(embed_dim * channel_multiplier)) self.stage_channels = stage_channels else: self.stage_channels = [int(embed_dim * 2 ** i) for i in range(len(depths))] final_resolution = ( int(self.backbone.patches_resolution[0] // (2 ** (len(depths) - 1))), int(self.backbone.patches_resolution[1] // (2 ** (len(depths) - 1))), ) self.fwta = FourierWaveletTokenAggregation( dim=int(self.backbone.num_features), grid_size=final_resolution, wavelet=fwta_wavelet, wavelet_level=fwta_level, sigma_ratio=fwta_sigma_ratio, tau_fourier=fwta_tau_fourier, gate_temperature=fwta_gate_temperature, residual_scale_init=fwta_residual_scale_init, fusion_hidden_ratio=fwta_fusion_hidden_ratio, use_cls_conditioning=fwta_use_global_conditioning, learnable_global_token=fwta_learnable_global_token, global_token_use_image_conditioning=fwta_global_token_use_image_conditioning, ) def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor | list[torch.Tensor]]: if self.use_multiscale_features: features = self.backbone.forward_multiscale_features( x, normalize=self.normalize_features, include_patch_embed=self.include_patch_embed, ) else: features = self.backbone.forward_stage_features(x, normalize=self.normalize_features) deepest = features[-1] patch_tokens = deepest.flatten(2).transpose(1, 2) cls_out, gate, stability_prior, saliency_prior = self.fwta.forward_with_map( patch_tokens=patch_tokens, ) return { "features": features, "deepest_feature": deepest, "global_token": cls_out, "token_gate": gate, "stability_prior": stability_prior, "saliency_prior": saliency_prior, }