segmentation_2d.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. from __future__ import annotations
  2. from argparse import Namespace
  3. from pathlib import Path
  4. from typing import Any, Sequence
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from .decoder_2d import StructureAwareDecoder2d
  9. from .layers_2d import Conv2dBN
  10. from .swinv2_fwta_encoder_2d import SwinV2FWTAEncoder2d
  11. class SegmentationHead2d(nn.Module):
  12. def __init__(self, in_channels: int, out_channels: int) -> None:
  13. super().__init__()
  14. self.block = nn.Sequential(
  15. Conv2dBN(in_channels, in_channels, 3, 1, 1),
  16. nn.ReLU(inplace=True),
  17. nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True),
  18. )
  19. def forward(self, x: torch.Tensor, output_size: tuple[int, int]) -> torch.Tensor:
  20. x = self.block(x)
  21. return F.interpolate(x, size=output_size, mode="bilinear", align_corners=False)
  22. class BoundaryHead2d(nn.Module):
  23. def __init__(self, in_channels: int, out_channels: int = 1) -> None:
  24. super().__init__()
  25. self.block = nn.Sequential(
  26. Conv2dBN(in_channels, in_channels, 3, 1, 1),
  27. nn.ReLU(inplace=True),
  28. nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True),
  29. )
  30. def forward(self, x: torch.Tensor, output_size: tuple[int, int]) -> torch.Tensor:
  31. x = self.block(x)
  32. return F.interpolate(x, size=output_size, mode="bilinear", align_corners=False)
  33. class GlobalTokenConditioning2d(nn.Module):
  34. """
  35. 使用 FWTA 更新后的全局前景 token 对解码特征做通道调制。
  36. """
  37. def __init__(self, token_channels: int, feature_channels: int) -> None:
  38. super().__init__()
  39. hidden_channels = max(feature_channels // 2, 32)
  40. self.gate = nn.Sequential(
  41. nn.LayerNorm(token_channels),
  42. nn.Linear(token_channels, hidden_channels),
  43. nn.GELU(),
  44. nn.Linear(hidden_channels, feature_channels),
  45. nn.Sigmoid(),
  46. )
  47. def forward(self, x: torch.Tensor, global_token: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  48. channel_gate = self.gate(global_token).unsqueeze(-1).unsqueeze(-1)
  49. return x * (1.0 + channel_gate), channel_gate
  50. class SegmentationNet2d(nn.Module):
  51. """
  52. 第一版超声分割主网络骨架。
  53. 当前职责:
  54. - 编码器输出多尺度特征和稳定性图
  55. - 结构感知解码器恢复分割特征
  56. - 同时输出分割图和边界图
  57. """
  58. def __init__(
  59. self,
  60. num_classes: int,
  61. model_name: str | None = None,
  62. config_path: str | Path | None = None,
  63. weight_path: str | Path | None = None,
  64. args: Namespace | None = None,
  65. *,
  66. decoder_channels: Sequence[int] | None = None,
  67. load_weights: bool = True,
  68. **encoder_kwargs: Any,
  69. ) -> None:
  70. super().__init__()
  71. self.encoder = SwinV2FWTAEncoder2d(
  72. model_name=model_name,
  73. config_path=config_path,
  74. weight_path=weight_path,
  75. args=args,
  76. load_weights=load_weights,
  77. **encoder_kwargs,
  78. )
  79. self.decoder = StructureAwareDecoder2d(
  80. encoder_channels=self.encoder.stage_channels,
  81. decoder_channels=decoder_channels,
  82. )
  83. self.global_conditioning = GlobalTokenConditioning2d(
  84. token_channels=self.encoder.stage_channels[-1],
  85. feature_channels=self.decoder.out_channels,
  86. )
  87. self.segmentation_head = SegmentationHead2d(self.decoder.out_channels, num_classes)
  88. self.boundary_head = BoundaryHead2d(self.decoder.out_channels, out_channels=1)
  89. def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor | list[torch.Tensor]]:
  90. encoder_outputs = self.encoder(x)
  91. features = encoder_outputs["features"]
  92. stability_map = encoder_outputs["stability_map"]
  93. decoder_out, decoder_features = self.decoder(
  94. features=features,
  95. stability_map=stability_map,
  96. )
  97. conditioned_decoder_out, global_channel_gate = self.global_conditioning(
  98. decoder_out,
  99. encoder_outputs["global_token"],
  100. )
  101. output_size = x.shape[-2:]
  102. seg_logits = self.segmentation_head(conditioned_decoder_out, output_size=output_size)
  103. boundary_logits = self.boundary_head(conditioned_decoder_out, output_size=output_size)
  104. return {
  105. "seg_logits": seg_logits,
  106. "boundary_logits": boundary_logits,
  107. "stability_map": F.interpolate(
  108. stability_map, size=output_size, mode="bilinear", align_corners=False
  109. ),
  110. "encoder_features": features,
  111. "decoder_features": decoder_features,
  112. "conditioned_decoder_feature": conditioned_decoder_out,
  113. "deepest_feature": encoder_outputs["deepest_feature"],
  114. "global_token": encoder_outputs["global_token"],
  115. "global_channel_gate": global_channel_gate,
  116. "token_gate": encoder_outputs["token_gate"],
  117. }