swinv2_encoder_2d.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. from __future__ import annotations
  2. from argparse import Namespace
  3. from pathlib import Path
  4. from typing import Any
  5. import torch
  6. import torch.nn as nn
  7. from .build_swinv2 import build_swinv2
  8. class SwinV2Encoder2d(nn.Module):
  9. """
  10. 面向分割的纯 SwinV2 编码器封装
  11. """
  12. def __init__(
  13. self,
  14. model_name: str | None = None,
  15. config_path: str | Path | None = None,
  16. weight_path: str | Path | None = None,
  17. args: Namespace | None = None,
  18. *,
  19. load_weights: bool = True,
  20. normalize_features: bool = True,
  21. use_multiscale_features: bool = True,
  22. include_patch_embed: bool = True,
  23. **model_kwargs: Any,
  24. ) -> None:
  25. super().__init__()
  26. backbone, cfg = build_swinv2(
  27. model_name=model_name,
  28. config_path=config_path,
  29. weight_path=weight_path,
  30. args=args,
  31. load_weights=load_weights,
  32. return_config=True,
  33. **model_kwargs,
  34. )
  35. self.backbone = backbone
  36. self.cfg = cfg
  37. self.normalize_features = normalize_features
  38. self.use_multiscale_features = use_multiscale_features
  39. self.include_patch_embed = include_patch_embed
  40. depths = tuple(cfg.MODEL.SWINV2.DEPTHS)
  41. embed_dim = int(cfg.MODEL.SWINV2.EMBED_DIM)
  42. if self.use_multiscale_features:
  43. stage_channels = []
  44. if self.include_patch_embed:
  45. stage_channels.append(embed_dim)
  46. for i in range(len(depths)):
  47. channel_multiplier = 2 ** min(i + 1, len(depths) - 1)
  48. stage_channels.append(int(embed_dim * channel_multiplier))
  49. self.stage_channels = stage_channels
  50. else:
  51. self.stage_channels = [int(embed_dim * 2**i) for i in range(len(depths))]
  52. def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor | list[torch.Tensor]]:
  53. if self.use_multiscale_features:
  54. features = self.backbone.forward_multiscale_features(
  55. x,
  56. normalize=self.normalize_features,
  57. include_patch_embed=self.include_patch_embed,
  58. )
  59. else:
  60. features = self.backbone.forward_stage_features(
  61. x, normalize=self.normalize_features
  62. )
  63. deepest = features[-1]
  64. return {
  65. "features": features,
  66. "deepest_feature": deepest,
  67. }