| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576 |
- from __future__ import annotations
- from argparse import Namespace
- from pathlib import Path
- from typing import Any
- import torch
- import torch.nn as nn
- from .build_swinv2 import build_swinv2
- class SwinV2Encoder2d(nn.Module):
- """
- 面向分割的纯 SwinV2 编码器封装
- """
- def __init__(
- self,
- model_name: str | None = None,
- config_path: str | Path | None = None,
- weight_path: str | Path | None = None,
- args: Namespace | None = None,
- *,
- load_weights: bool = True,
- normalize_features: bool = True,
- use_multiscale_features: bool = True,
- include_patch_embed: bool = True,
- **model_kwargs: Any,
- ) -> None:
- super().__init__()
- backbone, cfg = build_swinv2(
- model_name=model_name,
- config_path=config_path,
- weight_path=weight_path,
- args=args,
- load_weights=load_weights,
- return_config=True,
- **model_kwargs,
- )
- self.backbone = backbone
- self.cfg = cfg
- self.normalize_features = normalize_features
- self.use_multiscale_features = use_multiscale_features
- self.include_patch_embed = include_patch_embed
- depths = tuple(cfg.MODEL.SWINV2.DEPTHS)
- embed_dim = int(cfg.MODEL.SWINV2.EMBED_DIM)
- if self.use_multiscale_features:
- stage_channels = []
- if self.include_patch_embed:
- stage_channels.append(embed_dim)
- for i in range(len(depths)):
- channel_multiplier = 2 ** min(i + 1, len(depths) - 1)
- stage_channels.append(int(embed_dim * channel_multiplier))
- self.stage_channels = stage_channels
- else:
- self.stage_channels = [int(embed_dim * 2**i) for i in range(len(depths))]
- def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor | list[torch.Tensor]]:
- if self.use_multiscale_features:
- features = self.backbone.forward_multiscale_features(
- x,
- normalize=self.normalize_features,
- include_patch_embed=self.include_patch_embed,
- )
- else:
- features = self.backbone.forward_stage_features(
- x, normalize=self.normalize_features
- )
- deepest = features[-1]
- return {
- "features": features,
- "deepest_feature": deepest,
- }
|