from __future__ import annotations from argparse import Namespace from pathlib import Path from typing import Any import torch import torch.nn as nn from .build_swinv2 import build_swinv2 class SwinV2Encoder2d(nn.Module): """ 面向分割的纯 SwinV2 编码器封装 """ def __init__( self, model_name: str | None = None, config_path: str | Path | None = None, weight_path: str | Path | None = None, args: Namespace | None = None, *, load_weights: bool = True, normalize_features: bool = True, use_multiscale_features: bool = True, include_patch_embed: bool = True, **model_kwargs: Any, ) -> None: super().__init__() backbone, cfg = build_swinv2( model_name=model_name, config_path=config_path, weight_path=weight_path, args=args, load_weights=load_weights, return_config=True, **model_kwargs, ) self.backbone = backbone self.cfg = cfg self.normalize_features = normalize_features self.use_multiscale_features = use_multiscale_features self.include_patch_embed = include_patch_embed depths = tuple(cfg.MODEL.SWINV2.DEPTHS) embed_dim = int(cfg.MODEL.SWINV2.EMBED_DIM) if self.use_multiscale_features: stage_channels = [] if self.include_patch_embed: stage_channels.append(embed_dim) for i in range(len(depths)): channel_multiplier = 2 ** min(i + 1, len(depths) - 1) stage_channels.append(int(embed_dim * channel_multiplier)) self.stage_channels = stage_channels else: self.stage_channels = [int(embed_dim * 2**i) for i in range(len(depths))] def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor | list[torch.Tensor]]: if self.use_multiscale_features: features = self.backbone.forward_multiscale_features( x, normalize=self.normalize_features, include_patch_embed=self.include_patch_embed, ) else: features = self.backbone.forward_stage_features( x, normalize=self.normalize_features ) deepest = features[-1] return { "features": features, "deepest_feature": deepest, }