| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- from __future__ import annotations
- from argparse import Namespace
- from pathlib import Path
- from typing import Any
- import torch
- import torch.nn as nn
- from .build_swinv2 import build_swinv2
- from .fwta_2d import FourierWaveletTokenAggregation
- class SwinV2FWTAEncoder2d(nn.Module):
- """
- 面向分割的 SwinV2 + FWTA 编码器封装。
- """
- def __init__(
- self,
- model_name: str | None = None,
- config_path: str | Path | None = None,
- weight_path: str | Path | None = None,
- args: Namespace | None = None,
- *,
- load_weights: bool = True,
- normalize_features: bool = True,
- use_multiscale_features: bool = True,
- include_patch_embed: bool = True,
- fwta_wavelet: str = "haar",
- fwta_level: int = 1,
- fwta_sigma_ratio: float = 0.35,
- fwta_tau_fourier: float = 0.15,
- fwta_gate_temperature: float = 1.0,
- fwta_fusion_hidden_ratio: float = 0.5,
- fwta_use_global_conditioning: bool = True,
- fwta_residual_scale_init: float = 1.0,
- fwta_learnable_global_token: bool = True,
- fwta_global_token_use_image_conditioning: bool = True,
- **model_kwargs: Any,
- ) -> None:
- super().__init__()
- backbone, cfg = build_swinv2(
- model_name=model_name,
- config_path=config_path,
- weight_path=weight_path,
- args=args,
- load_weights=load_weights,
- return_config=True,
- **model_kwargs,
- )
- self.backbone = backbone
- self.cfg = cfg
- self.normalize_features = normalize_features
- self.use_multiscale_features = use_multiscale_features
- self.include_patch_embed = include_patch_embed
- depths = tuple(cfg.MODEL.SWINV2.DEPTHS)
- embed_dim = int(cfg.MODEL.SWINV2.EMBED_DIM)
- if self.use_multiscale_features:
- stage_channels = []
- if self.include_patch_embed:
- stage_channels.append(embed_dim)
- for i in range(len(depths)):
- channel_multiplier = 2 ** min(i + 1, len(depths) - 1)
- stage_channels.append(int(embed_dim * channel_multiplier))
- self.stage_channels = stage_channels
- else:
- self.stage_channels = [int(embed_dim * 2 ** i) for i in range(len(depths))]
- final_resolution = (
- int(self.backbone.patches_resolution[0] // (2 ** (len(depths) - 1))),
- int(self.backbone.patches_resolution[1] // (2 ** (len(depths) - 1))),
- )
- self.fwta = FourierWaveletTokenAggregation(
- dim=int(self.backbone.num_features),
- grid_size=final_resolution,
- wavelet=fwta_wavelet,
- wavelet_level=fwta_level,
- sigma_ratio=fwta_sigma_ratio,
- tau_fourier=fwta_tau_fourier,
- gate_temperature=fwta_gate_temperature,
- residual_scale_init=fwta_residual_scale_init,
- fusion_hidden_ratio=fwta_fusion_hidden_ratio,
- use_cls_conditioning=fwta_use_global_conditioning,
- learnable_global_token=fwta_learnable_global_token,
- global_token_use_image_conditioning=fwta_global_token_use_image_conditioning,
- )
- def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor | list[torch.Tensor]]:
- if self.use_multiscale_features:
- features = self.backbone.forward_multiscale_features(
- x,
- normalize=self.normalize_features,
- include_patch_embed=self.include_patch_embed,
- )
- else:
- features = self.backbone.forward_stage_features(x, normalize=self.normalize_features)
- deepest = features[-1]
- patch_tokens = deepest.flatten(2).transpose(1, 2)
- cls_out, gate, stability_prior, saliency_prior = self.fwta.forward_with_map(
- patch_tokens=patch_tokens,
- )
- return {
- "features": features,
- "deepest_feature": deepest,
- "global_token": cls_out,
- "token_gate": gate,
- "stability_prior": stability_prior,
- "saliency_prior": saliency_prior,
- }
|